GH-111693: Propagate correct asyncio.CancelledError instance out of asyncio.Condition.wait() (#111694)

Also fix a race condition in `asyncio.Semaphore.acquire()` when cancelled.
This commit is contained in:
Kristján Valur Jónsson 2024-01-08 19:57:48 +00:00 committed by GitHub
parent c6ca562138
commit 52161781a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 25 deletions

View file

@ -138,9 +138,6 @@ class Future:
exc = exceptions.CancelledError() exc = exceptions.CancelledError()
else: else:
exc = exceptions.CancelledError(self._cancel_message) exc = exceptions.CancelledError(self._cancel_message)
exc.__context__ = self._cancelled_exc
# Remove the reference since we don't need this anymore.
self._cancelled_exc = None
return exc return exc
def cancel(self, msg=None): def cancel(self, msg=None):

View file

@ -95,6 +95,8 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
This method blocks until the lock is unlocked, then sets it to This method blocks until the lock is unlocked, then sets it to
locked and returns True. locked and returns True.
""" """
# Implement fair scheduling, where thread always waits
# its turn. Jumping the queue if all are cancelled is an optimization.
if (not self._locked and (self._waiters is None or if (not self._locked and (self._waiters is None or
all(w.cancelled() for w in self._waiters))): all(w.cancelled() for w in self._waiters))):
self._locked = True self._locked = True
@ -105,19 +107,22 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
fut = self._get_loop().create_future() fut = self._get_loop().create_future()
self._waiters.append(fut) self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try: try:
try: try:
await fut await fut
finally: finally:
self._waiters.remove(fut) self._waiters.remove(fut)
except exceptions.CancelledError: except exceptions.CancelledError:
# Currently the only exception designed be able to occur here.
# Ensure the lock invariant: If lock is not claimed (or about
# to be claimed by us) and there is a Task in waiters,
# ensure that the Task at the head will run.
if not self._locked: if not self._locked:
self._wake_up_first() self._wake_up_first()
raise raise
# assert self._locked is False
self._locked = True self._locked = True
return True return True
@ -139,7 +144,7 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
raise RuntimeError('Lock is not acquired.') raise RuntimeError('Lock is not acquired.')
def _wake_up_first(self): def _wake_up_first(self):
"""Wake up the first waiter if it isn't done.""" """Ensure that the first waiter will wake up."""
if not self._waiters: if not self._waiters:
return return
try: try:
@ -147,9 +152,7 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
except StopIteration: except StopIteration:
return return
# .done() necessarily means that a waiter will wake up later on and # .done() means that the waiter is already set to wake up.
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not fut.done(): if not fut.done():
fut.set_result(True) fut.set_result(True)
@ -269,17 +272,22 @@ class Condition(_ContextManagerMixin, mixins._LoopBoundMixin):
self._waiters.remove(fut) self._waiters.remove(fut)
finally: finally:
# Must reacquire lock even if wait is cancelled # Must re-acquire lock even if wait is cancelled.
cancelled = False # We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True: while True:
try: try:
await self.acquire() await self.acquire()
break break
except exceptions.CancelledError: except exceptions.CancelledError as e:
cancelled = True err = e
if cancelled: if err:
raise exceptions.CancelledError try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
async def wait_for(self, predicate): async def wait_for(self, predicate):
"""Wait until a predicate becomes true. """Wait until a predicate becomes true.
@ -357,6 +365,7 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
def locked(self): def locked(self):
"""Returns True if semaphore cannot be acquired immediately.""" """Returns True if semaphore cannot be acquired immediately."""
# Due to state, or FIFO rules (must allow others to run first).
return self._value == 0 or ( return self._value == 0 or (
any(not w.cancelled() for w in (self._waiters or ()))) any(not w.cancelled() for w in (self._waiters or ())))
@ -370,6 +379,7 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
True. True.
""" """
if not self.locked(): if not self.locked():
# Maintain FIFO, wait for others to start even if _value > 0.
self._value -= 1 self._value -= 1
return True return True
@ -378,22 +388,27 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
fut = self._get_loop().create_future() fut = self._get_loop().create_future()
self._waiters.append(fut) self._waiters.append(fut)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try: try:
try: try:
await fut await fut
finally: finally:
self._waiters.remove(fut) self._waiters.remove(fut)
except exceptions.CancelledError: except exceptions.CancelledError:
if not fut.cancelled(): # Currently the only exception designed be able to occur here.
if fut.done() and not fut.cancelled():
# Our Future was successfully set to True via _wake_up_next(),
# but we are not about to successfully acquire(). Therefore we
# must undo the bookkeeping already done and attempt to wake
# up someone else.
self._value += 1 self._value += 1
self._wake_up_next()
raise raise
if self._value > 0: finally:
self._wake_up_next() # New waiters may have arrived but had to wait due to FIFO.
# Wake up as many as are allowed.
while self._value > 0:
if not self._wake_up_next():
break # There was no-one to wake up.
return True return True
def release(self): def release(self):
@ -408,13 +423,15 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
def _wake_up_next(self): def _wake_up_next(self):
"""Wake up the first waiter that isn't done.""" """Wake up the first waiter that isn't done."""
if not self._waiters: if not self._waiters:
return return False
for fut in self._waiters: for fut in self._waiters:
if not fut.done(): if not fut.done():
self._value -= 1 self._value -= 1
fut.set_result(True) fut.set_result(True)
return # `fut` is now `done()` and not `cancelled()`.
return True
return False
class BoundedSemaphore(Semaphore): class BoundedSemaphore(Semaphore):

View file

@ -758,6 +758,63 @@ class ConditionTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaises(asyncio.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5) await asyncio.wait_for(condition.wait(), timeout=0.5)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = asyncio.Condition()
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo")
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = asyncio.Condition()
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo")
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
class SemaphoreTests(unittest.IsolatedAsyncioTestCase): class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
@ -1044,6 +1101,62 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
self.assertEqual([2, 3], result) self.assertEqual([2, 3], result)
async def test_acquire_fifo_order_4(self):
# Test that a successfule `acquire()` will wake up multiple Tasks
# that were waiting in the Semaphore queue due to FIFO rules.
sem = asyncio.Semaphore(0)
result = []
count = 0
async def c1(result):
# First task immediatlly waits for semaphore. It will be awoken by c2.
self.assertEqual(sem._value, 0)
await sem.acquire()
# We should have woken up all waiting tasks now.
self.assertEqual(sem._value, 0)
# Create a fourth task. It should run after c3, not c2.
nonlocal t4
t4 = asyncio.create_task(c4(result))
result.append(1)
return True
async def c2(result):
# The second task begins by releasing semaphore three times,
# for c1, c2, and c3.
sem.release()
sem.release()
sem.release()
self.assertEqual(sem._value, 2)
# It is locked, because c1 hasn't woken up yet.
self.assertTrue(sem.locked())
await sem.acquire()
result.append(2)
return True
async def c3(result):
await sem.acquire()
self.assertTrue(sem.locked())
result.append(3)
return True
async def c4(result):
result.append(4)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
t4 = None
await asyncio.sleep(0)
# Three tasks are in the queue, the first hasn't woken up yet.
self.assertEqual(sem._value, 2)
self.assertEqual(len(sem._waiters), 3)
await asyncio.sleep(0)
tasks = [t1, t2, t3, t4]
await asyncio.gather(*tasks)
self.assertEqual([1, 2, 3, 4], result)
class BarrierTests(unittest.IsolatedAsyncioTestCase): class BarrierTests(unittest.IsolatedAsyncioTestCase):

View file

@ -0,0 +1 @@
:func:`asyncio.Condition.wait()` now re-raises the same :exc:`CancelledError` instance that may have caused it to be interrupted. Fixed race condition in :func:`asyncio.Semaphore.aquire` when interrupted with a :exc:`CancelledError`.