forkserver.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import errno
  2. import os
  3. import selectors
  4. import signal
  5. import socket
  6. import struct
  7. import sys
  8. import threading
  9. import warnings
  10. from . import connection
  11. from . import process
  12. from .context import reduction
  13. from . import semaphore_tracker
  14. from . import spawn
  15. from . import util
  16. __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
  17. 'set_forkserver_preload']
  18. #
  19. #
  20. #
  21. MAXFDS_TO_SEND = 256
  22. SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
  23. #
  24. # Forkserver class
  25. #
  26. class ForkServer(object):
  27. def __init__(self):
  28. self._forkserver_address = None
  29. self._forkserver_alive_fd = None
  30. self._forkserver_pid = None
  31. self._inherited_fds = None
  32. self._lock = threading.Lock()
  33. self._preload_modules = ['__main__']
  34. def set_forkserver_preload(self, modules_names):
  35. '''Set list of module names to try to load in forkserver process.'''
  36. if not all(type(mod) is str for mod in self._preload_modules):
  37. raise TypeError('module_names must be a list of strings')
  38. self._preload_modules = modules_names
  39. def get_inherited_fds(self):
  40. '''Return list of fds inherited from parent process.
  41. This returns None if the current process was not started by fork
  42. server.
  43. '''
  44. return self._inherited_fds
  45. def connect_to_new_process(self, fds):
  46. '''Request forkserver to create a child process.
  47. Returns a pair of fds (status_r, data_w). The calling process can read
  48. the child process's pid and (eventually) its returncode from status_r.
  49. The calling process should write to data_w the pickled preparation and
  50. process data.
  51. '''
  52. self.ensure_running()
  53. if len(fds) + 4 >= MAXFDS_TO_SEND:
  54. raise ValueError('too many fds')
  55. with socket.socket(socket.AF_UNIX) as client:
  56. client.connect(self._forkserver_address)
  57. parent_r, child_w = os.pipe()
  58. child_r, parent_w = os.pipe()
  59. allfds = [child_r, child_w, self._forkserver_alive_fd,
  60. semaphore_tracker.getfd()]
  61. allfds += fds
  62. try:
  63. reduction.sendfds(client, allfds)
  64. return parent_r, parent_w
  65. except:
  66. os.close(parent_r)
  67. os.close(parent_w)
  68. raise
  69. finally:
  70. os.close(child_r)
  71. os.close(child_w)
  72. def ensure_running(self):
  73. '''Make sure that a fork server is running.
  74. This can be called from any process. Note that usually a child
  75. process will just reuse the forkserver started by its parent, so
  76. ensure_running() will do nothing.
  77. '''
  78. with self._lock:
  79. semaphore_tracker.ensure_running()
  80. if self._forkserver_pid is not None:
  81. # forkserver was launched before, is it still running?
  82. pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
  83. if not pid:
  84. # still alive
  85. return
  86. # dead, launch it again
  87. os.close(self._forkserver_alive_fd)
  88. self._forkserver_address = None
  89. self._forkserver_alive_fd = None
  90. self._forkserver_pid = None
  91. cmd = ('from multiprocessing.forkserver import main; ' +
  92. 'main(%d, %d, %r, **%r)')
  93. if self._preload_modules:
  94. desired_keys = {'main_path', 'sys_path'}
  95. data = spawn.get_preparation_data('ignore')
  96. data = {x: y for x, y in data.items() if x in desired_keys}
  97. else:
  98. data = {}
  99. with socket.socket(socket.AF_UNIX) as listener:
  100. address = connection.arbitrary_address('AF_UNIX')
  101. listener.bind(address)
  102. os.chmod(address, 0o600)
  103. listener.listen()
  104. # all client processes own the write end of the "alive" pipe;
  105. # when they all terminate the read end becomes ready.
  106. alive_r, alive_w = os.pipe()
  107. try:
  108. fds_to_pass = [listener.fileno(), alive_r]
  109. cmd %= (listener.fileno(), alive_r, self._preload_modules,
  110. data)
  111. exe = spawn.get_executable()
  112. args = [exe] + util._args_from_interpreter_flags()
  113. args += ['-c', cmd]
  114. pid = util.spawnv_passfds(exe, args, fds_to_pass)
  115. except:
  116. os.close(alive_w)
  117. raise
  118. finally:
  119. os.close(alive_r)
  120. self._forkserver_address = address
  121. self._forkserver_alive_fd = alive_w
  122. self._forkserver_pid = pid
  123. #
  124. #
  125. #
  126. def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
  127. '''Run forkserver.'''
  128. if preload:
  129. if '__main__' in preload and main_path is not None:
  130. process.current_process()._inheriting = True
  131. try:
  132. spawn.import_main_path(main_path)
  133. finally:
  134. del process.current_process()._inheriting
  135. for modname in preload:
  136. try:
  137. __import__(modname)
  138. except ImportError:
  139. pass
  140. util._close_stdin()
  141. sig_r, sig_w = os.pipe()
  142. os.set_blocking(sig_r, False)
  143. os.set_blocking(sig_w, False)
  144. def sigchld_handler(*_unused):
  145. # Dummy signal handler, doesn't do anything
  146. pass
  147. handlers = {
  148. # unblocking SIGCHLD allows the wakeup fd to notify our event loop
  149. signal.SIGCHLD: sigchld_handler,
  150. # protect the process from ^C
  151. signal.SIGINT: signal.SIG_IGN,
  152. }
  153. old_handlers = {sig: signal.signal(sig, val)
  154. for (sig, val) in handlers.items()}
  155. # calling os.write() in the Python signal handler is racy
  156. signal.set_wakeup_fd(sig_w)
  157. # map child pids to client fds
  158. pid_to_fd = {}
  159. with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
  160. selectors.DefaultSelector() as selector:
  161. _forkserver._forkserver_address = listener.getsockname()
  162. selector.register(listener, selectors.EVENT_READ)
  163. selector.register(alive_r, selectors.EVENT_READ)
  164. selector.register(sig_r, selectors.EVENT_READ)
  165. while True:
  166. try:
  167. while True:
  168. rfds = [key.fileobj for (key, events) in selector.select()]
  169. if rfds:
  170. break
  171. if alive_r in rfds:
  172. # EOF because no more client processes left
  173. assert os.read(alive_r, 1) == b'', "Not at EOF?"
  174. raise SystemExit
  175. if sig_r in rfds:
  176. # Got SIGCHLD
  177. os.read(sig_r, 65536) # exhaust
  178. while True:
  179. # Scan for child processes
  180. try:
  181. pid, sts = os.waitpid(-1, os.WNOHANG)
  182. except ChildProcessError:
  183. break
  184. if pid == 0:
  185. break
  186. child_w = pid_to_fd.pop(pid, None)
  187. if child_w is not None:
  188. if os.WIFSIGNALED(sts):
  189. returncode = -os.WTERMSIG(sts)
  190. else:
  191. if not os.WIFEXITED(sts):
  192. raise AssertionError(
  193. "Child {0:n} status is {1:n}".format(
  194. pid,sts))
  195. returncode = os.WEXITSTATUS(sts)
  196. # Send exit code to client process
  197. try:
  198. write_signed(child_w, returncode)
  199. except BrokenPipeError:
  200. # client vanished
  201. pass
  202. os.close(child_w)
  203. else:
  204. # This shouldn't happen really
  205. warnings.warn('forkserver: waitpid returned '
  206. 'unexpected pid %d' % pid)
  207. if listener in rfds:
  208. # Incoming fork request
  209. with listener.accept()[0] as s:
  210. # Receive fds from client
  211. fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
  212. if len(fds) > MAXFDS_TO_SEND:
  213. raise RuntimeError(
  214. "Too many ({0:n}) fds to send".format(
  215. len(fds)))
  216. child_r, child_w, *fds = fds
  217. s.close()
  218. pid = os.fork()
  219. if pid == 0:
  220. # Child
  221. code = 1
  222. try:
  223. listener.close()
  224. selector.close()
  225. unused_fds = [alive_r, child_w, sig_r, sig_w]
  226. unused_fds.extend(pid_to_fd.values())
  227. code = _serve_one(child_r, fds,
  228. unused_fds,
  229. old_handlers)
  230. except Exception:
  231. sys.excepthook(*sys.exc_info())
  232. sys.stderr.flush()
  233. finally:
  234. os._exit(code)
  235. else:
  236. # Send pid to client process
  237. try:
  238. write_signed(child_w, pid)
  239. except BrokenPipeError:
  240. # client vanished
  241. pass
  242. pid_to_fd[pid] = child_w
  243. os.close(child_r)
  244. for fd in fds:
  245. os.close(fd)
  246. except OSError as e:
  247. if e.errno != errno.ECONNABORTED:
  248. raise
  249. def _serve_one(child_r, fds, unused_fds, handlers):
  250. # close unnecessary stuff and reset signal handlers
  251. signal.set_wakeup_fd(-1)
  252. for sig, val in handlers.items():
  253. signal.signal(sig, val)
  254. for fd in unused_fds:
  255. os.close(fd)
  256. (_forkserver._forkserver_alive_fd,
  257. semaphore_tracker._semaphore_tracker._fd,
  258. *_forkserver._inherited_fds) = fds
  259. # Run process object received over pipe
  260. code = spawn._main(child_r)
  261. return code
  262. #
  263. # Read and write signed numbers
  264. #
  265. def read_signed(fd):
  266. data = b''
  267. length = SIGNED_STRUCT.size
  268. while len(data) < length:
  269. s = os.read(fd, length - len(data))
  270. if not s:
  271. raise EOFError('unexpected EOF')
  272. data += s
  273. return SIGNED_STRUCT.unpack(data)[0]
  274. def write_signed(fd, n):
  275. msg = SIGNED_STRUCT.pack(n)
  276. while msg:
  277. nbytes = os.write(fd, msg)
  278. if nbytes == 0:
  279. raise RuntimeError('should not get here')
  280. msg = msg[nbytes:]
  281. #
  282. #
  283. #
  284. _forkserver = ForkServer()
  285. ensure_running = _forkserver.ensure_running
  286. get_inherited_fds = _forkserver.get_inherited_fds
  287. connect_to_new_process = _forkserver.connect_to_new_process
  288. set_forkserver_preload = _forkserver.set_forkserver_preload