GH-96764: rewrite asyncio.wait_for to use asyncio.timeout (#98518)

Changes `asyncio.wait_for` to use `asyncio.timeout` as its underlying implementation.
This commit is contained in:
Kumar Aditya 2023-02-17 00:18:21 +05:30 committed by GitHub
parent 226484e475
commit a5024a261a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 79 deletions

View file

@ -237,33 +237,6 @@ class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(FooException):
await foo()
async def test_wait_for_self_cancellation(self):
async def inner():
try:
await asyncio.sleep(0.3)
except asyncio.CancelledError:
try:
await asyncio.sleep(0.3)
except asyncio.CancelledError:
await asyncio.sleep(0.3)
return 42
inner_task = asyncio.create_task(inner())
wait = asyncio.wait_for(inner_task, timeout=0.1)
# Test that wait_for itself is properly cancellable
# even when the initial task holds up the initial cancellation.
task = asyncio.create_task(wait)
await asyncio.sleep(0.2)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(await inner_task, 42)
async def _test_cancel_wait_for(self, timeout):
loop = asyncio.get_running_loop()
@ -289,6 +262,106 @@ class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
async def test_cancel_wait_for(self):
await self._test_cancel_wait_for(60.0)
async def test_wait_for_cancel_suppressed(self):
# GH-86296: Supressing CancelledError is discouraged
# but if a task subpresses CancelledError and returns a value,
# `wait_for` should return the value instead of raising CancelledError.
# This is the same behavior as `asyncio.timeout`.
async def return_42():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
return 42
res = await asyncio.wait_for(return_42(), timeout=0.1)
self.assertEqual(res, 42)
async def test_wait_for_issue86296(self):
# GH-86296: The task should get cancelled and not run to completion.
# inner completes in one cycle of the event loop so it
# completes before the task is cancelled.
async def inner():
return 'done'
inner_task = asyncio.create_task(inner())
reached_end = False
async def wait_for_coro():
await asyncio.wait_for(inner_task, timeout=100)
await asyncio.sleep(1)
nonlocal reached_end
reached_end = True
task = asyncio.create_task(wait_for_coro())
self.assertFalse(task.done())
# Run the task
await asyncio.sleep(0)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(inner_task.done())
self.assertEqual(await inner_task, 'done')
self.assertFalse(reached_end)
class WaitForShieldTests(unittest.IsolatedAsyncioTestCase):
async def test_zero_timeout(self):
# `asyncio.shield` creates a new task which wraps the passed in
# awaitable and shields it from cancellation so with timeout=0
# the task returned by `asyncio.shield` aka shielded_task gets
# cancelled immediately and the task wrapped by it is scheduled
# to run.
async def coro():
await asyncio.sleep(0.01)
return 'done'
task = asyncio.create_task(coro())
with self.assertRaises(asyncio.TimeoutError):
shielded_task = asyncio.shield(task)
await asyncio.wait_for(shielded_task, timeout=0)
# Task is running in background
self.assertFalse(task.done())
self.assertFalse(task.cancelled())
self.assertTrue(shielded_task.cancelled())
# Wait for the task to complete
await asyncio.sleep(0.1)
self.assertTrue(task.done())
async def test_none_timeout(self):
# With timeout=None the timeout is disabled so it
# runs till completion.
async def coro():
await asyncio.sleep(0.1)
return 'done'
task = asyncio.create_task(coro())
await asyncio.wait_for(asyncio.shield(task), timeout=None)
self.assertTrue(task.done())
self.assertEqual(await task, "done")
async def test_shielded_timeout(self):
# shield prevents the task from being cancelled.
async def coro():
await asyncio.sleep(0.1)
return 'done'
task = asyncio.create_task(coro())
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(task), timeout=0.01)
self.assertFalse(task.done())
self.assertFalse(task.cancelled())
self.assertEqual(await task, "done")
if __name__ == '__main__':
unittest.main()