gh-129874: improve asyncio tests to use correct internal functions (#129887)

This commit is contained in:
Kumar Aditya 2025-02-09 17:35:39 +05:30 committed by GitHub
parent c88dacb391
commit 09fe550ecc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 110 additions and 7 deletions

View file

@ -267,12 +267,33 @@ class EagerTaskFactoryLoopTests:
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask Task = tasks._PyTask
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
@unittest.skipUnless(hasattr(tasks, '_CTask'), @unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module') 'requires the C _asyncio module')
class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None) Task = getattr(tasks, '_CTask', None)
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
@unittest.skip("skip")
def test_issue105987(self): def test_issue105987(self):
code = """if 1: code = """if 1:
from _asyncio import _swap_current_task from _asyncio import _swap_current_task
@ -400,31 +421,83 @@ class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = asyncio.Task Task = asyncio.tasks._CTask
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase): class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = asyncio.Task Task = asyncio.tasks._CTask
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = tasks._PyTask Task = tasks._PyTask
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = tasks._PyTask Task = tasks._PyTask
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
@unittest.skipUnless(hasattr(tasks, '_CTask'), @unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module') 'requires the C _asyncio module')
class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None) Task = getattr(tasks, '_CTask', None)
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
@unittest.skipUnless(hasattr(tasks, '_CTask'), @unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module') 'requires the C _asyncio module')
class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None) Task = getattr(tasks, '_CTask', None)
def setUp(self):
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -40,7 +40,7 @@ class TestFreeThreading:
self.assertEqual(task.get_loop(), loop) self.assertEqual(task.get_loop(), loop)
self.assertFalse(task.done()) self.assertFalse(task.done())
current = self.current_task() current = asyncio.current_task()
self.assertEqual(current.get_loop(), loop) self.assertEqual(current.get_loop(), loop)
self.assertSetEqual(all_tasks, tasks | {current}) self.assertSetEqual(all_tasks, tasks | {current})
future.set_result(None) future.set_result(None)
@ -101,8 +101,12 @@ class TestFreeThreading:
async def func(): async def func():
nonlocal task nonlocal task
task = asyncio.current_task() task = asyncio.current_task()
def runner():
thread = Thread(target=lambda: asyncio.run(func())) with asyncio.Runner() as runner:
loop = runner.get_loop()
loop.set_task_factory(self.factory)
runner.run(func())
thread = Thread(target=runner)
thread.start() thread.start()
thread.join() thread.join()
wr = weakref.ref(task) wr = weakref.ref(task)
@ -164,7 +168,15 @@ class TestFreeThreading:
class TestPyFreeThreading(TestFreeThreading, TestCase): class TestPyFreeThreading(TestFreeThreading, TestCase):
all_tasks = staticmethod(asyncio.tasks._py_all_tasks) all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
current_task = staticmethod(asyncio.tasks._py_current_task)
def setUp(self):
self._old_current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
return super().tearDown()
def factory(self, loop, coro, **kwargs): def factory(self, loop, coro, **kwargs):
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs) return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
@ -173,7 +185,16 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") @unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
class TestCFreeThreading(TestFreeThreading, TestCase): class TestCFreeThreading(TestFreeThreading, TestCase):
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None)) all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
def setUp(self):
self._old_current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
return super().setUp()
def tearDown(self):
asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
return super().tearDown()
def factory(self, loop, coro, **kwargs): def factory(self, loop, coro, **kwargs):
return asyncio.tasks._CTask(coro, loop=loop, **kwargs) return asyncio.tasks._CTask(coro, loop=loop, **kwargs)

View file

@ -369,6 +369,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
futures.future_discard_from_awaited_by = futures._c_future_discard_from_awaited_by futures.future_discard_from_awaited_by = futures._c_future_discard_from_awaited_by
asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = tasks._c_current_task
def tearDown(self): def tearDown(self):
futures = asyncio.futures futures = asyncio.futures
@ -390,6 +392,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
futures.Future = self._Future futures.Future = self._Future
del self._Future del self._Future
asyncio.current_task = asyncio.tasks.current_task = self._current_task
@unittest.skipIf( @unittest.skipIf(
not hasattr(asyncio.futures, "_py_future_add_to_awaited_by"), not hasattr(asyncio.futures, "_py_future_add_to_awaited_by"),
@ -414,6 +418,9 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
futures.future_discard_from_awaited_by = futures._py_future_discard_from_awaited_by futures.future_discard_from_awaited_by = futures._py_future_discard_from_awaited_by
asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
self._current_task = asyncio.current_task
asyncio.current_task = asyncio.tasks.current_task = tasks._py_current_task
def tearDown(self): def tearDown(self):
futures = asyncio.futures futures = asyncio.futures
@ -434,3 +441,5 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
asyncio.Future = self._Future asyncio.Future = self._Future
futures.Future = self._Future futures.Future = self._Future
del self._Future del self._Future
asyncio.current_task = asyncio.tasks.current_task = self._current_task