[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111172)

asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.

* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
  sequentially).
* If there is no current task when entering the context manager.

They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).

(cherry picked from commit 6c23635f2b)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
Miss Islington (bot) 2023-10-21 21:40:07 +02:00 committed by GitHub
parent cf28c61c73
commit cf777399a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 10 deletions

View file

@ -54,16 +54,14 @@ class TaskGroup:
async def __aenter__(self): async def __aenter__(self):
if self._entered: if self._entered:
raise RuntimeError( raise RuntimeError(
f"TaskGroup {self!r} has been already entered") f"TaskGroup {self!r} has already been entered")
self._entered = True
if self._loop is None: if self._loop is None:
self._loop = events.get_running_loop() self._loop = events.get_running_loop()
self._parent_task = tasks.current_task(self._loop) self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None: if self._parent_task is None:
raise RuntimeError( raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task') f'TaskGroup {self!r} cannot determine the parent task')
self._entered = True
return self return self

View file

@ -49,8 +49,9 @@ class Timeout:
def reschedule(self, when: Optional[float]) -> None: def reschedule(self, when: Optional[float]) -> None:
"""Reschedule the timeout.""" """Reschedule the timeout."""
assert self._state is not _State.CREATED
if self._state is not _State.ENTERED: if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError( raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout", f"Cannot change state of {self._state.value} Timeout",
) )
@ -82,11 +83,14 @@ class Timeout:
return f"<Timeout [{self._state.value}]{info_str}>" return f"<Timeout [{self._state.value}]{info_str}>"
async def __aenter__(self) -> "Timeout": async def __aenter__(self) -> "Timeout":
self._state = _State.ENTERED if self._state is not _State.CREATED:
self._task = tasks.current_task() raise RuntimeError("Timeout has already been entered")
self._cancelling = self._task.cancelling() task = tasks.current_task()
if self._task is None: if task is None:
raise RuntimeError("Timeout should be used inside a task") raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = task
self._cancelling = self._task.cancelling()
self.reschedule(self._when) self.reschedule(self._when)
return self return self

View file

@ -8,6 +8,8 @@ import contextlib
from asyncio import taskgroups from asyncio import taskgroups
import unittest import unittest
from test.test_asyncio.utils import await_without_task
# To prevent a warning "test altered the execution environment" # To prevent a warning "test altered the execution environment"
def tearDownModule(): def tearDownModule():
@ -779,6 +781,49 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
await asyncio.create_task(main()) await asyncio.create_task(main())
async def test_taskgroup_already_entered(self):
tg = taskgroups.TaskGroup()
async with tg:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_double_enter(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_finished(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
with self.assertRaisesRegex(RuntimeError, "parent task"):
await await_without_task(tg.__aenter__())
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -6,11 +6,12 @@ import time
import asyncio import asyncio
from asyncio import tasks from asyncio import tasks
from test.test_asyncio.utils import await_without_task
def tearDownModule(): def tearDownModule():
asyncio.set_event_loop_policy(None) asyncio.set_event_loop_policy(None)
class TimeoutTests(unittest.IsolatedAsyncioTestCase): class TimeoutTests(unittest.IsolatedAsyncioTestCase):
async def test_timeout_basic(self): async def test_timeout_basic(self):
@ -258,6 +259,51 @@ class TimeoutTests(unittest.IsolatedAsyncioTestCase):
cause = exc.exception.__cause__ cause = exc.exception.__cause__
assert isinstance(cause, asyncio.CancelledError) assert isinstance(cause, asyncio.CancelledError)
async def test_timeout_already_entered(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_double_enter(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_finished(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "finished"):
cm.reschedule(0.02)
async def test_timeout_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expired"):
cm.reschedule(0.02)
async def test_timeout_expiring(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expiring"):
cm.reschedule(0.02)
async def test_timeout_not_entered(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
async def test_timeout_without_task(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "task"):
await await_without_task(cm.__aenter__())
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
sock.family = family sock.family = family
sock.gettimeout.return_value = 0.0 sock.gettimeout.return_value = 0.0
return sock return sock
async def await_without_task(coro):
exc = None
def func():
try:
for _ in coro.__await__():
pass
except BaseException as err:
nonlocal exc
exc = err
asyncio.get_running_loop().call_soon(func)
await asyncio.sleep(0)
if exc is not None:
raise exc

View file

@ -0,0 +1,3 @@
Fix invalid state handling in :class:`asyncio.TaskGroup` and
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
improperly used and are left in consistent state after this.