[3.9] bpo-44566: resolve differences between asynccontextmanager and contextmanager (GH-27024). (#27269)

(cherry picked from commit 7f1c330da3)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
This commit is contained in:
Łukasz Langa 2021-07-20 21:12:58 +02:00 committed by GitHub
parent dae4928dd0
commit 1c5c9c89ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 44 deletions

View file

@ -97,18 +97,20 @@ class _GeneratorContextManagerBase:
# for the class instead. # for the class instead.
# See http://bugs.python.org/issue19404 for more details. # See http://bugs.python.org/issue19404 for more details.
class _GeneratorContextManager(_GeneratorContextManagerBase,
AbstractContextManager,
ContextDecorator):
"""Helper for @contextmanager decorator."""
def _recreate_cm(self): def _recreate_cm(self):
# _GCM instances are one-shot context managers, so the # _GCMB instances are one-shot context managers, so the
# CM must be recreated each time a decorated function is # CM must be recreated each time a decorated function is
# called # called
return self.__class__(self.func, self.args, self.kwds) return self.__class__(self.func, self.args, self.kwds)
class _GeneratorContextManager(
_GeneratorContextManagerBase,
AbstractContextManager,
ContextDecorator,
):
"""Helper for @contextmanager decorator."""
def __enter__(self): def __enter__(self):
# do not keep args and kwds alive unnecessarily # do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore # they are only needed for recreation, which is not possible anymore
@ -118,8 +120,8 @@ class _GeneratorContextManager(_GeneratorContextManagerBase,
except StopIteration: except StopIteration:
raise RuntimeError("generator didn't yield") from None raise RuntimeError("generator didn't yield") from None
def __exit__(self, type, value, traceback): def __exit__(self, typ, value, traceback):
if type is None: if typ is None:
try: try:
next(self.gen) next(self.gen)
except StopIteration: except StopIteration:
@ -130,9 +132,9 @@ class _GeneratorContextManager(_GeneratorContextManagerBase,
if value is None: if value is None:
# Need to force instantiation so we can reliably # Need to force instantiation so we can reliably
# tell if we get the same exception back # tell if we get the same exception back
value = type() value = typ()
try: try:
self.gen.throw(type, value, traceback) self.gen.throw(typ, value, traceback)
except StopIteration as exc: except StopIteration as exc:
# Suppress StopIteration *unless* it's the same exception that # Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration # was passed to throw(). This prevents a StopIteration
@ -142,35 +144,39 @@ class _GeneratorContextManager(_GeneratorContextManagerBase,
# Don't re-raise the passed in exception. (issue27122) # Don't re-raise the passed in exception. (issue27122)
if exc is value: if exc is value:
return False return False
# Likewise, avoid suppressing if a StopIteration exception # Avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError # was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479). # (see PEP 479 for sync generators; async generators also
if type is StopIteration and exc.__cause__ is value: # have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actually Stop(Async)Iteration (see
# issue29692).
if (
isinstance(value, StopIteration)
and exc.__cause__ is value
):
return False return False
raise raise
except: except BaseException as exc:
# only re-raise if it's *not* the exception that was # only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise # passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw() # an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this # has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol # fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol. # and the __exit__() protocol.
# if exc is not value:
# This cannot use 'except BaseException as exc' (as in the raise
# async implementation) to maintain compatibility with return False
# Python 2, where old-style class exceptions are not caught
# by 'except BaseException'.
if sys.exc_info()[1] is value:
return False
raise
raise RuntimeError("generator didn't stop after throw()") raise RuntimeError("generator didn't stop after throw()")
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase, class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
AbstractAsyncContextManager): AbstractAsyncContextManager):
"""Helper for @asynccontextmanager.""" """Helper for @asynccontextmanager decorator."""
async def __aenter__(self): async def __aenter__(self):
# do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore
del self.args, self.kwds, self.func
try: try:
return await self.gen.__anext__() return await self.gen.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
@ -181,35 +187,48 @@ class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
try: try:
await self.gen.__anext__() await self.gen.__anext__()
except StopAsyncIteration: except StopAsyncIteration:
return return False
else: else:
raise RuntimeError("generator didn't stop") raise RuntimeError("generator didn't stop")
else: else:
if value is None: if value is None:
# Need to force instantiation so we can reliably
# tell if we get the same exception back
value = typ() value = typ()
# See _GeneratorContextManager.__exit__ for comments on subtleties
# in this implementation
try: try:
await self.gen.athrow(typ, value, traceback) await self.gen.athrow(typ, value, traceback)
raise RuntimeError("generator didn't stop after athrow()")
except StopAsyncIteration as exc: except StopAsyncIteration as exc:
# Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
# raised inside the "with" statement from being suppressed.
return exc is not value return exc is not value
except RuntimeError as exc: except RuntimeError as exc:
# Don't re-raise the passed in exception. (issue27122)
if exc is value: if exc is value:
return False return False
# Avoid suppressing if a StopIteration exception # Avoid suppressing if a Stop(Async)Iteration exception
# was passed to throw() and later wrapped into a RuntimeError # was passed to athrow() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also # (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception wrapped # have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actully Stop(Async)Iteration (see # by the RuntimeError is actully Stop(Async)Iteration (see
# issue29692). # issue29692).
if isinstance(value, (StopIteration, StopAsyncIteration)): if (
if exc.__cause__ is value: isinstance(value, (StopIteration, StopAsyncIteration))
return False and exc.__cause__ is value
):
return False
raise raise
except BaseException as exc: except BaseException as exc:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
if exc is not value: if exc is not value:
raise raise
return False
raise RuntimeError("generator didn't stop after athrow()")
def contextmanager(func): def contextmanager(func):

View file

@ -125,19 +125,22 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(state, [1, 42, 999]) self.assertEqual(state, [1, 42, 999])
def test_contextmanager_except_stopiter(self): def test_contextmanager_except_stopiter(self):
stop_exc = StopIteration('spam')
@contextmanager @contextmanager
def woohoo(): def woohoo():
yield yield
try:
with self.assertWarnsRegex(DeprecationWarning, class StopIterationSubclass(StopIteration):
"StopIteration"): pass
with woohoo():
raise stop_exc for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
except Exception as ex: with self.subTest(type=type(stop_exc)):
self.assertIs(ex, stop_exc) try:
else: with woohoo():
self.fail('StopIteration was suppressed') raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail(f'{stop_exc} was suppressed')
def test_contextmanager_except_pep479(self): def test_contextmanager_except_pep479(self):
code = """\ code = """\

View file

@ -207,7 +207,18 @@ class AsyncContextManagerTestCase(unittest.TestCase):
async def woohoo(): async def woohoo():
yield yield
for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): class StopIterationSubclass(StopIteration):
pass
class StopAsyncIterationSubclass(StopAsyncIteration):
pass
for stop_exc in (
StopIteration('spam'),
StopAsyncIteration('ham'),
StopIterationSubclass('spam'),
StopAsyncIterationSubclass('spam')
):
with self.subTest(type=type(stop_exc)): with self.subTest(type=type(stop_exc)):
try: try:
async with woohoo(): async with woohoo():

View file

@ -0,0 +1 @@
handle StopIteration subclass raised from @contextlib.contextmanager generator