bpo-29302: Implement contextlib.AsyncExitStack. (#4790)

This commit is contained in:
Ilya Kulakov 2018-01-25 12:51:18 -08:00 committed by Yury Selivanov
parent 6ab62920c8
commit 1aa094f740
6 changed files with 451 additions and 81 deletions

View file

@ -435,6 +435,44 @@ Functions and classes provided:
callbacks registered, the arguments passed in will indicate that no callbacks registered, the arguments passed in will indicate that no
exception occurred. exception occurred.
.. class:: AsyncExitStack()
An :ref:`asynchronous context manager <async-context-managers>`, similar
to :class:`ExitStack`, that supports combining both synchronous and
asynchronous context managers, as well as having coroutines for
cleanup logic.
The :meth:`close` method is not implemented, :meth:`aclose` must be used
instead.
.. method:: enter_async_context(cm)
Similar to :meth:`enter_context` but expects an asynchronous context
manager.
.. method:: push_async_exit(exit)
Similar to :meth:`push` but expects either an asynchronous context manager
or a coroutine.
.. method:: push_async_callback(callback, *args, **kwds)
Similar to :meth:`callback` but expects a coroutine.
.. method:: aclose()
Similar to :meth:`close` but properly handles awaitables.
Continuing the example for :func:`asynccontextmanager`::
async with AsyncExitStack() as stack:
connections = [await stack.enter_async_context(get_connection())
for i in range(5)]
# All opened connections will automatically be released at the end of
# the async with statement, even if attempts to open a connection
# later in the list raise an exception.
.. versionadded:: 3.7
Examples and Recipes Examples and Recipes
-------------------- --------------------

View file

@ -379,6 +379,9 @@ contextlib
:class:`~contextlib.AbstractAsyncContextManager` have been added. (Contributed :class:`~contextlib.AbstractAsyncContextManager` have been added. (Contributed
by Jelle Zijlstra in :issue:`29679` and :issue:`30241`.) by Jelle Zijlstra in :issue:`29679` and :issue:`30241`.)
:class:`contextlib.AsyncExitStack` has been added. (Contributed by
Alexander Mohr and Ilya Kulakov in :issue:`29302`.)
cProfile cProfile
-------- --------

View file

@ -7,7 +7,7 @@ from functools import wraps
__all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext", __all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
"AbstractContextManager", "AbstractAsyncContextManager", "AbstractContextManager", "AbstractAsyncContextManager",
"ContextDecorator", "ExitStack", "AsyncExitStack", "ContextDecorator", "ExitStack",
"redirect_stdout", "redirect_stderr", "suppress"] "redirect_stdout", "redirect_stderr", "suppress"]
@ -365,85 +365,102 @@ class suppress(AbstractContextManager):
return exctype is not None and issubclass(exctype, self._exceptions) return exctype is not None and issubclass(exctype, self._exceptions)
# Inspired by discussions on http://bugs.python.org/issue13585 class _BaseExitStack:
class ExitStack(AbstractContextManager): """A base class for ExitStack and AsyncExitStack."""
"""Context manager for dynamic management of a stack of exit callbacks
For example: @staticmethod
def _create_exit_wrapper(cm, cm_exit):
def _exit_wrapper(exc_type, exc, tb):
return cm_exit(cm, exc_type, exc, tb)
return _exit_wrapper
with ExitStack() as stack: @staticmethod
files = [stack.enter_context(open(fname)) for fname in filenames] def _create_cb_wrapper(callback, *args, **kwds):
# All opened files will automatically be closed at the end of def _exit_wrapper(exc_type, exc, tb):
# the with statement, even if attempts to open files later callback(*args, **kwds)
# in the list raise an exception return _exit_wrapper
"""
def __init__(self): def __init__(self):
self._exit_callbacks = deque() self._exit_callbacks = deque()
def pop_all(self): def pop_all(self):
"""Preserve the context stack by transferring it to a new instance""" """Preserve the context stack by transferring it to a new instance."""
new_stack = type(self)() new_stack = type(self)()
new_stack._exit_callbacks = self._exit_callbacks new_stack._exit_callbacks = self._exit_callbacks
self._exit_callbacks = deque() self._exit_callbacks = deque()
return new_stack return new_stack
def _push_cm_exit(self, cm, cm_exit):
"""Helper to correctly register callbacks to __exit__ methods"""
def _exit_wrapper(*exc_details):
return cm_exit(cm, *exc_details)
_exit_wrapper.__self__ = cm
self.push(_exit_wrapper)
def push(self, exit): def push(self, exit):
"""Registers a callback with the standard __exit__ method signature """Registers a callback with the standard __exit__ method signature.
Can suppress exceptions the same way __exit__ methods can.
Can suppress exceptions the same way __exit__ method can.
Also accepts any object with an __exit__ method (registering a call Also accepts any object with an __exit__ method (registering a call
to the method instead of the object itself) to the method instead of the object itself).
""" """
# We use an unbound method rather than a bound method to follow # We use an unbound method rather than a bound method to follow
# the standard lookup behaviour for special methods # the standard lookup behaviour for special methods.
_cb_type = type(exit) _cb_type = type(exit)
try: try:
exit_method = _cb_type.__exit__ exit_method = _cb_type.__exit__
except AttributeError: except AttributeError:
# Not a context manager, so assume its a callable # Not a context manager, so assume it's a callable.
self._exit_callbacks.append(exit) self._push_exit_callback(exit)
else: else:
self._push_cm_exit(exit, exit_method) self._push_cm_exit(exit, exit_method)
return exit # Allow use as a decorator return exit # Allow use as a decorator.
def callback(self, callback, *args, **kwds):
"""Registers an arbitrary callback and arguments.
Cannot suppress exceptions.
"""
def _exit_wrapper(exc_type, exc, tb):
callback(*args, **kwds)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection
_exit_wrapper.__wrapped__ = callback
self.push(_exit_wrapper)
return callback # Allow use as a decorator
def enter_context(self, cm): def enter_context(self, cm):
"""Enters the supplied context manager """Enters the supplied context manager.
If successful, also pushes its __exit__ method as a callback and If successful, also pushes its __exit__ method as a callback and
returns the result of the __enter__ method. returns the result of the __enter__ method.
""" """
# We look up the special methods on the type to match the with statement # We look up the special methods on the type to match the with
# statement.
_cm_type = type(cm) _cm_type = type(cm)
_exit = _cm_type.__exit__ _exit = _cm_type.__exit__
result = _cm_type.__enter__(cm) result = _cm_type.__enter__(cm)
self._push_cm_exit(cm, _exit) self._push_cm_exit(cm, _exit)
return result return result
def close(self): def callback(self, callback, *args, **kwds):
"""Immediately unwind the context stack""" """Registers an arbitrary callback and arguments.
self.__exit__(None, None, None)
Cannot suppress exceptions.
"""
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection.
_exit_wrapper.__wrapped__ = callback
self._push_exit_callback(_exit_wrapper)
return callback # Allow use as a decorator
def _push_cm_exit(self, cm, cm_exit):
"""Helper to correctly register callbacks to __exit__ methods."""
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
_exit_wrapper.__self__ = cm
self._push_exit_callback(_exit_wrapper, True)
def _push_exit_callback(self, callback, is_sync=True):
self._exit_callbacks.append((is_sync, callback))
# Inspired by discussions on http://bugs.python.org/issue13585
class ExitStack(_BaseExitStack, AbstractContextManager):
"""Context manager for dynamic management of a stack of exit callbacks.
For example:
with ExitStack() as stack:
files = [stack.enter_context(open(fname)) for fname in filenames]
# All opened files will automatically be closed at the end of
# the with statement, even if attempts to open files later
# in the list raise an exception.
"""
def __enter__(self):
return self
def __exit__(self, *exc_details): def __exit__(self, *exc_details):
received_exc = exc_details[0] is not None received_exc = exc_details[0] is not None
@ -470,7 +487,8 @@ class ExitStack(AbstractContextManager):
suppressed_exc = False suppressed_exc = False
pending_raise = False pending_raise = False
while self._exit_callbacks: while self._exit_callbacks:
cb = self._exit_callbacks.pop() is_sync, cb = self._exit_callbacks.pop()
assert is_sync
try: try:
if cb(*exc_details): if cb(*exc_details):
suppressed_exc = True suppressed_exc = True
@ -493,6 +511,147 @@ class ExitStack(AbstractContextManager):
raise raise
return received_exc and suppressed_exc return received_exc and suppressed_exc
def close(self):
"""Immediately unwind the context stack."""
self.__exit__(None, None, None)
# Inspired by discussions on https://bugs.python.org/issue29302
class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
"""Async context manager for dynamic management of a stack of exit
callbacks.
For example:
async with AsyncExitStack() as stack:
connections = [await stack.enter_async_context(get_connection())
for i in range(5)]
# All opened connections will automatically be released at the
# end of the async with statement, even if attempts to open a
# connection later in the list raise an exception.
"""
@staticmethod
def _create_async_exit_wrapper(cm, cm_exit):
async def _exit_wrapper(exc_type, exc, tb):
return await cm_exit(cm, exc_type, exc, tb)
return _exit_wrapper
@staticmethod
def _create_async_cb_wrapper(callback, *args, **kwds):
async def _exit_wrapper(exc_type, exc, tb):
await callback(*args, **kwds)
return _exit_wrapper
async def enter_async_context(self, cm):
"""Enters the supplied async context manager.
If successful, also pushes its __aexit__ method as a callback and
returns the result of the __aenter__ method.
"""
_cm_type = type(cm)
_exit = _cm_type.__aexit__
result = await _cm_type.__aenter__(cm)
self._push_async_cm_exit(cm, _exit)
return result
def push_async_exit(self, exit):
"""Registers a coroutine function with the standard __aexit__ method
signature.
Can suppress exceptions the same way __aexit__ method can.
Also accepts any object with an __aexit__ method (registering a call
to the method instead of the object itself).
"""
_cb_type = type(exit)
try:
exit_method = _cb_type.__aexit__
except AttributeError:
# Not an async context manager, so assume it's a coroutine function
self._push_exit_callback(exit, False)
else:
self._push_async_cm_exit(exit, exit_method)
return exit # Allow use as a decorator
def push_async_callback(self, callback, *args, **kwds):
"""Registers an arbitrary coroutine function and arguments.
Cannot suppress exceptions.
"""
_exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection.
_exit_wrapper.__wrapped__ = callback
self._push_exit_callback(_exit_wrapper, False)
return callback # Allow use as a decorator
async def aclose(self):
"""Immediately unwind the context stack."""
await self.__aexit__(None, None, None)
def _push_async_cm_exit(self, cm, cm_exit):
"""Helper to correctly register coroutine function to __aexit__
method."""
_exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
_exit_wrapper.__self__ = cm
self._push_exit_callback(_exit_wrapper, False)
async def __aenter__(self):
return self
async def __aexit__(self, *exc_details):
received_exc = exc_details[0] is not None
# We manipulate the exception state so it behaves as though
# we were actually nesting multiple with statements
frame_exc = sys.exc_info()[1]
def _fix_exception_context(new_exc, old_exc):
# Context may not be correct, so find the end of the chain
while 1:
exc_context = new_exc.__context__
if exc_context is old_exc:
# Context is already set correctly (see issue 20317)
return
if exc_context is None or exc_context is frame_exc:
break
new_exc = exc_context
# Change the end of the chain to point to the exception
# we expect it to reference
new_exc.__context__ = old_exc
# Callbacks are invoked in LIFO order to match the behaviour of
# nested context managers
suppressed_exc = False
pending_raise = False
while self._exit_callbacks:
is_sync, cb = self._exit_callbacks.pop()
try:
if is_sync:
cb_suppress = cb(*exc_details)
else:
cb_suppress = await cb(*exc_details)
if cb_suppress:
suppressed_exc = True
pending_raise = False
exc_details = (None, None, None)
except:
new_exc_details = sys.exc_info()
# simulate the stack of exceptions by setting the context
_fix_exception_context(new_exc_details[1], exc_details[1])
pending_raise = True
exc_details = new_exc_details
if pending_raise:
try:
# bare "raise exc_details[1]" replaces our carefully
# set-up context
fixed_ctx = exc_details[1].__context__
raise exc_details[1]
except BaseException:
exc_details[1].__context__ = fixed_ctx
raise
return received_exc and suppressed_exc
class nullcontext(AbstractContextManager): class nullcontext(AbstractContextManager):
"""Context manager that does no additional processing. """Context manager that does no additional processing.

View file

@ -1,5 +1,6 @@
"""Unit tests for contextlib.py, and other context managers.""" """Unit tests for contextlib.py, and other context managers."""
import asyncio
import io import io
import sys import sys
import tempfile import tempfile
@ -505,17 +506,18 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999]) self.assertEqual(state, [1, 'something else', 999])
class TestExitStack(unittest.TestCase): class TestBaseExitStack:
exit_stack = None
@support.requires_docstrings @support.requires_docstrings
def test_instance_docs(self): def test_instance_docs(self):
# Issue 19330: ensure context manager instances have good docstrings # Issue 19330: ensure context manager instances have good docstrings
cm_docstring = ExitStack.__doc__ cm_docstring = self.exit_stack.__doc__
obj = ExitStack() obj = self.exit_stack()
self.assertEqual(obj.__doc__, cm_docstring) self.assertEqual(obj.__doc__, cm_docstring)
def test_no_resources(self): def test_no_resources(self):
with ExitStack(): with self.exit_stack():
pass pass
def test_callback(self): def test_callback(self):
@ -531,7 +533,7 @@ class TestExitStack(unittest.TestCase):
def _exit(*args, **kwds): def _exit(*args, **kwds):
"""Test metadata propagation""" """Test metadata propagation"""
result.append((args, kwds)) result.append((args, kwds))
with ExitStack() as stack: with self.exit_stack() as stack:
for args, kwds in reversed(expected): for args, kwds in reversed(expected):
if args and kwds: if args and kwds:
f = stack.callback(_exit, *args, **kwds) f = stack.callback(_exit, *args, **kwds)
@ -543,9 +545,9 @@ class TestExitStack(unittest.TestCase):
f = stack.callback(_exit) f = stack.callback(_exit)
self.assertIs(f, _exit) self.assertIs(f, _exit)
for wrapper in stack._exit_callbacks: for wrapper in stack._exit_callbacks:
self.assertIs(wrapper.__wrapped__, _exit) self.assertIs(wrapper[1].__wrapped__, _exit)
self.assertNotEqual(wrapper.__name__, _exit.__name__) self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
self.assertIsNone(wrapper.__doc__, _exit.__doc__) self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_push(self): def test_push(self):
@ -565,21 +567,21 @@ class TestExitStack(unittest.TestCase):
self.fail("Should not be called!") self.fail("Should not be called!")
def __exit__(self, *exc_details): def __exit__(self, *exc_details):
self.check_exc(*exc_details) self.check_exc(*exc_details)
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(_expect_ok) stack.push(_expect_ok)
self.assertIs(stack._exit_callbacks[-1], _expect_ok) self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
cm = ExitCM(_expect_ok) cm = ExitCM(_expect_ok)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push(_suppress_exc) stack.push(_suppress_exc)
self.assertIs(stack._exit_callbacks[-1], _suppress_exc) self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
cm = ExitCM(_expect_exc) cm = ExitCM(_expect_exc)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push(_expect_exc) stack.push(_expect_exc)
self.assertIs(stack._exit_callbacks[-1], _expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
stack.push(_expect_exc) stack.push(_expect_exc)
self.assertIs(stack._exit_callbacks[-1], _expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
1/0 1/0
def test_enter_context(self): def test_enter_context(self):
@ -591,19 +593,19 @@ class TestExitStack(unittest.TestCase):
result = [] result = []
cm = TestCM() cm = TestCM()
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback # Registered first => cleaned up last @stack.callback # Registered first => cleaned up last
def _exit(): def _exit():
result.append(4) result.append(4)
self.assertIsNotNone(_exit) self.assertIsNotNone(_exit)
stack.enter_context(cm) stack.enter_context(cm)
self.assertIs(stack._exit_callbacks[-1].__self__, cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
result.append(2) result.append(2)
self.assertEqual(result, [1, 2, 3, 4]) self.assertEqual(result, [1, 2, 3, 4])
def test_close(self): def test_close(self):
result = [] result = []
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback @stack.callback
def _exit(): def _exit():
result.append(1) result.append(1)
@ -614,7 +616,7 @@ class TestExitStack(unittest.TestCase):
def test_pop_all(self): def test_pop_all(self):
result = [] result = []
with ExitStack() as stack: with self.exit_stack() as stack:
@stack.callback @stack.callback
def _exit(): def _exit():
result.append(3) result.append(3)
@ -627,12 +629,12 @@ class TestExitStack(unittest.TestCase):
def test_exit_raise(self): def test_exit_raise(self):
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: False) stack.push(lambda *exc: False)
1/0 1/0
def test_exit_suppress(self): def test_exit_suppress(self):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: True) stack.push(lambda *exc: True)
1/0 1/0
@ -696,7 +698,7 @@ class TestExitStack(unittest.TestCase):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
stack.callback(raise_exc, KeyError) stack.callback(raise_exc, KeyError)
stack.callback(raise_exc, AttributeError) stack.callback(raise_exc, AttributeError)
@ -724,7 +726,7 @@ class TestExitStack(unittest.TestCase):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(lambda: None) stack.callback(lambda: None)
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
except Exception as exc: except Exception as exc:
@ -733,7 +735,7 @@ class TestExitStack(unittest.TestCase):
self.fail("Expected IndexError, but no exception was raised") self.fail("Expected IndexError, but no exception was raised")
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_exc, KeyError) stack.callback(raise_exc, KeyError)
stack.push(suppress_exc) stack.push(suppress_exc)
stack.callback(raise_exc, IndexError) stack.callback(raise_exc, IndexError)
@ -760,7 +762,7 @@ class TestExitStack(unittest.TestCase):
# fix, ExitStack would try to fix it *again* and get into an # fix, ExitStack would try to fix it *again* and get into an
# infinite self-referential loop # infinite self-referential loop
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.enter_context(gets_the_context_right(exc4)) stack.enter_context(gets_the_context_right(exc4))
stack.enter_context(gets_the_context_right(exc3)) stack.enter_context(gets_the_context_right(exc3))
stack.enter_context(gets_the_context_right(exc2)) stack.enter_context(gets_the_context_right(exc2))
@ -787,7 +789,7 @@ class TestExitStack(unittest.TestCase):
exc4 = Exception(4) exc4 = Exception(4)
exc5 = Exception(5) exc5 = Exception(5)
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.callback(raise_nested, exc4, exc5) stack.callback(raise_nested, exc4, exc5)
stack.callback(raise_nested, exc2, exc3) stack.callback(raise_nested, exc2, exc3)
raise exc1 raise exc1
@ -801,27 +803,25 @@ class TestExitStack(unittest.TestCase):
self.assertIsNone( self.assertIsNone(
exc.__context__.__context__.__context__.__context__.__context__) exc.__context__.__context__.__context__.__context__.__context__)
def test_body_exception_suppress(self): def test_body_exception_suppress(self):
def suppress_exc(*exc_details): def suppress_exc(*exc_details):
return True return True
try: try:
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(suppress_exc) stack.push(suppress_exc)
1/0 1/0
except IndexError as exc: except IndexError as exc:
self.fail("Expected no exception, got IndexError") self.fail("Expected no exception, got IndexError")
def test_exit_exception_chaining_suppress(self): def test_exit_exception_chaining_suppress(self):
with ExitStack() as stack: with self.exit_stack() as stack:
stack.push(lambda *exc: True) stack.push(lambda *exc: True)
stack.push(lambda *exc: 1/0) stack.push(lambda *exc: 1/0)
stack.push(lambda *exc: {}[1]) stack.push(lambda *exc: {}[1])
def test_excessive_nesting(self): def test_excessive_nesting(self):
# The original implementation would die with RecursionError here # The original implementation would die with RecursionError here
with ExitStack() as stack: with self.exit_stack() as stack:
for i in range(10000): for i in range(10000):
stack.callback(int) stack.callback(int)
@ -829,10 +829,10 @@ class TestExitStack(unittest.TestCase):
class Example(object): pass class Example(object): pass
cm = Example() cm = Example()
cm.__exit__ = object() cm.__exit__ = object()
stack = ExitStack() stack = self.exit_stack()
self.assertRaises(AttributeError, stack.enter_context, cm) self.assertRaises(AttributeError, stack.enter_context, cm)
stack.push(cm) stack.push(cm)
self.assertIs(stack._exit_callbacks[-1], cm) self.assertIs(stack._exit_callbacks[-1][1], cm)
def test_dont_reraise_RuntimeError(self): def test_dont_reraise_RuntimeError(self):
# https://bugs.python.org/issue27122 # https://bugs.python.org/issue27122
@ -856,7 +856,7 @@ class TestExitStack(unittest.TestCase):
# The UniqueRuntimeError should be caught by second()'s exception # The UniqueRuntimeError should be caught by second()'s exception
# handler which chain raised a new UniqueException. # handler which chain raised a new UniqueException.
with self.assertRaises(UniqueException) as err_ctx: with self.assertRaises(UniqueException) as err_ctx:
with ExitStack() as es_ctx: with self.exit_stack() as es_ctx:
es_ctx.enter_context(second()) es_ctx.enter_context(second())
es_ctx.enter_context(first()) es_ctx.enter_context(first())
raise UniqueRuntimeError("please no infinite loop.") raise UniqueRuntimeError("please no infinite loop.")
@ -869,6 +869,10 @@ class TestExitStack(unittest.TestCase):
self.assertIs(exc.__cause__, exc.__context__) self.assertIs(exc.__cause__, exc.__context__)
class TestExitStack(TestBaseExitStack, unittest.TestCase):
exit_stack = ExitStack
class TestRedirectStream: class TestRedirectStream:
redirect_stream = None redirect_stream = None

View file

@ -1,9 +1,11 @@
import asyncio import asyncio
from contextlib import asynccontextmanager, AbstractAsyncContextManager from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
import functools import functools
from test import support from test import support
import unittest import unittest
from .test_contextlib import TestBaseExitStack
def _async_test(func): def _async_test(func):
"""Decorator to turn an async function into a test case.""" """Decorator to turn an async function into a test case."""
@ -255,5 +257,168 @@ class AsyncContextManagerTestCase(unittest.TestCase):
self.assertEqual(target, (11, 22, 33, 44)) self.assertEqual(target, (11, 22, 33, 44))
class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
class SyncAsyncExitStack(AsyncExitStack):
@staticmethod
def run_coroutine(coro):
loop = asyncio.get_event_loop()
f = asyncio.ensure_future(coro)
f.add_done_callback(lambda f: loop.stop())
loop.run_forever()
exc = f.exception()
if not exc:
return f.result()
else:
context = exc.__context__
try:
raise exc
except:
exc.__context__ = context
raise exc
def close(self):
return self.run_coroutine(self.aclose())
def __enter__(self):
return self.run_coroutine(self.__aenter__())
def __exit__(self, *exc_details):
return self.run_coroutine(self.__aexit__(*exc_details))
exit_stack = SyncAsyncExitStack
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.addCleanup(self.loop.close)
@_async_test
async def test_async_callback(self):
expected = [
((), {}),
((1,), {}),
((1,2), {}),
((), dict(example=1)),
((1,), dict(example=1)),
((1,2), dict(example=1)),
]
result = []
async def _exit(*args, **kwds):
"""Test metadata propagation"""
result.append((args, kwds))
async with AsyncExitStack() as stack:
for args, kwds in reversed(expected):
if args and kwds:
f = stack.push_async_callback(_exit, *args, **kwds)
elif args:
f = stack.push_async_callback(_exit, *args)
elif kwds:
f = stack.push_async_callback(_exit, **kwds)
else:
f = stack.push_async_callback(_exit)
self.assertIs(f, _exit)
for wrapper in stack._exit_callbacks:
self.assertIs(wrapper[1].__wrapped__, _exit)
self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
self.assertEqual(result, expected)
@_async_test
async def test_async_push(self):
exc_raised = ZeroDivisionError
async def _expect_exc(exc_type, exc, exc_tb):
self.assertIs(exc_type, exc_raised)
async def _suppress_exc(*exc_details):
return True
async def _expect_ok(exc_type, exc, exc_tb):
self.assertIsNone(exc_type)
self.assertIsNone(exc)
self.assertIsNone(exc_tb)
class ExitCM(object):
def __init__(self, check_exc):
self.check_exc = check_exc
async def __aenter__(self):
self.fail("Should not be called!")
async def __aexit__(self, *exc_details):
await self.check_exc(*exc_details)
async with self.exit_stack() as stack:
stack.push_async_exit(_expect_ok)
self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
cm = ExitCM(_expect_ok)
stack.push_async_exit(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push_async_exit(_suppress_exc)
self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
cm = ExitCM(_expect_exc)
stack.push_async_exit(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
stack.push_async_exit(_expect_exc)
self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
stack.push_async_exit(_expect_exc)
self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
1/0
@_async_test
async def test_async_enter_context(self):
class TestCM(object):
async def __aenter__(self):
result.append(1)
async def __aexit__(self, *exc_details):
result.append(3)
result = []
cm = TestCM()
async with AsyncExitStack() as stack:
@stack.push_async_callback # Registered first => cleaned up last
async def _exit():
result.append(4)
self.assertIsNotNone(_exit)
await stack.enter_async_context(cm)
self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
result.append(2)
self.assertEqual(result, [1, 2, 3, 4])
@_async_test
async def test_async_exit_exception_chaining(self):
# Ensure exception chaining matches the reference behaviour
async def raise_exc(exc):
raise exc
saved_details = None
async def suppress_exc(*exc_details):
nonlocal saved_details
saved_details = exc_details
return True
try:
async with self.exit_stack() as stack:
stack.push_async_callback(raise_exc, IndexError)
stack.push_async_callback(raise_exc, KeyError)
stack.push_async_callback(raise_exc, AttributeError)
stack.push_async_exit(suppress_exc)
stack.push_async_callback(raise_exc, ValueError)
1 / 0
except IndexError as exc:
self.assertIsInstance(exc.__context__, KeyError)
self.assertIsInstance(exc.__context__.__context__, AttributeError)
# Inner exceptions were suppressed
self.assertIsNone(exc.__context__.__context__.__context__)
else:
self.fail("Expected IndexError, but no exception was raised")
# Check the inner exceptions
inner_exc = saved_details[1]
self.assertIsInstance(inner_exc, ValueError)
self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -0,0 +1 @@
Add contextlib.AsyncExitStack. Patch by Alexander Mohr and Ilya Kulakov.