mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 02:15:10 +00:00 
			
		
		
		
	 fa58e75a86
			
		
	
	
		fa58e75a86
		
			
		
	
	
	
	
		
			
			This prevents external cancellations of a task group's parent task to be dropped when an internal cancellation happens at the same time. Also strengthen the semantics of uncancel() to clear self._must_cancel when the cancellation count reaches zero. Co-Authored-By: Tin Tvrtković <tinchester@gmail.com> Co-Authored-By: Arthur Tacca
		
			
				
	
	
		
			904 lines
		
	
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			904 lines
		
	
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Adapted with permission from the EdgeDB project;
 | |
| # license: PSFL.
 | |
| 
 | |
| 
 | |
| import asyncio
 | |
| import contextvars
 | |
| import contextlib
 | |
| from asyncio import taskgroups
 | |
| import unittest
 | |
| import warnings
 | |
| 
 | |
| from test.test_asyncio.utils import await_without_task
 | |
| 
 | |
| 
 | |
| # To prevent a warning "test altered the execution environment"
 | |
| def tearDownModule():
 | |
|     asyncio.set_event_loop_policy(None)
 | |
| 
 | |
| 
 | |
| class MyExc(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class MyBaseExc(BaseException):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def get_error_types(eg):
 | |
|     return {type(exc) for exc in eg.exceptions}
 | |
| 
 | |
| 
 | |
| class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 | |
| 
 | |
|     async def test_taskgroup_01(self):
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(0.1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(0.2)
 | |
|             return 11
 | |
| 
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             t1 = g.create_task(foo1())
 | |
|             t2 = g.create_task(foo2())
 | |
| 
 | |
|         self.assertEqual(t1.result(), 42)
 | |
|         self.assertEqual(t2.result(), 11)
 | |
| 
 | |
|     async def test_taskgroup_02(self):
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(0.1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(0.2)
 | |
|             return 11
 | |
| 
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             t1 = g.create_task(foo1())
 | |
|             await asyncio.sleep(0.15)
 | |
|             t2 = g.create_task(foo2())
 | |
| 
 | |
|         self.assertEqual(t1.result(), 42)
 | |
|         self.assertEqual(t2.result(), 11)
 | |
| 
 | |
|     async def test_taskgroup_03(self):
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(0.2)
 | |
|             return 11
 | |
| 
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             t1 = g.create_task(foo1())
 | |
|             await asyncio.sleep(0.15)
 | |
|             # cancel t1 explicitly, i.e. everything should continue
 | |
|             # working as expected.
 | |
|             t1.cancel()
 | |
| 
 | |
|             t2 = g.create_task(foo2())
 | |
| 
 | |
|         self.assertTrue(t1.cancelled())
 | |
|         self.assertEqual(t2.result(), 11)
 | |
| 
 | |
|     async def test_taskgroup_04(self):
 | |
| 
 | |
|         NUM = 0
 | |
|         t2_cancel = False
 | |
|         t2 = None
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(0.1)
 | |
|             1 / 0
 | |
| 
 | |
|         async def foo2():
 | |
|             nonlocal NUM, t2_cancel
 | |
|             try:
 | |
|                 await asyncio.sleep(1)
 | |
|             except asyncio.CancelledError:
 | |
|                 t2_cancel = True
 | |
|                 raise
 | |
|             NUM += 1
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal NUM, t2
 | |
| 
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(foo1())
 | |
|                 t2 = g.create_task(foo2())
 | |
| 
 | |
|             NUM += 10
 | |
| 
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await asyncio.create_task(runner())
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
| 
 | |
|         self.assertEqual(NUM, 0)
 | |
|         self.assertTrue(t2_cancel)
 | |
|         self.assertTrue(t2.cancelled())
 | |
| 
 | |
|     async def test_cancel_children_on_child_error(self):
 | |
|         # When a child task raises an error, the rest of the children
 | |
|         # are cancelled and the errors are gathered into an EG.
 | |
| 
 | |
|         NUM = 0
 | |
|         t2_cancel = False
 | |
|         runner_cancel = False
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(0.1)
 | |
|             1 / 0
 | |
| 
 | |
|         async def foo2():
 | |
|             nonlocal NUM, t2_cancel
 | |
|             try:
 | |
|                 await asyncio.sleep(5)
 | |
|             except asyncio.CancelledError:
 | |
|                 t2_cancel = True
 | |
|                 raise
 | |
|             NUM += 1
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal NUM, runner_cancel
 | |
| 
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(foo1())
 | |
|                 g.create_task(foo1())
 | |
|                 g.create_task(foo1())
 | |
|                 g.create_task(foo2())
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     runner_cancel = True
 | |
|                     raise
 | |
| 
 | |
|             NUM += 10
 | |
| 
 | |
|         # The 3 foo1 sub tasks can be racy when the host is busy - if the
 | |
|         # cancellation happens in the middle, we'll see partial sub errors here
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await asyncio.create_task(runner())
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
|         self.assertEqual(NUM, 0)
 | |
|         self.assertTrue(t2_cancel)
 | |
|         self.assertTrue(runner_cancel)
 | |
| 
 | |
|     async def test_cancellation(self):
 | |
| 
 | |
|         NUM = 0
 | |
| 
 | |
|         async def foo():
 | |
|             nonlocal NUM
 | |
|             try:
 | |
|                 await asyncio.sleep(5)
 | |
|             except asyncio.CancelledError:
 | |
|                 NUM += 1
 | |
|                 raise
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 for _ in range(5):
 | |
|                     g.create_task(foo())
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(asyncio.CancelledError) as cm:
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(NUM, 5)
 | |
| 
 | |
|     async def test_taskgroup_07(self):
 | |
| 
 | |
|         NUM = 0
 | |
| 
 | |
|         async def foo():
 | |
|             nonlocal NUM
 | |
|             try:
 | |
|                 await asyncio.sleep(5)
 | |
|             except asyncio.CancelledError:
 | |
|                 NUM += 1
 | |
|                 raise
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal NUM
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 for _ in range(5):
 | |
|                     g.create_task(foo())
 | |
| 
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     NUM += 10
 | |
|                     raise
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(asyncio.CancelledError):
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(NUM, 15)
 | |
| 
 | |
|     async def test_taskgroup_08(self):
 | |
| 
 | |
|         async def foo():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 1 / 0
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 for _ in range(5):
 | |
|                     g.create_task(foo())
 | |
| 
 | |
|                 await asyncio.sleep(10)
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_09(self):
 | |
| 
 | |
|         t1 = t2 = None
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(2)
 | |
|             return 11
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal t1, t2
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 t1 = g.create_task(foo1())
 | |
|                 t2 = g.create_task(foo2())
 | |
|                 await asyncio.sleep(0.1)
 | |
|                 1 / 0
 | |
| 
 | |
|         try:
 | |
|             await runner()
 | |
|         except ExceptionGroup as t:
 | |
|             self.assertEqual(get_error_types(t), {ZeroDivisionError})
 | |
|         else:
 | |
|             self.fail('ExceptionGroup was not raised')
 | |
| 
 | |
|         self.assertTrue(t1.cancelled())
 | |
|         self.assertTrue(t2.cancelled())
 | |
| 
 | |
|     async def test_taskgroup_10(self):
 | |
| 
 | |
|         t1 = t2 = None
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(2)
 | |
|             return 11
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal t1, t2
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 t1 = g.create_task(foo1())
 | |
|                 t2 = g.create_task(foo2())
 | |
|                 1 / 0
 | |
| 
 | |
|         try:
 | |
|             await runner()
 | |
|         except ExceptionGroup as t:
 | |
|             self.assertEqual(get_error_types(t), {ZeroDivisionError})
 | |
|         else:
 | |
|             self.fail('ExceptionGroup was not raised')
 | |
| 
 | |
|         self.assertTrue(t1.cancelled())
 | |
|         self.assertTrue(t2.cancelled())
 | |
| 
 | |
|     async def test_taskgroup_11(self):
 | |
| 
 | |
|         async def foo():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 1 / 0
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup():
 | |
|                 async with taskgroups.TaskGroup() as g2:
 | |
|                     for _ in range(5):
 | |
|                         g2.create_task(foo())
 | |
| 
 | |
|                     await asyncio.sleep(10)
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | |
|         self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_12(self):
 | |
| 
 | |
|         async def foo():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 1 / 0
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g1:
 | |
|                 g1.create_task(asyncio.sleep(10))
 | |
| 
 | |
|                 async with taskgroups.TaskGroup() as g2:
 | |
|                     for _ in range(5):
 | |
|                         g2.create_task(foo())
 | |
| 
 | |
|                     await asyncio.sleep(10)
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | |
|         self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_13(self):
 | |
| 
 | |
|         async def crash_after(t):
 | |
|             await asyncio.sleep(t)
 | |
|             raise ValueError(t)
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g1:
 | |
|                 g1.create_task(crash_after(0.1))
 | |
| 
 | |
|                 async with taskgroups.TaskGroup() as g2:
 | |
|                     g2.create_task(crash_after(10))
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ValueError})
 | |
| 
 | |
|     async def test_taskgroup_14(self):
 | |
| 
 | |
|         async def crash_after(t):
 | |
|             await asyncio.sleep(t)
 | |
|             raise ValueError(t)
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g1:
 | |
|                 g1.create_task(crash_after(10))
 | |
| 
 | |
|                 async with taskgroups.TaskGroup() as g2:
 | |
|                     g2.create_task(crash_after(0.1))
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
 | |
|         self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
 | |
| 
 | |
|     async def test_taskgroup_15(self):
 | |
| 
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.3)
 | |
|             1 / 0
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g1:
 | |
|                 g1.create_task(crash_soon())
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     await asyncio.sleep(0.5)
 | |
|                     raise
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_16(self):
 | |
| 
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.3)
 | |
|             1 / 0
 | |
| 
 | |
|         async def nested_runner():
 | |
|             async with taskgroups.TaskGroup() as g1:
 | |
|                 g1.create_task(crash_soon())
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     await asyncio.sleep(0.5)
 | |
|                     raise
 | |
| 
 | |
|         async def runner():
 | |
|             t = asyncio.create_task(nested_runner())
 | |
|             await t
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await r
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_17(self):
 | |
|         NUM = 0
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal NUM
 | |
|             async with taskgroups.TaskGroup():
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     NUM += 10
 | |
|                     raise
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
|         with self.assertRaises(asyncio.CancelledError):
 | |
|             await r
 | |
| 
 | |
|         self.assertEqual(NUM, 10)
 | |
| 
 | |
|     async def test_taskgroup_18(self):
 | |
|         NUM = 0
 | |
| 
 | |
|         async def runner():
 | |
|             nonlocal NUM
 | |
|             async with taskgroups.TaskGroup():
 | |
|                 try:
 | |
|                     await asyncio.sleep(10)
 | |
|                 except asyncio.CancelledError:
 | |
|                     NUM += 10
 | |
|                     # This isn't a good idea, but we have to support
 | |
|                     # this weird case.
 | |
|                     raise MyExc
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.1)
 | |
| 
 | |
|         self.assertFalse(r.done())
 | |
|         r.cancel()
 | |
| 
 | |
|         try:
 | |
|             await r
 | |
|         except ExceptionGroup as t:
 | |
|             self.assertEqual(get_error_types(t),{MyExc})
 | |
|         else:
 | |
|             self.fail('ExceptionGroup was not raised')
 | |
| 
 | |
|         self.assertEqual(NUM, 10)
 | |
| 
 | |
|     async def test_taskgroup_19(self):
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.1)
 | |
|             1 / 0
 | |
| 
 | |
|         async def nested():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 raise MyExc
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(crash_soon())
 | |
|                 await nested()
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         try:
 | |
|             await r
 | |
|         except ExceptionGroup as t:
 | |
|             self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
 | |
|         else:
 | |
|             self.fail('TasgGroupError was not raised')
 | |
| 
 | |
|     async def test_taskgroup_20(self):
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.1)
 | |
|             1 / 0
 | |
| 
 | |
|         async def nested():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 raise KeyboardInterrupt
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(crash_soon())
 | |
|                 await nested()
 | |
| 
 | |
|         with self.assertRaises(KeyboardInterrupt):
 | |
|             await runner()
 | |
| 
 | |
|     async def test_taskgroup_20a(self):
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.1)
 | |
|             1 / 0
 | |
| 
 | |
|         async def nested():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 raise MyBaseExc
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(crash_soon())
 | |
|                 await nested()
 | |
| 
 | |
|         with self.assertRaises(BaseExceptionGroup) as cm:
 | |
|             await runner()
 | |
| 
 | |
|         self.assertEqual(
 | |
|             get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
 | |
|         )
 | |
| 
 | |
|     async def _test_taskgroup_21(self):
 | |
|         # This test doesn't work as asyncio, currently, doesn't
 | |
|         # correctly propagate KeyboardInterrupt (or SystemExit) --
 | |
|         # those cause the event loop itself to crash.
 | |
|         # (Compare to the previous (passing) test -- that one raises
 | |
|         # a plain exception but raises KeyboardInterrupt in nested();
 | |
|         # this test does it the other way around.)
 | |
| 
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.1)
 | |
|             raise KeyboardInterrupt
 | |
| 
 | |
|         async def nested():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 raise TypeError
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(crash_soon())
 | |
|                 await nested()
 | |
| 
 | |
|         with self.assertRaises(KeyboardInterrupt):
 | |
|             await runner()
 | |
| 
 | |
|     async def test_taskgroup_21a(self):
 | |
| 
 | |
|         async def crash_soon():
 | |
|             await asyncio.sleep(0.1)
 | |
|             raise MyBaseExc
 | |
| 
 | |
|         async def nested():
 | |
|             try:
 | |
|                 await asyncio.sleep(10)
 | |
|             finally:
 | |
|                 raise TypeError
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(crash_soon())
 | |
|                 await nested()
 | |
| 
 | |
|         with self.assertRaises(BaseExceptionGroup) as cm:
 | |
|             await runner()
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
 | |
| 
 | |
|     async def test_taskgroup_22(self):
 | |
| 
 | |
|         async def foo1():
 | |
|             await asyncio.sleep(1)
 | |
|             return 42
 | |
| 
 | |
|         async def foo2():
 | |
|             await asyncio.sleep(2)
 | |
|             return 11
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(foo1())
 | |
|                 g.create_task(foo2())
 | |
| 
 | |
|         r = asyncio.create_task(runner())
 | |
|         await asyncio.sleep(0.05)
 | |
|         r.cancel()
 | |
| 
 | |
|         with self.assertRaises(asyncio.CancelledError):
 | |
|             await r
 | |
| 
 | |
|     async def test_taskgroup_23(self):
 | |
| 
 | |
|         async def do_job(delay):
 | |
|             await asyncio.sleep(delay)
 | |
| 
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             for count in range(10):
 | |
|                 await asyncio.sleep(0.1)
 | |
|                 g.create_task(do_job(0.3))
 | |
|                 if count == 5:
 | |
|                     self.assertLess(len(g._tasks), 5)
 | |
|             await asyncio.sleep(1.35)
 | |
|             self.assertEqual(len(g._tasks), 0)
 | |
| 
 | |
|     async def test_taskgroup_24(self):
 | |
| 
 | |
|         async def root(g):
 | |
|             await asyncio.sleep(0.1)
 | |
|             g.create_task(coro1(0.1))
 | |
|             g.create_task(coro1(0.2))
 | |
| 
 | |
|         async def coro1(delay):
 | |
|             await asyncio.sleep(delay)
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(root(g))
 | |
| 
 | |
|         await runner()
 | |
| 
 | |
|     async def test_taskgroup_25(self):
 | |
|         nhydras = 0
 | |
| 
 | |
|         async def hydra(g):
 | |
|             nonlocal nhydras
 | |
|             nhydras += 1
 | |
|             await asyncio.sleep(0.01)
 | |
|             g.create_task(hydra(g))
 | |
|             g.create_task(hydra(g))
 | |
| 
 | |
|         async def hercules():
 | |
|             while nhydras < 10:
 | |
|                 await asyncio.sleep(0.015)
 | |
|             1 / 0
 | |
| 
 | |
|         async def runner():
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(hydra(g))
 | |
|                 g.create_task(hercules())
 | |
| 
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             await runner()
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
|         self.assertGreaterEqual(nhydras, 10)
 | |
| 
 | |
|     async def test_taskgroup_task_name(self):
 | |
|         async def coro():
 | |
|             await asyncio.sleep(0)
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             t = g.create_task(coro(), name="yolo")
 | |
|             self.assertEqual(t.get_name(), "yolo")
 | |
| 
 | |
|     async def test_taskgroup_task_context(self):
 | |
|         cvar = contextvars.ContextVar('cvar')
 | |
| 
 | |
|         async def coro(val):
 | |
|             await asyncio.sleep(0)
 | |
|             cvar.set(val)
 | |
| 
 | |
|         async with taskgroups.TaskGroup() as g:
 | |
|             ctx = contextvars.copy_context()
 | |
|             self.assertIsNone(ctx.get(cvar))
 | |
|             t1 = g.create_task(coro(1), context=ctx)
 | |
|             await t1
 | |
|             self.assertEqual(1, ctx.get(cvar))
 | |
|             t2 = g.create_task(coro(2), context=ctx)
 | |
|             await t2
 | |
|             self.assertEqual(2, ctx.get(cvar))
 | |
| 
 | |
|     async def test_taskgroup_no_create_task_after_failure(self):
 | |
|         async def coro1():
 | |
|             await asyncio.sleep(0.001)
 | |
|             1 / 0
 | |
|         async def coro2(g):
 | |
|             try:
 | |
|                 await asyncio.sleep(1)
 | |
|             except asyncio.CancelledError:
 | |
|                 with self.assertRaises(RuntimeError):
 | |
|                     g.create_task(coro1())
 | |
| 
 | |
|         with self.assertRaises(ExceptionGroup) as cm:
 | |
|             async with taskgroups.TaskGroup() as g:
 | |
|                 g.create_task(coro1())
 | |
|                 g.create_task(coro2(g))
 | |
| 
 | |
|         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
 | |
| 
 | |
|     async def test_taskgroup_context_manager_exit_raises(self):
 | |
|         # See https://github.com/python/cpython/issues/95289
 | |
|         class CustomException(Exception):
 | |
|             pass
 | |
| 
 | |
|         async def raise_exc():
 | |
|             raise CustomException
 | |
| 
 | |
|         @contextlib.asynccontextmanager
 | |
|         async def database():
 | |
|             try:
 | |
|                 yield
 | |
|             finally:
 | |
|                 raise CustomException
 | |
| 
 | |
|         async def main():
 | |
|             task = asyncio.current_task()
 | |
|             try:
 | |
|                 async with taskgroups.TaskGroup() as tg:
 | |
|                     async with database():
 | |
|                         tg.create_task(raise_exc())
 | |
|                         await asyncio.sleep(1)
 | |
|             except* CustomException as err:
 | |
|                 self.assertEqual(task.cancelling(), 0)
 | |
|                 self.assertEqual(len(err.exceptions), 2)
 | |
| 
 | |
|             else:
 | |
|                 self.fail('CustomException not raised')
 | |
| 
 | |
|         await asyncio.create_task(main())
 | |
| 
 | |
|     async def test_taskgroup_already_entered(self):
 | |
|         tg = taskgroups.TaskGroup()
 | |
|         async with tg:
 | |
|             with self.assertRaisesRegex(RuntimeError, "has already been entered"):
 | |
|                 async with tg:
 | |
|                     pass
 | |
| 
 | |
|     async def test_taskgroup_double_enter(self):
 | |
|         tg = taskgroups.TaskGroup()
 | |
|         async with tg:
 | |
|             pass
 | |
|         with self.assertRaisesRegex(RuntimeError, "has already been entered"):
 | |
|             async with tg:
 | |
|                 pass
 | |
| 
 | |
|     async def test_taskgroup_finished(self):
 | |
|         async def create_task_after_tg_finish():
 | |
|             tg = taskgroups.TaskGroup()
 | |
|             async with tg:
 | |
|                 pass
 | |
|             coro = asyncio.sleep(0)
 | |
|             with self.assertRaisesRegex(RuntimeError, "is finished"):
 | |
|                 tg.create_task(coro)
 | |
| 
 | |
|         # Make sure the coroutine was closed when submitted to the inactive tg
 | |
|         # (if not closed, a RuntimeWarning should have been raised)
 | |
|         with warnings.catch_warnings(record=True) as w:
 | |
|             await create_task_after_tg_finish()
 | |
|         self.assertEqual(len(w), 0)
 | |
| 
 | |
|     async def test_taskgroup_not_entered(self):
 | |
|         tg = taskgroups.TaskGroup()
 | |
|         coro = asyncio.sleep(0)
 | |
|         with self.assertRaisesRegex(RuntimeError, "has not been entered"):
 | |
|             tg.create_task(coro)
 | |
| 
 | |
|     async def test_taskgroup_without_parent_task(self):
 | |
|         tg = taskgroups.TaskGroup()
 | |
|         with self.assertRaisesRegex(RuntimeError, "parent task"):
 | |
|             await await_without_task(tg.__aenter__())
 | |
|         coro = asyncio.sleep(0)
 | |
|         with self.assertRaisesRegex(RuntimeError, "has not been entered"):
 | |
|             tg.create_task(coro)
 | |
| 
 | |
|     def test_coro_closed_when_tg_closed(self):
 | |
|         async def run_coro_after_tg_closes():
 | |
|             async with taskgroups.TaskGroup() as tg:
 | |
|                 pass
 | |
|             coro = asyncio.sleep(0)
 | |
|             with self.assertRaisesRegex(RuntimeError, "is finished"):
 | |
|                 tg.create_task(coro)
 | |
|         loop = asyncio.get_event_loop()
 | |
|         loop.run_until_complete(run_coro_after_tg_closes())
 | |
| 
 | |
|     async def test_cancelling_level_preserved(self):
 | |
|         async def raise_after(t, e):
 | |
|             await asyncio.sleep(t)
 | |
|             raise e()
 | |
| 
 | |
|         try:
 | |
|             async with asyncio.TaskGroup() as tg:
 | |
|                 tg.create_task(raise_after(0.0, RuntimeError))
 | |
|         except* RuntimeError:
 | |
|             pass
 | |
|         self.assertEqual(asyncio.current_task().cancelling(), 0)
 | |
| 
 | |
|     async def test_nested_groups_both_cancelled(self):
 | |
|         async def raise_after(t, e):
 | |
|             await asyncio.sleep(t)
 | |
|             raise e()
 | |
| 
 | |
|         try:
 | |
|             async with asyncio.TaskGroup() as outer_tg:
 | |
|                 try:
 | |
|                     async with asyncio.TaskGroup() as inner_tg:
 | |
|                         inner_tg.create_task(raise_after(0, RuntimeError))
 | |
|                         outer_tg.create_task(raise_after(0, ValueError))
 | |
|                 except* RuntimeError:
 | |
|                     pass
 | |
|                 else:
 | |
|                     self.fail("RuntimeError not raised")
 | |
|             self.assertEqual(asyncio.current_task().cancelling(), 1)
 | |
|         except* ValueError:
 | |
|             pass
 | |
|         else:
 | |
|             self.fail("ValueError not raised")
 | |
|         self.assertEqual(asyncio.current_task().cancelling(), 0)
 | |
| 
 | |
|     async def test_error_and_cancel(self):
 | |
|         event = asyncio.Event()
 | |
| 
 | |
|         async def raise_error():
 | |
|             event.set()
 | |
|             await asyncio.sleep(0)
 | |
|             raise RuntimeError()
 | |
| 
 | |
|         async def inner():
 | |
|             try:
 | |
|                 async with taskgroups.TaskGroup() as tg:
 | |
|                     tg.create_task(raise_error())
 | |
|                     await asyncio.sleep(1)
 | |
|                     self.fail("Sleep in group should have been cancelled")
 | |
|             except* RuntimeError:
 | |
|                 self.assertEqual(asyncio.current_task().cancelling(), 1)
 | |
|             self.assertEqual(asyncio.current_task().cancelling(), 1)
 | |
|             await asyncio.sleep(1)
 | |
|             self.fail("Sleep after group should have been cancelled")
 | |
| 
 | |
|         async def outer():
 | |
|             t = asyncio.create_task(inner())
 | |
|             await event.wait()
 | |
|             self.assertEqual(t.cancelling(), 0)
 | |
|             t.cancel()
 | |
|             self.assertEqual(t.cancelling(), 1)
 | |
|             with self.assertRaises(asyncio.CancelledError):
 | |
|                 await t
 | |
|             self.assertTrue(t.cancelled())
 | |
| 
 | |
|         await outer()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 |