bpo-39622: Interrupt the main asyncio task on Ctrl+C (GH-32105)

Co-authored-by: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
This commit is contained in:
Andrew Svetlov 2022-03-30 15:15:06 +03:00 committed by GitHub
parent 04acfa94bb
commit f08a191882
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 2 deletions

View file

@ -2,8 +2,13 @@ __all__ = ('Runner', 'run')
import contextvars
import enum
import functools
import threading
import signal
import sys
from . import coroutines
from . import events
from . import exceptions
from . import tasks
@ -47,6 +52,7 @@ class Runner:
self._loop_factory = loop_factory
self._loop = None
self._context = None
self._interrupt_count = 0
def __enter__(self):
self._lazy_init()
@ -89,7 +95,28 @@ class Runner:
if context is None:
context = self._context
task = self._loop.create_task(coro, context=context)
return self._loop.run_until_complete(task)
if (threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
signal.signal(signal.SIGINT, sigint_handler)
else:
sigint_handler = None
self._interrupt_count = 0
try:
return self._loop.run_until_complete(task)
except exceptions.CancelledError:
if self._interrupt_count > 0 and task.uncancel() == 0:
raise KeyboardInterrupt()
else:
raise # CancelledError
finally:
if (sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)
def _lazy_init(self):
if self._state is _State.CLOSED:
@ -105,6 +132,14 @@ class Runner:
self._context = contextvars.copy_context()
self._state = _State.INITIALIZED
def _on_sigint(self, signum, frame, main_task):
self._interrupt_count += 1
if self._interrupt_count == 1 and not main_task.done():
main_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()
def run(main, *, debug=None):