mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 03:44:55 +00:00 
			
		
		
		
	* Revert "gh-125472: Revert "gh-124958: fix asyncio.TaskGroup and _PyFuture refcycles (#12… (#125476)"
This reverts commit e99650b80a.
* fix incompatability with gh-124392
		
	
			
		
			
				
	
	
		
			275 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			275 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Adapted with permission from the EdgeDB project;
 | 
						|
# license: PSFL.
 | 
						|
 | 
						|
 | 
						|
__all__ = ("TaskGroup",)
 | 
						|
 | 
						|
from . import events
 | 
						|
from . import exceptions
 | 
						|
from . import tasks
 | 
						|
 | 
						|
 | 
						|
class TaskGroup:
 | 
						|
    """Asynchronous context manager for managing groups of tasks.
 | 
						|
 | 
						|
    Example use:
 | 
						|
 | 
						|
        async with asyncio.TaskGroup() as group:
 | 
						|
            task1 = group.create_task(some_coroutine(...))
 | 
						|
            task2 = group.create_task(other_coroutine(...))
 | 
						|
        print("Both tasks have completed now.")
 | 
						|
 | 
						|
    All tasks are awaited when the context manager exits.
 | 
						|
 | 
						|
    Any exceptions other than `asyncio.CancelledError` raised within
 | 
						|
    a task will cancel all remaining tasks and wait for them to exit.
 | 
						|
    The exceptions are then combined and raised as an `ExceptionGroup`.
 | 
						|
    """
 | 
						|
    def __init__(self):
 | 
						|
        self._entered = False
 | 
						|
        self._exiting = False
 | 
						|
        self._aborting = False
 | 
						|
        self._loop = None
 | 
						|
        self._parent_task = None
 | 
						|
        self._parent_cancel_requested = False
 | 
						|
        self._tasks = set()
 | 
						|
        self._errors = []
 | 
						|
        self._base_error = None
 | 
						|
        self._on_completed_fut = None
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        info = ['']
 | 
						|
        if self._tasks:
 | 
						|
            info.append(f'tasks={len(self._tasks)}')
 | 
						|
        if self._errors:
 | 
						|
            info.append(f'errors={len(self._errors)}')
 | 
						|
        if self._aborting:
 | 
						|
            info.append('cancelling')
 | 
						|
        elif self._entered:
 | 
						|
            info.append('entered')
 | 
						|
 | 
						|
        info_str = ' '.join(info)
 | 
						|
        return f'<TaskGroup{info_str}>'
 | 
						|
 | 
						|
    async def __aenter__(self):
 | 
						|
        if self._entered:
 | 
						|
            raise RuntimeError(
 | 
						|
                f"TaskGroup {self!r} has already been entered")
 | 
						|
        if self._loop is None:
 | 
						|
            self._loop = events.get_running_loop()
 | 
						|
        self._parent_task = tasks.current_task(self._loop)
 | 
						|
        if self._parent_task is None:
 | 
						|
            raise RuntimeError(
 | 
						|
                f'TaskGroup {self!r} cannot determine the parent task')
 | 
						|
        self._entered = True
 | 
						|
 | 
						|
        return self
 | 
						|
 | 
						|
    async def __aexit__(self, et, exc, tb):
 | 
						|
        tb = None
 | 
						|
        try:
 | 
						|
            return await self._aexit(et, exc)
 | 
						|
        finally:
 | 
						|
            # Exceptions are heavy objects that can have object
 | 
						|
            # cycles (bad for GC); let's not keep a reference to
 | 
						|
            # a bunch of them. It would be nicer to use a try/finally
 | 
						|
            # in __aexit__ directly but that introduced some diff noise
 | 
						|
            self._parent_task = None
 | 
						|
            self._errors = None
 | 
						|
            self._base_error = None
 | 
						|
            exc = None
 | 
						|
 | 
						|
    async def _aexit(self, et, exc):
 | 
						|
        self._exiting = True
 | 
						|
 | 
						|
        if (exc is not None and
 | 
						|
                self._is_base_error(exc) and
 | 
						|
                self._base_error is None):
 | 
						|
            self._base_error = exc
 | 
						|
 | 
						|
        if et is not None and issubclass(et, exceptions.CancelledError):
 | 
						|
            propagate_cancellation_error = exc
 | 
						|
        else:
 | 
						|
            propagate_cancellation_error = None
 | 
						|
 | 
						|
        if et is not None:
 | 
						|
            if not self._aborting:
 | 
						|
                # Our parent task is being cancelled:
 | 
						|
                #
 | 
						|
                #    async with TaskGroup() as g:
 | 
						|
                #        g.create_task(...)
 | 
						|
                #        await ...  # <- CancelledError
 | 
						|
                #
 | 
						|
                # or there's an exception in "async with":
 | 
						|
                #
 | 
						|
                #    async with TaskGroup() as g:
 | 
						|
                #        g.create_task(...)
 | 
						|
                #        1 / 0
 | 
						|
                #
 | 
						|
                self._abort()
 | 
						|
 | 
						|
        # We use while-loop here because "self._on_completed_fut"
 | 
						|
        # can be cancelled multiple times if our parent task
 | 
						|
        # is being cancelled repeatedly (or even once, when
 | 
						|
        # our own cancellation is already in progress)
 | 
						|
        while self._tasks:
 | 
						|
            if self._on_completed_fut is None:
 | 
						|
                self._on_completed_fut = self._loop.create_future()
 | 
						|
 | 
						|
            try:
 | 
						|
                await self._on_completed_fut
 | 
						|
            except exceptions.CancelledError as ex:
 | 
						|
                if not self._aborting:
 | 
						|
                    # Our parent task is being cancelled:
 | 
						|
                    #
 | 
						|
                    #    async def wrapper():
 | 
						|
                    #        async with TaskGroup() as g:
 | 
						|
                    #            g.create_task(foo)
 | 
						|
                    #
 | 
						|
                    # "wrapper" is being cancelled while "foo" is
 | 
						|
                    # still running.
 | 
						|
                    propagate_cancellation_error = ex
 | 
						|
                    self._abort()
 | 
						|
 | 
						|
            self._on_completed_fut = None
 | 
						|
 | 
						|
        assert not self._tasks
 | 
						|
 | 
						|
        if self._base_error is not None:
 | 
						|
            try:
 | 
						|
                raise self._base_error
 | 
						|
            finally:
 | 
						|
                exc = None
 | 
						|
 | 
						|
        if self._parent_cancel_requested:
 | 
						|
            # If this flag is set we *must* call uncancel().
 | 
						|
            if self._parent_task.uncancel() == 0:
 | 
						|
                # If there are no pending cancellations left,
 | 
						|
                # don't propagate CancelledError.
 | 
						|
                propagate_cancellation_error = None
 | 
						|
 | 
						|
        # Propagate CancelledError if there is one, except if there
 | 
						|
        # are other errors -- those have priority.
 | 
						|
        try:
 | 
						|
            if propagate_cancellation_error is not None and not self._errors:
 | 
						|
                try:
 | 
						|
                    raise propagate_cancellation_error
 | 
						|
                finally:
 | 
						|
                    exc = None
 | 
						|
        finally:
 | 
						|
            propagate_cancellation_error = None
 | 
						|
 | 
						|
        if et is not None and not issubclass(et, exceptions.CancelledError):
 | 
						|
            self._errors.append(exc)
 | 
						|
 | 
						|
        if self._errors:
 | 
						|
            # If the parent task is being cancelled from the outside
 | 
						|
            # of the taskgroup, un-cancel and re-cancel the parent task,
 | 
						|
            # which will keep the cancel count stable.
 | 
						|
            if self._parent_task.cancelling():
 | 
						|
                self._parent_task.uncancel()
 | 
						|
                self._parent_task.cancel()
 | 
						|
            try:
 | 
						|
                raise BaseExceptionGroup(
 | 
						|
                    'unhandled errors in a TaskGroup',
 | 
						|
                    self._errors,
 | 
						|
                ) from None
 | 
						|
            finally:
 | 
						|
                exc = None
 | 
						|
 | 
						|
 | 
						|
    def create_task(self, coro, *, name=None, context=None):
 | 
						|
        """Create a new task in this group and return it.
 | 
						|
 | 
						|
        Similar to `asyncio.create_task`.
 | 
						|
        """
 | 
						|
        if not self._entered:
 | 
						|
            coro.close()
 | 
						|
            raise RuntimeError(f"TaskGroup {self!r} has not been entered")
 | 
						|
        if self._exiting and not self._tasks:
 | 
						|
            coro.close()
 | 
						|
            raise RuntimeError(f"TaskGroup {self!r} is finished")
 | 
						|
        if self._aborting:
 | 
						|
            coro.close()
 | 
						|
            raise RuntimeError(f"TaskGroup {self!r} is shutting down")
 | 
						|
        if context is None:
 | 
						|
            task = self._loop.create_task(coro, name=name)
 | 
						|
        else:
 | 
						|
            task = self._loop.create_task(coro, name=name, context=context)
 | 
						|
 | 
						|
        # optimization: Immediately call the done callback if the task is
 | 
						|
        # already done (e.g. if the coro was able to complete eagerly),
 | 
						|
        # and skip scheduling a done callback
 | 
						|
        if task.done():
 | 
						|
            self._on_task_done(task)
 | 
						|
        else:
 | 
						|
            self._tasks.add(task)
 | 
						|
            task.add_done_callback(self._on_task_done)
 | 
						|
        return task
 | 
						|
 | 
						|
    # Since Python 3.8 Tasks propagate all exceptions correctly,
 | 
						|
    # except for KeyboardInterrupt and SystemExit which are
 | 
						|
    # still considered special.
 | 
						|
 | 
						|
    def _is_base_error(self, exc: BaseException) -> bool:
 | 
						|
        assert isinstance(exc, BaseException)
 | 
						|
        return isinstance(exc, (SystemExit, KeyboardInterrupt))
 | 
						|
 | 
						|
    def _abort(self):
 | 
						|
        self._aborting = True
 | 
						|
 | 
						|
        for t in self._tasks:
 | 
						|
            if not t.done():
 | 
						|
                t.cancel()
 | 
						|
 | 
						|
    def _on_task_done(self, task):
 | 
						|
        self._tasks.discard(task)
 | 
						|
 | 
						|
        if self._on_completed_fut is not None and not self._tasks:
 | 
						|
            if not self._on_completed_fut.done():
 | 
						|
                self._on_completed_fut.set_result(True)
 | 
						|
 | 
						|
        if task.cancelled():
 | 
						|
            return
 | 
						|
 | 
						|
        exc = task.exception()
 | 
						|
        if exc is None:
 | 
						|
            return
 | 
						|
 | 
						|
        self._errors.append(exc)
 | 
						|
        if self._is_base_error(exc) and self._base_error is None:
 | 
						|
            self._base_error = exc
 | 
						|
 | 
						|
        if self._parent_task.done():
 | 
						|
            # Not sure if this case is possible, but we want to handle
 | 
						|
            # it anyways.
 | 
						|
            self._loop.call_exception_handler({
 | 
						|
                'message': f'Task {task!r} has errored out but its parent '
 | 
						|
                           f'task {self._parent_task} is already completed',
 | 
						|
                'exception': exc,
 | 
						|
                'task': task,
 | 
						|
            })
 | 
						|
            return
 | 
						|
 | 
						|
        if not self._aborting and not self._parent_cancel_requested:
 | 
						|
            # If parent task *is not* being cancelled, it means that we want
 | 
						|
            # to manually cancel it to abort whatever is being run right now
 | 
						|
            # in the TaskGroup.  But we want to mark parent task as
 | 
						|
            # "not cancelled" later in __aexit__.  Example situation that
 | 
						|
            # we need to handle:
 | 
						|
            #
 | 
						|
            #    async def foo():
 | 
						|
            #        try:
 | 
						|
            #            async with TaskGroup() as g:
 | 
						|
            #                g.create_task(crash_soon())
 | 
						|
            #                await something  # <- this needs to be canceled
 | 
						|
            #                                 #    by the TaskGroup, e.g.
 | 
						|
            #                                 #    foo() needs to be cancelled
 | 
						|
            #        except Exception:
 | 
						|
            #            # Ignore any exceptions raised in the TaskGroup
 | 
						|
            #            pass
 | 
						|
            #        await something_else     # this line has to be called
 | 
						|
            #                                 # after TaskGroup is finished.
 | 
						|
            self._abort()
 | 
						|
            self._parent_cancel_requested = True
 | 
						|
            self._parent_task.cancel()
 |