asyncio: Refactor SIGCHLD handling. By Anthony Baire.

This commit is contained in:
Guido van Rossum 2013-11-04 15:50:46 -08:00
parent ccea08462b
commit 0eaa5ac9b5
5 changed files with 1315 additions and 199 deletions

View file

@ -1,10 +1,11 @@
"""Event loop and event loop policy.""" """Event loop and event loop policy."""
__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', __all__ = ['AbstractEventLoopPolicy',
'AbstractEventLoop', 'AbstractServer', 'AbstractEventLoop', 'AbstractServer',
'Handle', 'TimerHandle', 'Handle', 'TimerHandle',
'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop_policy', 'set_event_loop_policy',
'get_event_loop', 'set_event_loop', 'new_event_loop', 'get_event_loop', 'set_event_loop', 'new_event_loop',
'get_child_watcher', 'set_child_watcher',
] ]
import subprocess import subprocess
@ -318,8 +319,18 @@ class AbstractEventLoopPolicy:
"""XXX""" """XXX"""
raise NotImplementedError raise NotImplementedError
# Child processes handling (Unix only).
class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): def get_child_watcher(self):
"""XXX"""
raise NotImplementedError
def set_child_watcher(self, watcher):
"""XXX"""
raise NotImplementedError
class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy):
"""Default policy implementation for accessing the event loop. """Default policy implementation for accessing the event loop.
In this policy, each thread has its own event loop. However, we In this policy, each thread has its own event loop. However, we
@ -332,28 +343,34 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
associated). associated).
""" """
_loop = None _loop_factory = None
_set_called = False
class _Local(threading.local):
_loop = None
_set_called = False
def __init__(self):
self._local = self._Local()
def get_event_loop(self): def get_event_loop(self):
"""Get the event loop. """Get the event loop.
This may be None or an instance of EventLoop. This may be None or an instance of EventLoop.
""" """
if (self._loop is None and if (self._local._loop is None and
not self._set_called and not self._local._set_called and
isinstance(threading.current_thread(), threading._MainThread)): isinstance(threading.current_thread(), threading._MainThread)):
self._loop = self.new_event_loop() self._local._loop = self.new_event_loop()
assert self._loop is not None, \ assert self._local._loop is not None, \
('There is no current event loop in thread %r.' % ('There is no current event loop in thread %r.' %
threading.current_thread().name) threading.current_thread().name)
return self._loop return self._local._loop
def set_event_loop(self, loop): def set_event_loop(self, loop):
"""Set the event loop.""" """Set the event loop."""
self._set_called = True self._local._set_called = True
assert loop is None or isinstance(loop, AbstractEventLoop) assert loop is None or isinstance(loop, AbstractEventLoop)
self._loop = loop self._local._loop = loop
def new_event_loop(self): def new_event_loop(self):
"""Create a new event loop. """Create a new event loop.
@ -361,12 +378,7 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
You must call set_event_loop() to make this the current event You must call set_event_loop() to make this the current event
loop. loop.
""" """
if sys.platform == 'win32': # pragma: no cover return self._loop_factory()
from . import windows_events
return windows_events.SelectorEventLoop()
else: # pragma: no cover
from . import unix_events
return unix_events.SelectorEventLoop()
# Event loop policy. The policy itself is always global, even if the # Event loop policy. The policy itself is always global, even if the
@ -375,12 +387,22 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
# call to get_event_loop_policy(). # call to get_event_loop_policy().
_event_loop_policy = None _event_loop_policy = None
# Lock for protecting the on-the-fly creation of the event loop policy.
_lock = threading.Lock()
def _init_event_loop_policy():
global _event_loop_policy
with _lock:
if _event_loop_policy is None: # pragma: no branch
from . import DefaultEventLoopPolicy
_event_loop_policy = DefaultEventLoopPolicy()
def get_event_loop_policy(): def get_event_loop_policy():
"""XXX""" """XXX"""
global _event_loop_policy
if _event_loop_policy is None: if _event_loop_policy is None:
_event_loop_policy = DefaultEventLoopPolicy() _init_event_loop_policy()
return _event_loop_policy return _event_loop_policy
@ -404,3 +426,13 @@ def set_event_loop(loop):
def new_event_loop(): def new_event_loop():
"""XXX""" """XXX"""
return get_event_loop_policy().new_event_loop() return get_event_loop_policy().new_event_loop()
def get_child_watcher():
"""XXX"""
return get_event_loop_policy().get_child_watcher()
def set_child_watcher(watcher):
"""XXX"""
return get_event_loop_policy().set_child_watcher(watcher)

View file

@ -8,6 +8,7 @@ import socket
import stat import stat
import subprocess import subprocess
import sys import sys
import threading
from . import base_subprocess from . import base_subprocess
@ -20,7 +21,10 @@ from . import transports
from .log import logger from .log import logger
__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] __all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR',
'AbstractChildWatcher', 'SafeChildWatcher',
'FastChildWatcher', 'DefaultEventLoopPolicy',
]
STDIN = 0 STDIN = 0
STDOUT = 1 STDOUT = 1
@ -31,7 +35,7 @@ if sys.platform == 'win32': # pragma: no cover
raise ImportError('Signals are not really supported on Windows') raise ImportError('Signals are not really supported on Windows')
class SelectorEventLoop(selector_events.BaseSelectorEventLoop): class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Unix event loop """Unix event loop
Adds signal handling to SelectorEventLoop Adds signal handling to SelectorEventLoop
@ -40,17 +44,10 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
def __init__(self, selector=None): def __init__(self, selector=None):
super().__init__(selector) super().__init__(selector)
self._signal_handlers = {} self._signal_handlers = {}
self._subprocesses = {}
def _socketpair(self): def _socketpair(self):
return socket.socketpair() return socket.socketpair()
def close(self):
handler = self._signal_handlers.get(signal.SIGCHLD)
if handler is not None:
self.remove_signal_handler(signal.SIGCHLD)
super().close()
def add_signal_handler(self, sig, callback, *args): def add_signal_handler(self, sig, callback, *args):
"""Add a handler for a signal. UNIX only. """Add a handler for a signal. UNIX only.
@ -152,49 +149,20 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _make_subprocess_transport(self, protocol, args, shell, def _make_subprocess_transport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): extra=None, **kwargs):
self._reg_sigchld() with events.get_child_watcher() as watcher:
transp = _UnixSubprocessTransport(self, protocol, args, shell, transp = _UnixSubprocessTransport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs) extra=None, **kwargs)
self._subprocesses[transp.get_pid()] = transp watcher.add_child_handler(transp.get_pid(),
self._child_watcher_callback, transp)
yield from transp._post_init() yield from transp._post_init()
return transp return transp
def _reg_sigchld(self): def _child_watcher_callback(self, pid, returncode, transp):
if signal.SIGCHLD not in self._signal_handlers: self.call_soon_threadsafe(transp._process_exited, returncode)
self.add_signal_handler(signal.SIGCHLD, self._sig_chld)
def _sig_chld(self): def _subprocess_closed(self, transp):
try: pass
# Because of signal coalescing, we must keep calling waitpid() as
# long as we're able to reap a child.
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except ChildProcessError:
break # No more child processes exist.
if pid == 0:
break # All remaining child processes are still alive.
elif os.WIFSIGNALED(status):
# A child process died because of a signal.
returncode = -os.WTERMSIG(status)
elif os.WIFEXITED(status):
# A child process exited (e.g. sys.exit()).
returncode = os.WEXITSTATUS(status)
else:
# A child exited, but we don't understand its status.
# This shouldn't happen, but if it does, let's just
# return that status; perhaps that helps debug it.
returncode = status
transp = self._subprocesses.get(pid)
if transp is not None:
transp._process_exited(returncode)
except Exception:
logger.exception('Unknown exception in SIGCHLD handler')
def _subprocess_closed(self, transport):
pid = transport.get_pid()
self._subprocesses.pop(pid, None)
def _set_nonblocking(fd): def _set_nonblocking(fd):
@ -423,3 +391,335 @@ class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport):
if stdin_w is not None: if stdin_w is not None:
stdin.close() stdin.close()
self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize) self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize)
class AbstractChildWatcher:
"""Abstract base class for monitoring child processes.
Objects derived from this class monitor a collection of subprocesses and
report their termination or interruption by a signal.
New callbacks are registered with .add_child_handler(). Starting a new
process must be done within a 'with' block to allow the watcher to suspend
its activity until the new process if fully registered (this is needed to
prevent a race condition in some implementations).
Example:
with watcher:
proc = subprocess.Popen("sleep 1")
watcher.add_child_handler(proc.pid, callback)
Notes:
Implementations of this class must be thread-safe.
Since child watcher objects may catch the SIGCHLD signal and call
waitpid(-1), there should be only one active object per process.
"""
def add_child_handler(self, pid, callback, *args):
"""Register a new child handler.
Arrange for callback(pid, returncode, *args) to be called when
process 'pid' terminates. Specifying another callback for the same
process replaces the previous handler.
Note: callback() must be thread-safe
"""
raise NotImplementedError()
def remove_child_handler(self, pid):
"""Removes the handler for process 'pid'.
The function returns True if the handler was successfully removed,
False if there was nothing to remove."""
raise NotImplementedError()
def set_loop(self, loop):
"""Reattach the watcher to another event loop.
Note: loop may be None
"""
raise NotImplementedError()
def close(self):
"""Close the watcher.
This must be called to make sure that any underlying resource is freed.
"""
raise NotImplementedError()
def __enter__(self):
"""Enter the watcher's context and allow starting new processes
This function must return self"""
raise NotImplementedError()
def __exit__(self, a, b, c):
"""Exit the watcher's context"""
raise NotImplementedError()
class BaseChildWatcher(AbstractChildWatcher):
def __init__(self, loop):
self._loop = None
self._callbacks = {}
self.set_loop(loop)
def close(self):
self.set_loop(None)
self._callbacks.clear()
def _do_waitpid(self, expected_pid):
raise NotImplementedError()
def _do_waitpid_all(self):
raise NotImplementedError()
def set_loop(self, loop):
assert loop is None or isinstance(loop, events.AbstractEventLoop)
if self._loop is not None:
self._loop.remove_signal_handler(signal.SIGCHLD)
self._loop = loop
if loop is not None:
loop.add_signal_handler(signal.SIGCHLD, self._sig_chld)
# Prevent a race condition in case a child terminated
# during the switch.
self._do_waitpid_all()
def remove_child_handler(self, pid):
try:
del self._callbacks[pid]
return True
except KeyError:
return False
def _sig_chld(self):
try:
self._do_waitpid_all()
except Exception:
logger.exception('Unknown exception in SIGCHLD handler')
def _compute_returncode(self, status):
if os.WIFSIGNALED(status):
# The child process died because of a signal.
return -os.WTERMSIG(status)
elif os.WIFEXITED(status):
# The child process exited (e.g sys.exit()).
return os.WEXITSTATUS(status)
else:
# The child exited, but we don't understand its status.
# This shouldn't happen, but if it does, let's just
# return that status; perhaps that helps debug it.
return status
class SafeChildWatcher(BaseChildWatcher):
"""'Safe' child watcher implementation.
This implementation avoids disrupting other code spawning processes by
polling explicitly each process in the SIGCHLD handler instead of calling
os.waitpid(-1).
This is a safe solution but it has a significant overhead when handling a
big number of children (O(n) each time SIGCHLD is raised)
"""
def __enter__(self):
return self
def __exit__(self, a, b, c):
pass
def add_child_handler(self, pid, callback, *args):
self._callbacks[pid] = callback, args
# Prevent a race condition in case the child is already terminated.
self._do_waitpid(pid)
def _do_waitpid_all(self):
for pid in list(self._callbacks):
self._do_waitpid(pid)
def _do_waitpid(self, expected_pid):
assert expected_pid > 0
try:
pid, status = os.waitpid(expected_pid, os.WNOHANG)
except ChildProcessError:
# The child process is already reaped
# (may happen if waitpid() is called elsewhere).
pid = expected_pid
returncode = 255
logger.warning(
"Unknown child process pid %d, will report returncode 255",
pid)
else:
if pid == 0:
# The child process is still alive.
return
returncode = self._compute_returncode(status)
try:
callback, args = self._callbacks.pop(pid)
except KeyError: # pragma: no cover
# May happen if .remove_child_handler() is called
# after os.waitpid() returns.
pass
else:
callback(pid, returncode, *args)
class FastChildWatcher(BaseChildWatcher):
"""'Fast' child watcher implementation.
This implementation reaps every terminated processes by calling
os.waitpid(-1) directly, possibly breaking other code spawning processes
and waiting for their termination.
There is no noticeable overhead when handling a big number of children
(O(1) each time a child terminates).
"""
def __init__(self, loop):
super().__init__(loop)
self._lock = threading.Lock()
self._zombies = {}
self._forks = 0
def close(self):
super().close()
self._zombies.clear()
def __enter__(self):
with self._lock:
self._forks += 1
return self
def __exit__(self, a, b, c):
with self._lock:
self._forks -= 1
if self._forks or not self._zombies:
return
collateral_victims = str(self._zombies)
self._zombies.clear()
logger.warning(
"Caught subprocesses termination from unknown pids: %s",
collateral_victims)
def add_child_handler(self, pid, callback, *args):
assert self._forks, "Must use the context manager"
self._callbacks[pid] = callback, args
try:
# Ensure that the child is not already terminated.
# (raise KeyError if still alive)
returncode = self._zombies.pop(pid)
# Child is dead, therefore we can fire the callback immediately.
# First we remove it from the dict.
# (raise KeyError if .remove_child_handler() was called in-between)
del self._callbacks[pid]
except KeyError:
pass
else:
callback(pid, returncode, *args)
def _do_waitpid_all(self):
# Because of signal coalescing, we must keep calling waitpid() as
# long as we're able to reap a child.
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except ChildProcessError:
# No more child processes exist.
return
else:
if pid == 0:
# A child process is still alive.
return
returncode = self._compute_returncode(status)
try:
callback, args = self._callbacks.pop(pid)
except KeyError:
# unknown child
with self._lock:
if self._forks:
# It may not be registered yet.
self._zombies[pid] = returncode
continue
logger.warning(
"Caught subprocess termination from unknown pid: "
"%d -> %d", pid, returncode)
else:
callback(pid, returncode, *args)
class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
"""XXX"""
_loop_factory = _UnixSelectorEventLoop
def __init__(self):
super().__init__()
self._watcher = None
def _init_watcher(self):
with events._lock:
if self._watcher is None: # pragma: no branch
if isinstance(threading.current_thread(),
threading._MainThread):
self._watcher = SafeChildWatcher(self._local._loop)
else:
self._watcher = SafeChildWatcher(None)
def set_event_loop(self, loop):
"""Set the event loop.
As a side effect, if a child watcher was set before, then calling
.set_event_loop() from the main thread will call .set_loop(loop) on the
child watcher.
"""
super().set_event_loop(loop)
if self._watcher is not None and \
isinstance(threading.current_thread(), threading._MainThread):
self._watcher.set_loop(loop)
def get_child_watcher(self):
"""Get the child watcher
If not yet set, a SafeChildWatcher object is automatically created.
"""
if self._watcher is None:
self._init_watcher()
return self._watcher
def set_child_watcher(self, watcher):
"""Set the child watcher"""
assert watcher is None or isinstance(watcher, AbstractChildWatcher)
if self._watcher is not None:
self._watcher.close()
self._watcher = watcher
SelectorEventLoop = _UnixSelectorEventLoop
DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy

View file

@ -7,6 +7,7 @@ import weakref
import struct import struct
import _winapi import _winapi
from . import events
from . import base_subprocess from . import base_subprocess
from . import futures from . import futures
from . import proactor_events from . import proactor_events
@ -17,7 +18,9 @@ from .log import logger
from . import _overlapped from . import _overlapped
__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor',
'DefaultEventLoopPolicy',
]
NULL = 0 NULL = 0
@ -108,7 +111,7 @@ class PipeServer(object):
__del__ = close __del__ = close
class SelectorEventLoop(selector_events.BaseSelectorEventLoop): class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Windows version of selector event loop.""" """Windows version of selector event loop."""
def _socketpair(self): def _socketpair(self):
@ -453,3 +456,13 @@ class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport):
f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) f = self._loop._proactor.wait_for_handle(int(self._proc._handle))
f.add_done_callback(callback) f.add_done_callback(callback)
SelectorEventLoop = _WindowsSelectorEventLoop
class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
_loop_factory = SelectorEventLoop
DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy

View file

@ -1308,8 +1308,17 @@ else:
from asyncio import selectors from asyncio import selectors
from asyncio import unix_events from asyncio import unix_events
class UnixEventLoopTestsMixin(EventLoopTestsMixin):
def setUp(self):
super().setUp()
events.set_child_watcher(unix_events.SafeChildWatcher(self.loop))
def tearDown(self):
events.set_child_watcher(None)
super().tearDown()
if hasattr(selectors, 'KqueueSelector'): if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(EventLoopTestsMixin, class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
@ -1318,7 +1327,7 @@ else:
selectors.KqueueSelector()) selectors.KqueueSelector())
if hasattr(selectors, 'EpollSelector'): if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(EventLoopTestsMixin, class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
@ -1326,7 +1335,7 @@ else:
return unix_events.SelectorEventLoop(selectors.EpollSelector()) return unix_events.SelectorEventLoop(selectors.EpollSelector())
if hasattr(selectors, 'PollSelector'): if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(EventLoopTestsMixin, class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
@ -1334,7 +1343,7 @@ else:
return unix_events.SelectorEventLoop(selectors.PollSelector()) return unix_events.SelectorEventLoop(selectors.PollSelector())
# Should always exist. # Should always exist.
class SelectEventLoopTests(EventLoopTestsMixin, class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
@ -1557,25 +1566,36 @@ class ProtocolsAbsTests(unittest.TestCase):
class PolicyTests(unittest.TestCase): class PolicyTests(unittest.TestCase):
def create_policy(self):
if sys.platform == "win32":
from asyncio import windows_events
return windows_events.DefaultEventLoopPolicy()
else:
from asyncio import unix_events
return unix_events.DefaultEventLoopPolicy()
def test_event_loop_policy(self): def test_event_loop_policy(self):
policy = events.AbstractEventLoopPolicy() policy = events.AbstractEventLoopPolicy()
self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.get_event_loop)
self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.set_event_loop, object())
self.assertRaises(NotImplementedError, policy.new_event_loop) self.assertRaises(NotImplementedError, policy.new_event_loop)
self.assertRaises(NotImplementedError, policy.get_child_watcher)
self.assertRaises(NotImplementedError, policy.set_child_watcher,
object())
def test_get_event_loop(self): def test_get_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
self.assertIsNone(policy._loop) self.assertIsNone(policy._local._loop)
loop = policy.get_event_loop() loop = policy.get_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, events.AbstractEventLoop)
self.assertIs(policy._loop, loop) self.assertIs(policy._local._loop, loop)
self.assertIs(loop, policy.get_event_loop()) self.assertIs(loop, policy.get_event_loop())
loop.close() loop.close()
def test_get_event_loop_after_set_none(self): def test_get_event_loop_after_set_none(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
policy.set_event_loop(None) policy.set_event_loop(None)
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
@ -1583,7 +1603,7 @@ class PolicyTests(unittest.TestCase):
def test_get_event_loop_thread(self, m_current_thread): def test_get_event_loop_thread(self, m_current_thread):
def f(): def f():
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
th = threading.Thread(target=f) th = threading.Thread(target=f)
@ -1591,14 +1611,14 @@ class PolicyTests(unittest.TestCase):
th.join() th.join()
def test_new_event_loop(self): def test_new_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
loop = policy.new_event_loop() loop = policy.new_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, events.AbstractEventLoop)
loop.close() loop.close()
def test_set_event_loop(self): def test_set_event_loop(self):
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
old_loop = policy.get_event_loop() old_loop = policy.get_event_loop()
self.assertRaises(AssertionError, policy.set_event_loop, object()) self.assertRaises(AssertionError, policy.set_event_loop, object())
@ -1621,7 +1641,7 @@ class PolicyTests(unittest.TestCase):
old_policy = events.get_event_loop_policy() old_policy = events.get_event_loop_policy()
policy = events.DefaultEventLoopPolicy() policy = self.create_policy()
events.set_event_loop_policy(policy) events.set_event_loop_policy(policy)
self.assertIs(policy, events.get_event_loop_policy()) self.assertIs(policy, events.get_event_loop_policy())
self.assertIsNot(policy, old_policy) self.assertIsNot(policy, old_policy)

File diff suppressed because it is too large Load diff