349 lines
12 KiB
Python
349 lines
12 KiB
Python
|
import errno
|
||
|
import os
|
||
|
import selectors
|
||
|
import signal
|
||
|
import socket
|
||
|
import struct
|
||
|
import sys
|
||
|
import threading
|
||
|
import warnings
|
||
|
|
||
|
from . import connection
|
||
|
from . import process
|
||
|
from .context import reduction
|
||
|
from . import resource_tracker
|
||
|
from . import spawn
|
||
|
from . import util
|
||
|
|
||
|
__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
|
||
|
'set_forkserver_preload']
|
||
|
|
||
|
#
|
||
|
#
|
||
|
#
|
||
|
|
||
|
MAXFDS_TO_SEND = 256
|
||
|
SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
|
||
|
|
||
|
#
|
||
|
# Forkserver class
|
||
|
#
|
||
|
|
||
|
class ForkServer(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
self._forkserver_address = None
|
||
|
self._forkserver_alive_fd = None
|
||
|
self._forkserver_pid = None
|
||
|
self._inherited_fds = None
|
||
|
self._lock = threading.Lock()
|
||
|
self._preload_modules = ['__main__']
|
||
|
|
||
|
def _stop(self):
|
||
|
# Method used by unit tests to stop the server
|
||
|
with self._lock:
|
||
|
self._stop_unlocked()
|
||
|
|
||
|
def _stop_unlocked(self):
|
||
|
if self._forkserver_pid is None:
|
||
|
return
|
||
|
|
||
|
# close the "alive" file descriptor asks the server to stop
|
||
|
os.close(self._forkserver_alive_fd)
|
||
|
self._forkserver_alive_fd = None
|
||
|
|
||
|
os.waitpid(self._forkserver_pid, 0)
|
||
|
self._forkserver_pid = None
|
||
|
|
||
|
if not util.is_abstract_socket_namespace(self._forkserver_address):
|
||
|
os.unlink(self._forkserver_address)
|
||
|
self._forkserver_address = None
|
||
|
|
||
|
def set_forkserver_preload(self, modules_names):
|
||
|
'''Set list of module names to try to load in forkserver process.'''
|
||
|
if not all(type(mod) is str for mod in self._preload_modules):
|
||
|
raise TypeError('module_names must be a list of strings')
|
||
|
self._preload_modules = modules_names
|
||
|
|
||
|
def get_inherited_fds(self):
|
||
|
'''Return list of fds inherited from parent process.
|
||
|
|
||
|
This returns None if the current process was not started by fork
|
||
|
server.
|
||
|
'''
|
||
|
return self._inherited_fds
|
||
|
|
||
|
def connect_to_new_process(self, fds):
|
||
|
'''Request forkserver to create a child process.
|
||
|
|
||
|
Returns a pair of fds (status_r, data_w). The calling process can read
|
||
|
the child process's pid and (eventually) its returncode from status_r.
|
||
|
The calling process should write to data_w the pickled preparation and
|
||
|
process data.
|
||
|
'''
|
||
|
self.ensure_running()
|
||
|
if len(fds) + 4 >= MAXFDS_TO_SEND:
|
||
|
raise ValueError('too many fds')
|
||
|
with socket.socket(socket.AF_UNIX) as client:
|
||
|
client.connect(self._forkserver_address)
|
||
|
parent_r, child_w = os.pipe()
|
||
|
child_r, parent_w = os.pipe()
|
||
|
allfds = [child_r, child_w, self._forkserver_alive_fd,
|
||
|
resource_tracker.getfd()]
|
||
|
allfds += fds
|
||
|
try:
|
||
|
reduction.sendfds(client, allfds)
|
||
|
return parent_r, parent_w
|
||
|
except:
|
||
|
os.close(parent_r)
|
||
|
os.close(parent_w)
|
||
|
raise
|
||
|
finally:
|
||
|
os.close(child_r)
|
||
|
os.close(child_w)
|
||
|
|
||
|
def ensure_running(self):
|
||
|
'''Make sure that a fork server is running.
|
||
|
|
||
|
This can be called from any process. Note that usually a child
|
||
|
process will just reuse the forkserver started by its parent, so
|
||
|
ensure_running() will do nothing.
|
||
|
'''
|
||
|
with self._lock:
|
||
|
resource_tracker.ensure_running()
|
||
|
if self._forkserver_pid is not None:
|
||
|
# forkserver was launched before, is it still running?
|
||
|
pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
|
||
|
if not pid:
|
||
|
# still alive
|
||
|
return
|
||
|
# dead, launch it again
|
||
|
os.close(self._forkserver_alive_fd)
|
||
|
self._forkserver_address = None
|
||
|
self._forkserver_alive_fd = None
|
||
|
self._forkserver_pid = None
|
||
|
|
||
|
cmd = ('from multiprocessing.forkserver import main; ' +
|
||
|
'main(%d, %d, %r, **%r)')
|
||
|
|
||
|
if self._preload_modules:
|
||
|
desired_keys = {'main_path', 'sys_path'}
|
||
|
data = spawn.get_preparation_data('ignore')
|
||
|
data = {x: y for x, y in data.items() if x in desired_keys}
|
||
|
else:
|
||
|
data = {}
|
||
|
|
||
|
with socket.socket(socket.AF_UNIX) as listener:
|
||
|
address = connection.arbitrary_address('AF_UNIX')
|
||
|
listener.bind(address)
|
||
|
if not util.is_abstract_socket_namespace(address):
|
||
|
os.chmod(address, 0o600)
|
||
|
listener.listen()
|
||
|
|
||
|
# all client processes own the write end of the "alive" pipe;
|
||
|
# when they all terminate the read end becomes ready.
|
||
|
alive_r, alive_w = os.pipe()
|
||
|
try:
|
||
|
fds_to_pass = [listener.fileno(), alive_r]
|
||
|
cmd %= (listener.fileno(), alive_r, self._preload_modules,
|
||
|
data)
|
||
|
exe = spawn.get_executable()
|
||
|
args = [exe] + util._args_from_interpreter_flags()
|
||
|
args += ['-c', cmd]
|
||
|
pid = util.spawnv_passfds(exe, args, fds_to_pass)
|
||
|
except:
|
||
|
os.close(alive_w)
|
||
|
raise
|
||
|
finally:
|
||
|
os.close(alive_r)
|
||
|
self._forkserver_address = address
|
||
|
self._forkserver_alive_fd = alive_w
|
||
|
self._forkserver_pid = pid
|
||
|
|
||
|
#
|
||
|
#
|
||
|
#
|
||
|
|
||
|
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
|
||
|
'''Run forkserver.'''
|
||
|
if preload:
|
||
|
if '__main__' in preload and main_path is not None:
|
||
|
process.current_process()._inheriting = True
|
||
|
try:
|
||
|
spawn.import_main_path(main_path)
|
||
|
finally:
|
||
|
del process.current_process()._inheriting
|
||
|
for modname in preload:
|
||
|
try:
|
||
|
__import__(modname)
|
||
|
except ImportError:
|
||
|
pass
|
||
|
|
||
|
util._close_stdin()
|
||
|
|
||
|
sig_r, sig_w = os.pipe()
|
||
|
os.set_blocking(sig_r, False)
|
||
|
os.set_blocking(sig_w, False)
|
||
|
|
||
|
def sigchld_handler(*_unused):
|
||
|
# Dummy signal handler, doesn't do anything
|
||
|
pass
|
||
|
|
||
|
handlers = {
|
||
|
# unblocking SIGCHLD allows the wakeup fd to notify our event loop
|
||
|
signal.SIGCHLD: sigchld_handler,
|
||
|
# protect the process from ^C
|
||
|
signal.SIGINT: signal.SIG_IGN,
|
||
|
}
|
||
|
old_handlers = {sig: signal.signal(sig, val)
|
||
|
for (sig, val) in handlers.items()}
|
||
|
|
||
|
# calling os.write() in the Python signal handler is racy
|
||
|
signal.set_wakeup_fd(sig_w)
|
||
|
|
||
|
# map child pids to client fds
|
||
|
pid_to_fd = {}
|
||
|
|
||
|
with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
|
||
|
selectors.DefaultSelector() as selector:
|
||
|
_forkserver._forkserver_address = listener.getsockname()
|
||
|
|
||
|
selector.register(listener, selectors.EVENT_READ)
|
||
|
selector.register(alive_r, selectors.EVENT_READ)
|
||
|
selector.register(sig_r, selectors.EVENT_READ)
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
while True:
|
||
|
rfds = [key.fileobj for (key, events) in selector.select()]
|
||
|
if rfds:
|
||
|
break
|
||
|
|
||
|
if alive_r in rfds:
|
||
|
# EOF because no more client processes left
|
||
|
assert os.read(alive_r, 1) == b'', "Not at EOF?"
|
||
|
raise SystemExit
|
||
|
|
||
|
if sig_r in rfds:
|
||
|
# Got SIGCHLD
|
||
|
os.read(sig_r, 65536) # exhaust
|
||
|
while True:
|
||
|
# Scan for child processes
|
||
|
try:
|
||
|
pid, sts = os.waitpid(-1, os.WNOHANG)
|
||
|
except ChildProcessError:
|
||
|
break
|
||
|
if pid == 0:
|
||
|
break
|
||
|
child_w = pid_to_fd.pop(pid, None)
|
||
|
if child_w is not None:
|
||
|
returncode = os.waitstatus_to_exitcode(sts)
|
||
|
|
||
|
# Send exit code to client process
|
||
|
try:
|
||
|
write_signed(child_w, returncode)
|
||
|
except BrokenPipeError:
|
||
|
# client vanished
|
||
|
pass
|
||
|
os.close(child_w)
|
||
|
else:
|
||
|
# This shouldn't happen really
|
||
|
warnings.warn('forkserver: waitpid returned '
|
||
|
'unexpected pid %d' % pid)
|
||
|
|
||
|
if listener in rfds:
|
||
|
# Incoming fork request
|
||
|
with listener.accept()[0] as s:
|
||
|
# Receive fds from client
|
||
|
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
|
||
|
if len(fds) > MAXFDS_TO_SEND:
|
||
|
raise RuntimeError(
|
||
|
"Too many ({0:n}) fds to send".format(
|
||
|
len(fds)))
|
||
|
child_r, child_w, *fds = fds
|
||
|
s.close()
|
||
|
pid = os.fork()
|
||
|
if pid == 0:
|
||
|
# Child
|
||
|
code = 1
|
||
|
try:
|
||
|
listener.close()
|
||
|
selector.close()
|
||
|
unused_fds = [alive_r, child_w, sig_r, sig_w]
|
||
|
unused_fds.extend(pid_to_fd.values())
|
||
|
code = _serve_one(child_r, fds,
|
||
|
unused_fds,
|
||
|
old_handlers)
|
||
|
except Exception:
|
||
|
sys.excepthook(*sys.exc_info())
|
||
|
sys.stderr.flush()
|
||
|
finally:
|
||
|
os._exit(code)
|
||
|
else:
|
||
|
# Send pid to client process
|
||
|
try:
|
||
|
write_signed(child_w, pid)
|
||
|
except BrokenPipeError:
|
||
|
# client vanished
|
||
|
pass
|
||
|
pid_to_fd[pid] = child_w
|
||
|
os.close(child_r)
|
||
|
for fd in fds:
|
||
|
os.close(fd)
|
||
|
|
||
|
except OSError as e:
|
||
|
if e.errno != errno.ECONNABORTED:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def _serve_one(child_r, fds, unused_fds, handlers):
|
||
|
# close unnecessary stuff and reset signal handlers
|
||
|
signal.set_wakeup_fd(-1)
|
||
|
for sig, val in handlers.items():
|
||
|
signal.signal(sig, val)
|
||
|
for fd in unused_fds:
|
||
|
os.close(fd)
|
||
|
|
||
|
(_forkserver._forkserver_alive_fd,
|
||
|
resource_tracker._resource_tracker._fd,
|
||
|
*_forkserver._inherited_fds) = fds
|
||
|
|
||
|
# Run process object received over pipe
|
||
|
parent_sentinel = os.dup(child_r)
|
||
|
code = spawn._main(child_r, parent_sentinel)
|
||
|
|
||
|
return code
|
||
|
|
||
|
|
||
|
#
|
||
|
# Read and write signed numbers
|
||
|
#
|
||
|
|
||
|
def read_signed(fd):
|
||
|
data = b''
|
||
|
length = SIGNED_STRUCT.size
|
||
|
while len(data) < length:
|
||
|
s = os.read(fd, length - len(data))
|
||
|
if not s:
|
||
|
raise EOFError('unexpected EOF')
|
||
|
data += s
|
||
|
return SIGNED_STRUCT.unpack(data)[0]
|
||
|
|
||
|
def write_signed(fd, n):
|
||
|
msg = SIGNED_STRUCT.pack(n)
|
||
|
while msg:
|
||
|
nbytes = os.write(fd, msg)
|
||
|
if nbytes == 0:
|
||
|
raise RuntimeError('should not get here')
|
||
|
msg = msg[nbytes:]
|
||
|
|
||
|
#
|
||
|
#
|
||
|
#
|
||
|
|
||
|
_forkserver = ForkServer()
|
||
|
ensure_running = _forkserver.ensure_running
|
||
|
get_inherited_fds = _forkserver.get_inherited_fds
|
||
|
connect_to_new_process = _forkserver.connect_to_new_process
|
||
|
set_forkserver_preload = _forkserver.set_forkserver_preload
|