mirror of
https://github.com/python/cpython.git
synced 2025-11-03 03:22:27 +00:00
bpo-32314: Fix asyncio.run() to cancel runinng tasks on shutdown (#5262)
This commit is contained in:
parent
fc2f407829
commit
a4afcdfa55
4 changed files with 122 additions and 15 deletions
|
|
@ -228,14 +228,9 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
self._coroutine_origin_tracking_enabled = False
|
self._coroutine_origin_tracking_enabled = False
|
||||||
self._coroutine_origin_tracking_saved_depth = None
|
self._coroutine_origin_tracking_saved_depth = None
|
||||||
|
|
||||||
if hasattr(sys, 'get_asyncgen_hooks'):
|
# A weak set of all asynchronous generators that are
|
||||||
# Python >= 3.6
|
# being iterated by the loop.
|
||||||
# A weak set of all asynchronous generators that are
|
self._asyncgens = weakref.WeakSet()
|
||||||
# being iterated by the loop.
|
|
||||||
self._asyncgens = weakref.WeakSet()
|
|
||||||
else:
|
|
||||||
self._asyncgens = None
|
|
||||||
|
|
||||||
# Set to True when `loop.shutdown_asyncgens` is called.
|
# Set to True when `loop.shutdown_asyncgens` is called.
|
||||||
self._asyncgens_shutdown_called = False
|
self._asyncgens_shutdown_called = False
|
||||||
|
|
||||||
|
|
@ -354,7 +349,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
"""Shutdown all active asynchronous generators."""
|
"""Shutdown all active asynchronous generators."""
|
||||||
self._asyncgens_shutdown_called = True
|
self._asyncgens_shutdown_called = True
|
||||||
|
|
||||||
if self._asyncgens is None or not len(self._asyncgens):
|
if not len(self._asyncgens):
|
||||||
# If Python version is <3.6 or we don't have any asynchronous
|
# If Python version is <3.6 or we don't have any asynchronous
|
||||||
# generators alive.
|
# generators alive.
|
||||||
return
|
return
|
||||||
|
|
@ -386,10 +381,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
'Cannot run the event loop while another loop is running')
|
'Cannot run the event loop while another loop is running')
|
||||||
self._set_coroutine_origin_tracking(self._debug)
|
self._set_coroutine_origin_tracking(self._debug)
|
||||||
self._thread_id = threading.get_ident()
|
self._thread_id = threading.get_ident()
|
||||||
if self._asyncgens is not None:
|
|
||||||
old_agen_hooks = sys.get_asyncgen_hooks()
|
old_agen_hooks = sys.get_asyncgen_hooks()
|
||||||
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
|
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
|
||||||
finalizer=self._asyncgen_finalizer_hook)
|
finalizer=self._asyncgen_finalizer_hook)
|
||||||
try:
|
try:
|
||||||
events._set_running_loop(self)
|
events._set_running_loop(self)
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -401,8 +396,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
self._thread_id = None
|
self._thread_id = None
|
||||||
events._set_running_loop(None)
|
events._set_running_loop(None)
|
||||||
self._set_coroutine_origin_tracking(False)
|
self._set_coroutine_origin_tracking(False)
|
||||||
if self._asyncgens is not None:
|
sys.set_asyncgen_hooks(*old_agen_hooks)
|
||||||
sys.set_asyncgen_hooks(*old_agen_hooks)
|
|
||||||
|
|
||||||
def run_until_complete(self, future):
|
def run_until_complete(self, future):
|
||||||
"""Run until the Future is done.
|
"""Run until the Future is done.
|
||||||
|
|
@ -1374,6 +1368,7 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
- 'message': Error message;
|
- 'message': Error message;
|
||||||
- 'exception' (optional): Exception object;
|
- 'exception' (optional): Exception object;
|
||||||
- 'future' (optional): Future instance;
|
- 'future' (optional): Future instance;
|
||||||
|
- 'task' (optional): Task instance;
|
||||||
- 'handle' (optional): Handle instance;
|
- 'handle' (optional): Handle instance;
|
||||||
- 'protocol' (optional): Protocol instance;
|
- 'protocol' (optional): Protocol instance;
|
||||||
- 'transport' (optional): Transport instance;
|
- 'transport' (optional): Transport instance;
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ __all__ = 'run',
|
||||||
|
|
||||||
from . import coroutines
|
from . import coroutines
|
||||||
from . import events
|
from . import events
|
||||||
|
from . import tasks
|
||||||
|
|
||||||
|
|
||||||
def run(main, *, debug=False):
|
def run(main, *, debug=False):
|
||||||
|
|
@ -42,7 +43,31 @@ def run(main, *, debug=False):
|
||||||
return loop.run_until_complete(main)
|
return loop.run_until_complete(main)
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
|
_cancel_all_tasks(loop)
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
finally:
|
finally:
|
||||||
events.set_event_loop(None)
|
events.set_event_loop(None)
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _cancel_all_tasks(loop):
|
||||||
|
to_cancel = [task for task in tasks.all_tasks(loop)
|
||||||
|
if not task.done()]
|
||||||
|
if not to_cancel:
|
||||||
|
return
|
||||||
|
|
||||||
|
for task in to_cancel:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
loop.run_until_complete(
|
||||||
|
tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
|
||||||
|
|
||||||
|
for task in to_cancel:
|
||||||
|
if task.cancelled():
|
||||||
|
continue
|
||||||
|
if task.exception() is not None:
|
||||||
|
loop.call_exception_handler({
|
||||||
|
'message': 'unhandled exception during asyncio.run() shutdown',
|
||||||
|
'exception': task.exception(),
|
||||||
|
'task': task,
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from . import utils as test_utils
|
||||||
|
|
||||||
|
|
||||||
class TestPolicy(asyncio.AbstractEventLoopPolicy):
|
class TestPolicy(asyncio.AbstractEventLoopPolicy):
|
||||||
|
|
@ -98,3 +99,81 @@ class RunTests(BaseTest):
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
'cannot be called from a running'):
|
'cannot be called from a running'):
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
def test_asyncio_run_cancels_hanging_tasks(self):
|
||||||
|
lo_task = None
|
||||||
|
|
||||||
|
async def leftover():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
nonlocal lo_task
|
||||||
|
lo_task = asyncio.create_task(leftover())
|
||||||
|
return 123
|
||||||
|
|
||||||
|
self.assertEqual(asyncio.run(main()), 123)
|
||||||
|
self.assertTrue(lo_task.done())
|
||||||
|
|
||||||
|
def test_asyncio_run_reports_hanging_tasks_errors(self):
|
||||||
|
lo_task = None
|
||||||
|
call_exc_handler_mock = mock.Mock()
|
||||||
|
|
||||||
|
async def leftover():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
1 / 0
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.call_exception_handler = call_exc_handler_mock
|
||||||
|
|
||||||
|
nonlocal lo_task
|
||||||
|
lo_task = asyncio.create_task(leftover())
|
||||||
|
return 123
|
||||||
|
|
||||||
|
self.assertEqual(asyncio.run(main()), 123)
|
||||||
|
self.assertTrue(lo_task.done())
|
||||||
|
|
||||||
|
call_exc_handler_mock.assert_called_with({
|
||||||
|
'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
|
||||||
|
'task': lo_task,
|
||||||
|
'exception': test_utils.MockInstanceOf(ZeroDivisionError)
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
|
||||||
|
spinner = None
|
||||||
|
lazyboy = None
|
||||||
|
|
||||||
|
class FancyExit(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def fidget():
|
||||||
|
while True:
|
||||||
|
yield 1
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def spin():
|
||||||
|
nonlocal spinner
|
||||||
|
spinner = fidget()
|
||||||
|
try:
|
||||||
|
async for the_meaning_of_life in spinner: # NoQA
|
||||||
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
1 / 0
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.call_exception_handler = mock.Mock()
|
||||||
|
|
||||||
|
nonlocal lazyboy
|
||||||
|
lazyboy = asyncio.create_task(spin())
|
||||||
|
raise FancyExit
|
||||||
|
|
||||||
|
with self.assertRaises(FancyExit):
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
self.assertTrue(lazyboy.done())
|
||||||
|
|
||||||
|
self.assertIsNone(spinner.ag_frame)
|
||||||
|
self.assertFalse(spinner.ag_running)
|
||||||
|
|
|
||||||
|
|
@ -485,6 +485,14 @@ class MockPattern(str):
|
||||||
return bool(re.search(str(self), other, re.S))
|
return bool(re.search(str(self), other, re.S))
|
||||||
|
|
||||||
|
|
||||||
|
class MockInstanceOf:
|
||||||
|
def __init__(self, type):
|
||||||
|
self._type = type
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, self._type)
|
||||||
|
|
||||||
|
|
||||||
def get_function_source(func):
|
def get_function_source(func):
|
||||||
source = format_helpers._get_function_source(func)
|
source = format_helpers._get_function_source(func)
|
||||||
if source is None:
|
if source is None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue