mirror of
https://github.com/python/cpython.git
synced 2025-07-16 07:45:20 +00:00
[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111172)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.
* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
sequentially).
* If there is no current task when entering the context manager.
They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).
(cherry picked from commit 6c23635f2b
)
Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
parent
cf28c61c73
commit
cf777399a9
6 changed files with 121 additions and 10 deletions
|
@ -54,16 +54,14 @@ class TaskGroup:
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
if self._entered:
|
if self._entered:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"TaskGroup {self!r} has been already entered")
|
f"TaskGroup {self!r} has already been entered")
|
||||||
self._entered = True
|
|
||||||
|
|
||||||
if self._loop is None:
|
if self._loop is None:
|
||||||
self._loop = events.get_running_loop()
|
self._loop = events.get_running_loop()
|
||||||
|
|
||||||
self._parent_task = tasks.current_task(self._loop)
|
self._parent_task = tasks.current_task(self._loop)
|
||||||
if self._parent_task is None:
|
if self._parent_task is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'TaskGroup {self!r} cannot determine the parent task')
|
f'TaskGroup {self!r} cannot determine the parent task')
|
||||||
|
self._entered = True
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,9 @@ class Timeout:
|
||||||
|
|
||||||
def reschedule(self, when: Optional[float]) -> None:
|
def reschedule(self, when: Optional[float]) -> None:
|
||||||
"""Reschedule the timeout."""
|
"""Reschedule the timeout."""
|
||||||
assert self._state is not _State.CREATED
|
|
||||||
if self._state is not _State.ENTERED:
|
if self._state is not _State.ENTERED:
|
||||||
|
if self._state is _State.CREATED:
|
||||||
|
raise RuntimeError("Timeout has not been entered")
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot change state of {self._state.value} Timeout",
|
f"Cannot change state of {self._state.value} Timeout",
|
||||||
)
|
)
|
||||||
|
@ -82,11 +83,14 @@ class Timeout:
|
||||||
return f"<Timeout [{self._state.value}]{info_str}>"
|
return f"<Timeout [{self._state.value}]{info_str}>"
|
||||||
|
|
||||||
async def __aenter__(self) -> "Timeout":
|
async def __aenter__(self) -> "Timeout":
|
||||||
self._state = _State.ENTERED
|
if self._state is not _State.CREATED:
|
||||||
self._task = tasks.current_task()
|
raise RuntimeError("Timeout has already been entered")
|
||||||
self._cancelling = self._task.cancelling()
|
task = tasks.current_task()
|
||||||
if self._task is None:
|
if task is None:
|
||||||
raise RuntimeError("Timeout should be used inside a task")
|
raise RuntimeError("Timeout should be used inside a task")
|
||||||
|
self._state = _State.ENTERED
|
||||||
|
self._task = task
|
||||||
|
self._cancelling = self._task.cancelling()
|
||||||
self.reschedule(self._when)
|
self.reschedule(self._when)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,8 @@ import contextlib
|
||||||
from asyncio import taskgroups
|
from asyncio import taskgroups
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from test.test_asyncio.utils import await_without_task
|
||||||
|
|
||||||
|
|
||||||
# To prevent a warning "test altered the execution environment"
|
# To prevent a warning "test altered the execution environment"
|
||||||
def tearDownModule():
|
def tearDownModule():
|
||||||
|
@ -779,6 +781,49 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
await asyncio.create_task(main())
|
await asyncio.create_task(main())
|
||||||
|
|
||||||
|
async def test_taskgroup_already_entered(self):
|
||||||
|
tg = taskgroups.TaskGroup()
|
||||||
|
async with tg:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||||
|
async with tg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_taskgroup_double_enter(self):
|
||||||
|
tg = taskgroups.TaskGroup()
|
||||||
|
async with tg:
|
||||||
|
pass
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||||
|
async with tg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_taskgroup_finished(self):
|
||||||
|
tg = taskgroups.TaskGroup()
|
||||||
|
async with tg:
|
||||||
|
pass
|
||||||
|
coro = asyncio.sleep(0)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "is finished"):
|
||||||
|
tg.create_task(coro)
|
||||||
|
# We still have to await coro to avoid a warning
|
||||||
|
await coro
|
||||||
|
|
||||||
|
async def test_taskgroup_not_entered(self):
|
||||||
|
tg = taskgroups.TaskGroup()
|
||||||
|
coro = asyncio.sleep(0)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||||
|
tg.create_task(coro)
|
||||||
|
# We still have to await coro to avoid a warning
|
||||||
|
await coro
|
||||||
|
|
||||||
|
async def test_taskgroup_without_parent_task(self):
|
||||||
|
tg = taskgroups.TaskGroup()
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "parent task"):
|
||||||
|
await await_without_task(tg.__aenter__())
|
||||||
|
coro = asyncio.sleep(0)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||||
|
tg.create_task(coro)
|
||||||
|
# We still have to await coro to avoid a warning
|
||||||
|
await coro
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -6,11 +6,12 @@ import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import tasks
|
from asyncio import tasks
|
||||||
|
|
||||||
|
from test.test_asyncio.utils import await_without_task
|
||||||
|
|
||||||
|
|
||||||
def tearDownModule():
|
def tearDownModule():
|
||||||
asyncio.set_event_loop_policy(None)
|
asyncio.set_event_loop_policy(None)
|
||||||
|
|
||||||
|
|
||||||
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
|
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_timeout_basic(self):
|
async def test_timeout_basic(self):
|
||||||
|
@ -258,6 +259,51 @@ class TimeoutTests(unittest.IsolatedAsyncioTestCase):
|
||||||
cause = exc.exception.__cause__
|
cause = exc.exception.__cause__
|
||||||
assert isinstance(cause, asyncio.CancelledError)
|
assert isinstance(cause, asyncio.CancelledError)
|
||||||
|
|
||||||
|
async def test_timeout_already_entered(self):
|
||||||
|
async with asyncio.timeout(0.01) as cm:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||||
|
async with cm:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_timeout_double_enter(self):
|
||||||
|
async with asyncio.timeout(0.01) as cm:
|
||||||
|
pass
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||||
|
async with cm:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_timeout_finished(self):
|
||||||
|
async with asyncio.timeout(0.01) as cm:
|
||||||
|
pass
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "finished"):
|
||||||
|
cm.reschedule(0.02)
|
||||||
|
|
||||||
|
async def test_timeout_expired(self):
|
||||||
|
with self.assertRaises(TimeoutError):
|
||||||
|
async with asyncio.timeout(0.01) as cm:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expired"):
|
||||||
|
cm.reschedule(0.02)
|
||||||
|
|
||||||
|
async def test_timeout_expiring(self):
|
||||||
|
async with asyncio.timeout(0.01) as cm:
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expiring"):
|
||||||
|
cm.reschedule(0.02)
|
||||||
|
|
||||||
|
async def test_timeout_not_entered(self):
|
||||||
|
cm = asyncio.timeout(0.01)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||||
|
cm.reschedule(0.02)
|
||||||
|
|
||||||
|
async def test_timeout_without_task(self):
|
||||||
|
cm = asyncio.timeout(0.01)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "task"):
|
||||||
|
await await_without_task(cm.__aenter__())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||||
|
cm.reschedule(0.02)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
|
||||||
sock.family = family
|
sock.family = family
|
||||||
sock.gettimeout.return_value = 0.0
|
sock.gettimeout.return_value = 0.0
|
||||||
return sock
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
async def await_without_task(coro):
|
||||||
|
exc = None
|
||||||
|
def func():
|
||||||
|
try:
|
||||||
|
for _ in coro.__await__():
|
||||||
|
pass
|
||||||
|
except BaseException as err:
|
||||||
|
nonlocal exc
|
||||||
|
exc = err
|
||||||
|
asyncio.get_running_loop().call_soon(func)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
Fix invalid state handling in :class:`asyncio.TaskGroup` and
|
||||||
|
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
|
||||||
|
improperly used and are left in consistent state after this.
|
Loading…
Add table
Add a link
Reference in a new issue