bpo-46752: Add TaskGroup; add Task..cancelled(),.uncancel() (GH-31270)

asyncio/taskgroups.py is an adaptation of taskgroup.py from EdgeDb, with the following key changes:

- Allow creating new tasks as long as the last task hasn't finished
- Raise [Base]ExceptionGroup (directly) rather than TaskGroupError deriving from MultiError
- Instead of monkey-patching the parent task's cancel() method,
  add a new public API to Task

The Task class has a new internal flag, `_cancel_requested`, which is set when `.cancel()` is called successfully. The `.cancelling()` method returns the value of this flag. Further `.cancel()` calls while this flag is set return False. To reset this flag, call `.uncancel()`.

Thus, a Task that catches and ignores `CancelledError` should call `.uncancel()` if it wants to be cancellable again; until it does so, it is deemed to be busy with uninterruptible cleanup.

This new Task API helps solve the problem where TaskGroup needs to distinguish between whether the parent task being cancelled "from the outside" vs. "from inside".

Co-authored-by: Yury Selivanov <yury@edgedb.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Guido van Rossum 2022-02-15 15:42:04 -08:00 committed by GitHub
parent 08ec80113b
commit 602630ac18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 1100 additions and 3 deletions

View file

@ -0,0 +1,694 @@
# Adapted with permission from the EdgeDB project.
import asyncio
from asyncio import taskgroups
import unittest
# To prevent a warning "test altered the execution environment"
def tearDownModule():
asyncio.set_event_loop_policy(None)
class MyExc(Exception):
pass
class MyBaseExc(BaseException):
pass
def get_error_types(eg):
return {type(exc) for exc in eg.exceptions}
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
async def test_taskgroup_01(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_02(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_03(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
# cancel t1 explicitly, i.e. everything should continue
# working as expected.
t1.cancel()
t2 = g.create_task(foo2())
self.assertTrue(t1.cancelled())
self.assertEqual(t2.result(), 11)
async def test_taskgroup_04(self):
NUM = 0
t2_cancel = False
t2 = None
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, t2
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
t2 = g.create_task(foo2())
NUM += 10
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(t2.cancelled())
async def test_taskgroup_05(self):
NUM = 0
t2_cancel = False
runner_cancel = False
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, runner_cancel
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo2())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
runner_cancel = True
raise
NUM += 10
# The 3 foo1 sub tasks can be racy when the host is busy - if the
# cancellation happens in the middle, we'll see partial sub errors here
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(runner_cancel)
async def test_taskgroup_06(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 5)
async def test_taskgroup_07(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 15)
async def test_taskgroup_08(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_09(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
await asyncio.sleep(0.1)
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_10(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_11(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup():
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_12(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(asyncio.sleep(10))
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_13(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_after(0.1))
async with taskgroups.TaskGroup(name='g2') as g2:
g2.create_task(crash_after(0.2))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ValueError})
async def test_taskgroup_14(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_after(10))
async with taskgroups.TaskGroup(name='g2') as g2:
g2.create_task(crash_after(0.1))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
async def test_taskgroup_15(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_16(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def nested_runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
async def runner():
t = asyncio.create_task(nested_runner())
await t
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_17(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 10)
async def test_taskgroup_18(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
# This isn't a good idea, but we have to support
# this weird case.
raise MyExc
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t),{MyExc})
else:
self.fail('ExceptionGroup was not raised')
self.assertEqual(NUM, 10)
async def test_taskgroup_19(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
r = asyncio.create_task(runner())
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
else:
self.fail('TasgGroupError was not raised')
async def test_taskgroup_20(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise KeyboardInterrupt
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_20a(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyBaseExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(
get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
)
async def _test_taskgroup_21(self):
# This test doesn't work as asyncio, currently, doesn't
# correctly propagate KeyboardInterrupt (or SystemExit) --
# those cause the event loop itself to crash.
# (Compare to the previous (passing) test -- that one raises
# a plain exception but raises KeyboardInterrupt in nested();
# this test does it the other way around.)
async def crash_soon():
await asyncio.sleep(0.1)
raise KeyboardInterrupt
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_21a(self):
async def crash_soon():
await asyncio.sleep(0.1)
raise MyBaseExc
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
async def test_taskgroup_22(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo2())
r = asyncio.create_task(runner())
await asyncio.sleep(0.05)
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_23(self):
async def do_job(delay):
await asyncio.sleep(delay)
async with taskgroups.TaskGroup() as g:
for count in range(10):
await asyncio.sleep(0.1)
g.create_task(do_job(0.3))
if count == 5:
self.assertLess(len(g._tasks), 5)
await asyncio.sleep(1.35)
self.assertEqual(len(g._tasks), 0)
async def test_taskgroup_24(self):
async def root(g):
await asyncio.sleep(0.1)
g.create_task(coro1(0.1))
g.create_task(coro1(0.2))
async def coro1(delay):
await asyncio.sleep(delay)
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(root(g))
await runner()
async def test_taskgroup_25(self):
nhydras = 0
async def hydra(g):
nonlocal nhydras
nhydras += 1
await asyncio.sleep(0.01)
g.create_task(hydra(g))
g.create_task(hydra(g))
async def hercules():
while nhydras < 10:
await asyncio.sleep(0.015)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(hydra(g))
g.create_task(hercules())
with self.assertRaises(ExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertGreaterEqual(nhydras, 10)

View file

@ -496,6 +496,51 @@ class BaseTaskTests:
# This also distinguishes from the initial has_cycle=None.
self.assertEqual(has_cycle, False)
def test_cancelling(self):
loop = asyncio.new_event_loop()
async def task():
await asyncio.sleep(10)
try:
t = self.new_task(loop, task())
self.assertFalse(t.cancelling())
self.assertNotIn(" cancelling ", repr(t))
self.assertTrue(t.cancel())
self.assertTrue(t.cancelling())
self.assertIn(" cancelling ", repr(t))
self.assertFalse(t.cancel())
with self.assertRaises(asyncio.CancelledError):
loop.run_until_complete(t)
finally:
loop.close()
def test_uncancel(self):
loop = asyncio.new_event_loop()
async def task():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
asyncio.current_task().uncancel()
await asyncio.sleep(10)
try:
t = self.new_task(loop, task())
loop.run_until_complete(asyncio.sleep(0.01))
self.assertTrue(t.cancel()) # Cancel first sleep
self.assertIn(" cancelling ", repr(t))
loop.run_until_complete(asyncio.sleep(0.01))
self.assertNotIn(" cancelling ", repr(t)) # after .uncancel()
self.assertTrue(t.cancel()) # Cancel second sleep
with self.assertRaises(asyncio.CancelledError):
loop.run_until_complete(t)
finally:
loop.close()
def test_cancel(self):
def gen():