mirror of
https://github.com/python/cpython.git
synced 2025-07-20 09:45:21 +00:00

The main problem was that an unluckily timed task cancellation could cause
the semaphore to be stuck. There were also doubts about strict FIFO ordering
of tasks allowed to pass.
The Semaphore implementation was rewritten to be more similar to Lock.
Many tests for edge cases (including cancellation) were added.
(cherry picked from commit 24e0379624
)
Co-authored-by: Cyker Way <cykerway@gmail.com>
1621 lines
47 KiB
Python
1621 lines
47 KiB
Python
"""Tests for locks.py"""
|
|
|
|
import unittest
|
|
from unittest import mock
|
|
import re
|
|
|
|
import asyncio
|
|
import collections
|
|
|
|
STR_RGX_REPR = (
|
|
r'^<(?P<class>.*?) object at (?P<address>.*?)'
|
|
r'\[(?P<extras>'
|
|
r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
|
|
r'(, value:\d)?'
|
|
r'(, waiters:\d+)?'
|
|
r'(, waiters:\d+\/\d+)?' # barrier
|
|
r')\]>\Z'
|
|
)
|
|
RGX_REPR = re.compile(STR_RGX_REPR)
|
|
|
|
|
|
def tearDownModule():
|
|
asyncio.set_event_loop_policy(None)
|
|
|
|
|
|
class LockTests(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def test_repr(self):
|
|
lock = asyncio.Lock()
|
|
self.assertTrue(repr(lock).endswith('[unlocked]>'))
|
|
self.assertTrue(RGX_REPR.match(repr(lock)))
|
|
|
|
await lock.acquire()
|
|
self.assertTrue(repr(lock).endswith('[locked]>'))
|
|
self.assertTrue(RGX_REPR.match(repr(lock)))
|
|
|
|
async def test_lock(self):
|
|
lock = asyncio.Lock()
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
"object Lock can't be used in 'await' expression"
|
|
):
|
|
await lock
|
|
|
|
self.assertFalse(lock.locked())
|
|
|
|
async def test_lock_doesnt_accept_loop_parameter(self):
|
|
primitives_cls = [
|
|
asyncio.Lock,
|
|
asyncio.Condition,
|
|
asyncio.Event,
|
|
asyncio.Semaphore,
|
|
asyncio.BoundedSemaphore,
|
|
]
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
for cls in primitives_cls:
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
rf"{cls.__name__}\.__init__\(\) got an unexpected "
|
|
rf"keyword argument 'loop'"
|
|
):
|
|
cls(loop=loop)
|
|
|
|
async def test_lock_by_with_statement(self):
|
|
primitives = [
|
|
asyncio.Lock(),
|
|
asyncio.Condition(),
|
|
asyncio.Semaphore(),
|
|
asyncio.BoundedSemaphore(),
|
|
]
|
|
|
|
for lock in primitives:
|
|
await asyncio.sleep(0.01)
|
|
self.assertFalse(lock.locked())
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r"object \w+ can't be used in 'await' expression"
|
|
):
|
|
with await lock:
|
|
pass
|
|
self.assertFalse(lock.locked())
|
|
|
|
async def test_acquire(self):
|
|
lock = asyncio.Lock()
|
|
result = []
|
|
|
|
self.assertTrue(await lock.acquire())
|
|
|
|
async def c1(result):
|
|
if await lock.acquire():
|
|
result.append(1)
|
|
return True
|
|
|
|
async def c2(result):
|
|
if await lock.acquire():
|
|
result.append(2)
|
|
return True
|
|
|
|
async def c3(result):
|
|
if await lock.acquire():
|
|
result.append(3)
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
lock.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
lock.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2], result)
|
|
|
|
lock.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2, 3], result)
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertTrue(t1.result())
|
|
self.assertTrue(t2.done())
|
|
self.assertTrue(t2.result())
|
|
self.assertTrue(t3.done())
|
|
self.assertTrue(t3.result())
|
|
|
|
async def test_acquire_cancel(self):
|
|
lock = asyncio.Lock()
|
|
self.assertTrue(await lock.acquire())
|
|
|
|
task = asyncio.create_task(lock.acquire())
|
|
asyncio.get_running_loop().call_soon(task.cancel)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await task
|
|
self.assertFalse(lock._waiters)
|
|
|
|
async def test_cancel_race(self):
|
|
# Several tasks:
|
|
# - A acquires the lock
|
|
# - B is blocked in acquire()
|
|
# - C is blocked in acquire()
|
|
#
|
|
# Now, concurrently:
|
|
# - B is cancelled
|
|
# - A releases the lock
|
|
#
|
|
# If B's waiter is marked cancelled but not yet removed from
|
|
# _waiters, A's release() call will crash when trying to set
|
|
# B's waiter; instead, it should move on to C's waiter.
|
|
|
|
# Setup: A has the lock, b and c are waiting.
|
|
lock = asyncio.Lock()
|
|
|
|
async def lockit(name, blocker):
|
|
await lock.acquire()
|
|
try:
|
|
if blocker is not None:
|
|
await blocker
|
|
finally:
|
|
lock.release()
|
|
|
|
fa = asyncio.get_running_loop().create_future()
|
|
ta = asyncio.create_task(lockit('A', fa))
|
|
await asyncio.sleep(0)
|
|
self.assertTrue(lock.locked())
|
|
tb = asyncio.create_task(lockit('B', None))
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(len(lock._waiters), 1)
|
|
tc = asyncio.create_task(lockit('C', None))
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(len(lock._waiters), 2)
|
|
|
|
# Create the race and check.
|
|
# Without the fix this failed at the last assert.
|
|
fa.set_result(None)
|
|
tb.cancel()
|
|
self.assertTrue(lock._waiters[0].cancelled())
|
|
await asyncio.sleep(0)
|
|
self.assertFalse(lock.locked())
|
|
self.assertTrue(ta.done())
|
|
self.assertTrue(tb.cancelled())
|
|
await tc
|
|
|
|
async def test_cancel_release_race(self):
|
|
# Issue 32734
|
|
# Acquire 4 locks, cancel second, release first
|
|
# and 2 locks are taken at once.
|
|
loop = asyncio.get_running_loop()
|
|
lock = asyncio.Lock()
|
|
lock_count = 0
|
|
call_count = 0
|
|
|
|
async def lockit():
|
|
nonlocal lock_count
|
|
nonlocal call_count
|
|
call_count += 1
|
|
await lock.acquire()
|
|
lock_count += 1
|
|
|
|
def trigger():
|
|
t1.cancel()
|
|
lock.release()
|
|
|
|
await lock.acquire()
|
|
|
|
t1 = asyncio.create_task(lockit())
|
|
t2 = asyncio.create_task(lockit())
|
|
t3 = asyncio.create_task(lockit())
|
|
|
|
# Start scheduled tasks
|
|
await asyncio.sleep(0)
|
|
|
|
loop.call_soon(trigger)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
# Wait for cancellation
|
|
await t1
|
|
|
|
# Make sure only one lock was taken
|
|
self.assertEqual(lock_count, 1)
|
|
# While 3 calls were made to lockit()
|
|
self.assertEqual(call_count, 3)
|
|
self.assertTrue(t1.cancelled() and t2.done())
|
|
|
|
# Cleanup the task that is stuck on acquire.
|
|
t3.cancel()
|
|
await asyncio.sleep(0)
|
|
self.assertTrue(t3.cancelled())
|
|
|
|
async def test_finished_waiter_cancelled(self):
|
|
lock = asyncio.Lock()
|
|
|
|
await lock.acquire()
|
|
self.assertTrue(lock.locked())
|
|
|
|
tb = asyncio.create_task(lock.acquire())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(len(lock._waiters), 1)
|
|
|
|
# Create a second waiter, wake up the first, and cancel it.
|
|
# Without the fix, the second was not woken up.
|
|
tc = asyncio.create_task(lock.acquire())
|
|
tb.cancel()
|
|
lock.release()
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertTrue(lock.locked())
|
|
self.assertTrue(tb.cancelled())
|
|
|
|
# Cleanup
|
|
await tc
|
|
|
|
async def test_release_not_acquired(self):
|
|
lock = asyncio.Lock()
|
|
|
|
self.assertRaises(RuntimeError, lock.release)
|
|
|
|
async def test_release_no_waiters(self):
|
|
lock = asyncio.Lock()
|
|
await lock.acquire()
|
|
self.assertTrue(lock.locked())
|
|
|
|
lock.release()
|
|
self.assertFalse(lock.locked())
|
|
|
|
async def test_context_manager(self):
|
|
lock = asyncio.Lock()
|
|
self.assertFalse(lock.locked())
|
|
|
|
async with lock:
|
|
self.assertTrue(lock.locked())
|
|
|
|
self.assertFalse(lock.locked())
|
|
|
|
|
|
class EventTests(unittest.IsolatedAsyncioTestCase):
|
|
|
|
def test_repr(self):
|
|
ev = asyncio.Event()
|
|
self.assertTrue(repr(ev).endswith('[unset]>'))
|
|
match = RGX_REPR.match(repr(ev))
|
|
self.assertEqual(match.group('extras'), 'unset')
|
|
|
|
ev.set()
|
|
self.assertTrue(repr(ev).endswith('[set]>'))
|
|
self.assertTrue(RGX_REPR.match(repr(ev)))
|
|
|
|
ev._waiters.append(mock.Mock())
|
|
self.assertTrue('waiters:1' in repr(ev))
|
|
self.assertTrue(RGX_REPR.match(repr(ev)))
|
|
|
|
async def test_wait(self):
|
|
ev = asyncio.Event()
|
|
self.assertFalse(ev.is_set())
|
|
|
|
result = []
|
|
|
|
async def c1(result):
|
|
if await ev.wait():
|
|
result.append(1)
|
|
|
|
async def c2(result):
|
|
if await ev.wait():
|
|
result.append(2)
|
|
|
|
async def c3(result):
|
|
if await ev.wait():
|
|
result.append(3)
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
ev.set()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([3, 1, 2], result)
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertIsNone(t1.result())
|
|
self.assertTrue(t2.done())
|
|
self.assertIsNone(t2.result())
|
|
self.assertTrue(t3.done())
|
|
self.assertIsNone(t3.result())
|
|
|
|
async def test_wait_on_set(self):
|
|
ev = asyncio.Event()
|
|
ev.set()
|
|
|
|
res = await ev.wait()
|
|
self.assertTrue(res)
|
|
|
|
async def test_wait_cancel(self):
|
|
ev = asyncio.Event()
|
|
|
|
wait = asyncio.create_task(ev.wait())
|
|
asyncio.get_running_loop().call_soon(wait.cancel)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await wait
|
|
self.assertFalse(ev._waiters)
|
|
|
|
async def test_clear(self):
|
|
ev = asyncio.Event()
|
|
self.assertFalse(ev.is_set())
|
|
|
|
ev.set()
|
|
self.assertTrue(ev.is_set())
|
|
|
|
ev.clear()
|
|
self.assertFalse(ev.is_set())
|
|
|
|
async def test_clear_with_waiters(self):
|
|
ev = asyncio.Event()
|
|
result = []
|
|
|
|
async def c1(result):
|
|
if await ev.wait():
|
|
result.append(1)
|
|
return True
|
|
|
|
t = asyncio.create_task(c1(result))
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
ev.set()
|
|
ev.clear()
|
|
self.assertFalse(ev.is_set())
|
|
|
|
ev.set()
|
|
ev.set()
|
|
self.assertEqual(1, len(ev._waiters))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
self.assertEqual(0, len(ev._waiters))
|
|
|
|
self.assertTrue(t.done())
|
|
self.assertTrue(t.result())
|
|
|
|
|
|
class ConditionTests(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def test_wait(self):
|
|
cond = asyncio.Condition()
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(1)
|
|
return True
|
|
|
|
async def c2(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(2)
|
|
return True
|
|
|
|
async def c3(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(3)
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
self.assertFalse(cond.locked())
|
|
|
|
self.assertTrue(await cond.acquire())
|
|
cond.notify()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
self.assertTrue(cond.locked())
|
|
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
self.assertTrue(cond.locked())
|
|
|
|
cond.notify(2)
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
self.assertTrue(cond.locked())
|
|
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2], result)
|
|
self.assertTrue(cond.locked())
|
|
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2, 3], result)
|
|
self.assertTrue(cond.locked())
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertTrue(t1.result())
|
|
self.assertTrue(t2.done())
|
|
self.assertTrue(t2.result())
|
|
self.assertTrue(t3.done())
|
|
self.assertTrue(t3.result())
|
|
|
|
async def test_wait_cancel(self):
|
|
cond = asyncio.Condition()
|
|
await cond.acquire()
|
|
|
|
wait = asyncio.create_task(cond.wait())
|
|
asyncio.get_running_loop().call_soon(wait.cancel)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await wait
|
|
self.assertFalse(cond._waiters)
|
|
self.assertTrue(cond.locked())
|
|
|
|
async def test_wait_cancel_contested(self):
|
|
cond = asyncio.Condition()
|
|
|
|
await cond.acquire()
|
|
self.assertTrue(cond.locked())
|
|
|
|
wait_task = asyncio.create_task(cond.wait())
|
|
await asyncio.sleep(0)
|
|
self.assertFalse(cond.locked())
|
|
|
|
# Notify, but contest the lock before cancelling
|
|
await cond.acquire()
|
|
self.assertTrue(cond.locked())
|
|
cond.notify()
|
|
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
|
asyncio.get_running_loop().call_soon(cond.release)
|
|
|
|
try:
|
|
await wait_task
|
|
except asyncio.CancelledError:
|
|
# Should not happen, since no cancellation points
|
|
pass
|
|
|
|
self.assertTrue(cond.locked())
|
|
|
|
async def test_wait_cancel_after_notify(self):
|
|
# See bpo-32841
|
|
waited = False
|
|
|
|
cond = asyncio.Condition()
|
|
|
|
async def wait_on_cond():
|
|
nonlocal waited
|
|
async with cond:
|
|
waited = True # Make sure this area was reached
|
|
await cond.wait()
|
|
|
|
waiter = asyncio.create_task(wait_on_cond())
|
|
await asyncio.sleep(0) # Start waiting
|
|
|
|
await cond.acquire()
|
|
cond.notify()
|
|
await asyncio.sleep(0) # Get to acquire()
|
|
waiter.cancel()
|
|
await asyncio.sleep(0) # Activate cancellation
|
|
cond.release()
|
|
await asyncio.sleep(0) # Cancellation should occur
|
|
|
|
self.assertTrue(waiter.cancelled())
|
|
self.assertTrue(waited)
|
|
|
|
async def test_wait_unacquired(self):
|
|
cond = asyncio.Condition()
|
|
with self.assertRaises(RuntimeError):
|
|
await cond.wait()
|
|
|
|
async def test_wait_for(self):
|
|
cond = asyncio.Condition()
|
|
presult = False
|
|
|
|
def predicate():
|
|
return presult
|
|
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await cond.acquire()
|
|
if await cond.wait_for(predicate):
|
|
result.append(1)
|
|
cond.release()
|
|
return True
|
|
|
|
t = asyncio.create_task(c1(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
await cond.acquire()
|
|
cond.notify()
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
presult = True
|
|
await cond.acquire()
|
|
cond.notify()
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
|
|
self.assertTrue(t.done())
|
|
self.assertTrue(t.result())
|
|
|
|
async def test_wait_for_unacquired(self):
|
|
cond = asyncio.Condition()
|
|
|
|
# predicate can return true immediately
|
|
res = await cond.wait_for(lambda: [1, 2, 3])
|
|
self.assertEqual([1, 2, 3], res)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
await cond.wait_for(lambda: False)
|
|
|
|
async def test_notify(self):
|
|
cond = asyncio.Condition()
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(1)
|
|
cond.release()
|
|
return True
|
|
|
|
async def c2(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(2)
|
|
cond.release()
|
|
return True
|
|
|
|
async def c3(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(3)
|
|
cond.release()
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
await cond.acquire()
|
|
cond.notify(1)
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
|
|
await cond.acquire()
|
|
cond.notify(1)
|
|
cond.notify(2048)
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2, 3], result)
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertTrue(t1.result())
|
|
self.assertTrue(t2.done())
|
|
self.assertTrue(t2.result())
|
|
self.assertTrue(t3.done())
|
|
self.assertTrue(t3.result())
|
|
|
|
async def test_notify_all(self):
|
|
cond = asyncio.Condition()
|
|
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(1)
|
|
cond.release()
|
|
return True
|
|
|
|
async def c2(result):
|
|
await cond.acquire()
|
|
if await cond.wait():
|
|
result.append(2)
|
|
cond.release()
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([], result)
|
|
|
|
await cond.acquire()
|
|
cond.notify_all()
|
|
cond.release()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1, 2], result)
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertTrue(t1.result())
|
|
self.assertTrue(t2.done())
|
|
self.assertTrue(t2.result())
|
|
|
|
def test_notify_unacquired(self):
|
|
cond = asyncio.Condition()
|
|
self.assertRaises(RuntimeError, cond.notify)
|
|
|
|
def test_notify_all_unacquired(self):
|
|
cond = asyncio.Condition()
|
|
self.assertRaises(RuntimeError, cond.notify_all)
|
|
|
|
async def test_repr(self):
|
|
cond = asyncio.Condition()
|
|
self.assertTrue('unlocked' in repr(cond))
|
|
self.assertTrue(RGX_REPR.match(repr(cond)))
|
|
|
|
await cond.acquire()
|
|
self.assertTrue('locked' in repr(cond))
|
|
|
|
cond._waiters.append(mock.Mock())
|
|
self.assertTrue('waiters:1' in repr(cond))
|
|
self.assertTrue(RGX_REPR.match(repr(cond)))
|
|
|
|
cond._waiters.append(mock.Mock())
|
|
self.assertTrue('waiters:2' in repr(cond))
|
|
self.assertTrue(RGX_REPR.match(repr(cond)))
|
|
|
|
async def test_context_manager(self):
|
|
cond = asyncio.Condition()
|
|
self.assertFalse(cond.locked())
|
|
async with cond:
|
|
self.assertTrue(cond.locked())
|
|
self.assertFalse(cond.locked())
|
|
|
|
async def test_explicit_lock(self):
|
|
async def f(lock=None, cond=None):
|
|
if lock is None:
|
|
lock = asyncio.Lock()
|
|
if cond is None:
|
|
cond = asyncio.Condition(lock)
|
|
self.assertIs(cond._lock, lock)
|
|
self.assertFalse(lock.locked())
|
|
self.assertFalse(cond.locked())
|
|
async with cond:
|
|
self.assertTrue(lock.locked())
|
|
self.assertTrue(cond.locked())
|
|
self.assertFalse(lock.locked())
|
|
self.assertFalse(cond.locked())
|
|
async with lock:
|
|
self.assertTrue(lock.locked())
|
|
self.assertTrue(cond.locked())
|
|
self.assertFalse(lock.locked())
|
|
self.assertFalse(cond.locked())
|
|
|
|
# All should work in the same way.
|
|
await f()
|
|
await f(asyncio.Lock())
|
|
lock = asyncio.Lock()
|
|
await f(lock, asyncio.Condition(lock))
|
|
|
|
async def test_ambiguous_loops(self):
|
|
loop = asyncio.new_event_loop()
|
|
self.addCleanup(loop.close)
|
|
|
|
async def wrong_loop_in_lock():
|
|
with self.assertRaises(TypeError):
|
|
asyncio.Lock(loop=loop) # actively disallowed since 3.10
|
|
lock = asyncio.Lock()
|
|
lock._loop = loop # use private API for testing
|
|
async with lock:
|
|
# acquired immediately via the fast-path
|
|
# without interaction with any event loop.
|
|
cond = asyncio.Condition(lock)
|
|
# cond.acquire() will trigger waiting on the lock
|
|
# and it will discover the event loop mismatch.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"is bound to a different event loop",
|
|
):
|
|
await cond.acquire()
|
|
|
|
async def wrong_loop_in_cond():
|
|
# Same analogy here with the condition's loop.
|
|
lock = asyncio.Lock()
|
|
async with lock:
|
|
with self.assertRaises(TypeError):
|
|
asyncio.Condition(lock, loop=loop)
|
|
cond = asyncio.Condition(lock)
|
|
cond._loop = loop
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"is bound to a different event loop",
|
|
):
|
|
await cond.wait()
|
|
|
|
await wrong_loop_in_lock()
|
|
await wrong_loop_in_cond()
|
|
|
|
async def test_timeout_in_block(self):
|
|
condition = asyncio.Condition()
|
|
async with condition:
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
|
|
|
|
|
class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
|
|
|
|
def test_initial_value_zero(self):
|
|
sem = asyncio.Semaphore(0)
|
|
self.assertTrue(sem.locked())
|
|
|
|
async def test_repr(self):
|
|
sem = asyncio.Semaphore()
|
|
self.assertTrue(repr(sem).endswith('[unlocked, value:1]>'))
|
|
self.assertTrue(RGX_REPR.match(repr(sem)))
|
|
|
|
await sem.acquire()
|
|
self.assertTrue(repr(sem).endswith('[locked]>'))
|
|
self.assertTrue('waiters' not in repr(sem))
|
|
self.assertTrue(RGX_REPR.match(repr(sem)))
|
|
|
|
if sem._waiters is None:
|
|
sem._waiters = collections.deque()
|
|
|
|
sem._waiters.append(mock.Mock())
|
|
self.assertTrue('waiters:1' in repr(sem))
|
|
self.assertTrue(RGX_REPR.match(repr(sem)))
|
|
|
|
sem._waiters.append(mock.Mock())
|
|
self.assertTrue('waiters:2' in repr(sem))
|
|
self.assertTrue(RGX_REPR.match(repr(sem)))
|
|
|
|
async def test_semaphore(self):
|
|
sem = asyncio.Semaphore()
|
|
self.assertEqual(1, sem._value)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
"object Semaphore can't be used in 'await' expression",
|
|
):
|
|
await sem
|
|
|
|
self.assertFalse(sem.locked())
|
|
self.assertEqual(1, sem._value)
|
|
|
|
def test_semaphore_value(self):
|
|
self.assertRaises(ValueError, asyncio.Semaphore, -1)
|
|
|
|
async def test_acquire(self):
|
|
sem = asyncio.Semaphore(3)
|
|
result = []
|
|
|
|
self.assertTrue(await sem.acquire())
|
|
self.assertTrue(await sem.acquire())
|
|
self.assertFalse(sem.locked())
|
|
|
|
async def c1(result):
|
|
await sem.acquire()
|
|
result.append(1)
|
|
return True
|
|
|
|
async def c2(result):
|
|
await sem.acquire()
|
|
result.append(2)
|
|
return True
|
|
|
|
async def c3(result):
|
|
await sem.acquire()
|
|
result.append(3)
|
|
return True
|
|
|
|
async def c4(result):
|
|
await sem.acquire()
|
|
result.append(4)
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
await asyncio.sleep(0)
|
|
self.assertEqual([1], result)
|
|
self.assertTrue(sem.locked())
|
|
self.assertEqual(2, len(sem._waiters))
|
|
self.assertEqual(0, sem._value)
|
|
|
|
t4 = asyncio.create_task(c4(result))
|
|
|
|
sem.release()
|
|
sem.release()
|
|
self.assertEqual(2, sem._value)
|
|
|
|
await asyncio.sleep(0)
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(0, sem._value)
|
|
self.assertEqual(3, len(result))
|
|
self.assertTrue(sem.locked())
|
|
self.assertEqual(1, len(sem._waiters))
|
|
self.assertEqual(0, sem._value)
|
|
|
|
self.assertTrue(t1.done())
|
|
self.assertTrue(t1.result())
|
|
race_tasks = [t2, t3, t4]
|
|
done_tasks = [t for t in race_tasks if t.done() and t.result()]
|
|
self.assertEqual(2, len(done_tasks))
|
|
|
|
# cleanup locked semaphore
|
|
sem.release()
|
|
await asyncio.gather(*race_tasks)
|
|
|
|
async def test_acquire_cancel(self):
|
|
sem = asyncio.Semaphore()
|
|
await sem.acquire()
|
|
|
|
acquire = asyncio.create_task(sem.acquire())
|
|
asyncio.get_running_loop().call_soon(acquire.cancel)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await acquire
|
|
self.assertTrue((not sem._waiters) or
|
|
all(waiter.done() for waiter in sem._waiters))
|
|
|
|
async def test_acquire_cancel_before_awoken(self):
|
|
sem = asyncio.Semaphore(value=0)
|
|
|
|
t1 = asyncio.create_task(sem.acquire())
|
|
t2 = asyncio.create_task(sem.acquire())
|
|
t3 = asyncio.create_task(sem.acquire())
|
|
t4 = asyncio.create_task(sem.acquire())
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
t1.cancel()
|
|
t2.cancel()
|
|
sem.release()
|
|
|
|
await asyncio.sleep(0)
|
|
await asyncio.sleep(0)
|
|
num_done = sum(t.done() for t in [t3, t4])
|
|
self.assertEqual(num_done, 1)
|
|
self.assertTrue(t3.done())
|
|
self.assertFalse(t4.done())
|
|
|
|
t3.cancel()
|
|
t4.cancel()
|
|
await asyncio.sleep(0)
|
|
|
|
async def test_acquire_hang(self):
|
|
sem = asyncio.Semaphore(value=0)
|
|
|
|
t1 = asyncio.create_task(sem.acquire())
|
|
t2 = asyncio.create_task(sem.acquire())
|
|
await asyncio.sleep(0)
|
|
|
|
t1.cancel()
|
|
sem.release()
|
|
await asyncio.sleep(0)
|
|
await asyncio.sleep(0)
|
|
self.assertTrue(sem.locked())
|
|
self.assertTrue(t2.done())
|
|
|
|
async def test_acquire_no_hang(self):
|
|
|
|
sem = asyncio.Semaphore(1)
|
|
|
|
async def c1():
|
|
async with sem:
|
|
await asyncio.sleep(0)
|
|
t2.cancel()
|
|
|
|
async def c2():
|
|
async with sem:
|
|
self.assertFalse(True)
|
|
|
|
t1 = asyncio.create_task(c1())
|
|
t2 = asyncio.create_task(c2())
|
|
|
|
r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True)
|
|
self.assertTrue(r1 is None)
|
|
self.assertTrue(isinstance(r2, asyncio.CancelledError))
|
|
|
|
await asyncio.wait_for(sem.acquire(), timeout=1.0)
|
|
|
|
def test_release_not_acquired(self):
|
|
sem = asyncio.BoundedSemaphore()
|
|
|
|
self.assertRaises(ValueError, sem.release)
|
|
|
|
async def test_release_no_waiters(self):
|
|
sem = asyncio.Semaphore()
|
|
await sem.acquire()
|
|
self.assertTrue(sem.locked())
|
|
|
|
sem.release()
|
|
self.assertFalse(sem.locked())
|
|
|
|
async def test_acquire_fifo_order(self):
|
|
sem = asyncio.Semaphore(1)
|
|
result = []
|
|
|
|
async def coro(tag):
|
|
await sem.acquire()
|
|
result.append(f'{tag}_1')
|
|
await asyncio.sleep(0.01)
|
|
sem.release()
|
|
|
|
await sem.acquire()
|
|
result.append(f'{tag}_2')
|
|
await asyncio.sleep(0.01)
|
|
sem.release()
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
tg.create_task(coro('c1'))
|
|
tg.create_task(coro('c2'))
|
|
tg.create_task(coro('c3'))
|
|
|
|
self.assertEqual(
|
|
['c1_1', 'c2_1', 'c3_1', 'c1_2', 'c2_2', 'c3_2'],
|
|
result
|
|
)
|
|
|
|
async def test_acquire_fifo_order_2(self):
|
|
sem = asyncio.Semaphore(1)
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await sem.acquire()
|
|
result.append(1)
|
|
return True
|
|
|
|
async def c2(result):
|
|
await sem.acquire()
|
|
result.append(2)
|
|
sem.release()
|
|
await sem.acquire()
|
|
result.append(4)
|
|
return True
|
|
|
|
async def c3(result):
|
|
await sem.acquire()
|
|
result.append(3)
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
sem.release()
|
|
sem.release()
|
|
|
|
tasks = [t1, t2, t3]
|
|
await asyncio.gather(*tasks)
|
|
self.assertEqual([1, 2, 3, 4], result)
|
|
|
|
async def test_acquire_fifo_order_3(self):
|
|
sem = asyncio.Semaphore(0)
|
|
result = []
|
|
|
|
async def c1(result):
|
|
await sem.acquire()
|
|
result.append(1)
|
|
return True
|
|
|
|
async def c2(result):
|
|
await sem.acquire()
|
|
result.append(2)
|
|
return True
|
|
|
|
async def c3(result):
|
|
await sem.acquire()
|
|
result.append(3)
|
|
return True
|
|
|
|
t1 = asyncio.create_task(c1(result))
|
|
t2 = asyncio.create_task(c2(result))
|
|
t3 = asyncio.create_task(c3(result))
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
t1.cancel()
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
sem.release()
|
|
sem.release()
|
|
|
|
tasks = [t1, t2, t3]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
self.assertEqual([2, 3], result)
|
|
|
|
|
|
class BarrierTests(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def asyncSetUp(self):
|
|
await super().asyncSetUp()
|
|
self.N = 5
|
|
|
|
def make_tasks(self, n, coro):
|
|
tasks = [asyncio.create_task(coro()) for _ in range(n)]
|
|
return tasks
|
|
|
|
async def gather_tasks(self, n, coro):
|
|
tasks = self.make_tasks(n, coro)
|
|
res = await asyncio.gather(*tasks)
|
|
return res, tasks
|
|
|
|
async def test_barrier(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
self.assertIn("filling", repr(barrier))
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
"object Barrier can't be used in 'await' expression",
|
|
):
|
|
await barrier
|
|
|
|
self.assertIn("filling", repr(barrier))
|
|
|
|
async def test_repr(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
|
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
|
self.assertIn("filling", repr(barrier))
|
|
|
|
waiters = []
|
|
async def wait(barrier):
|
|
await barrier.wait()
|
|
|
|
incr = 2
|
|
for i in range(incr):
|
|
waiters.append(asyncio.create_task(wait(barrier)))
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
|
self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
|
|
self.assertIn("filling", repr(barrier))
|
|
|
|
# create missing waiters
|
|
for i in range(barrier.parties - barrier.n_waiting):
|
|
waiters.append(asyncio.create_task(wait(barrier)))
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
|
self.assertIn("draining", repr(barrier))
|
|
|
|
# add a part of waiters
|
|
for i in range(incr):
|
|
waiters.append(asyncio.create_task(wait(barrier)))
|
|
await asyncio.sleep(0)
|
|
# and reset
|
|
await barrier.reset()
|
|
|
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
|
self.assertIn("resetting", repr(barrier))
|
|
|
|
# add a part of waiters again
|
|
for i in range(incr):
|
|
waiters.append(asyncio.create_task(wait(barrier)))
|
|
await asyncio.sleep(0)
|
|
# and abort
|
|
await barrier.abort()
|
|
|
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
|
self.assertIn("broken", repr(barrier))
|
|
self.assertTrue(barrier.broken)
|
|
|
|
# suppress unhandled exceptions
|
|
await asyncio.gather(*waiters, return_exceptions=True)
|
|
|
|
async def test_barrier_parties(self):
|
|
self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
|
|
self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
|
|
|
|
self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
|
|
|
|
async def test_context_manager(self):
|
|
self.N = 3
|
|
barrier = asyncio.Barrier(self.N)
|
|
results = []
|
|
|
|
async def coro():
|
|
async with barrier as i:
|
|
results.append(i)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertListEqual(sorted(results), list(range(self.N)))
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_filling_one_task(self):
|
|
barrier = asyncio.Barrier(1)
|
|
|
|
async def f():
|
|
async with barrier as i:
|
|
return True
|
|
|
|
ret = await f()
|
|
|
|
self.assertTrue(ret)
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_filling_one_task_twice(self):
|
|
barrier = asyncio.Barrier(1)
|
|
|
|
t1 = asyncio.create_task(barrier.wait())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
|
|
t2 = asyncio.create_task(barrier.wait())
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(t1.result(), t2.result())
|
|
self.assertEqual(t1.done(), t2.done())
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_filling_task_by_task(self):
|
|
self.N = 3
|
|
barrier = asyncio.Barrier(self.N)
|
|
|
|
t1 = asyncio.create_task(barrier.wait())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 1)
|
|
self.assertIn("filling", repr(barrier))
|
|
|
|
t2 = asyncio.create_task(barrier.wait())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 2)
|
|
self.assertIn("filling", repr(barrier))
|
|
|
|
t3 = asyncio.create_task(barrier.wait())
|
|
await asyncio.sleep(0)
|
|
|
|
await asyncio.wait([t1, t2, t3])
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_filling_tasks_wait_twice(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results = []
|
|
|
|
async def coro():
|
|
async with barrier:
|
|
results.append(True)
|
|
|
|
async with barrier:
|
|
results.append(False)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertEqual(len(results), self.N*2)
|
|
self.assertEqual(results.count(True), self.N)
|
|
self.assertEqual(results.count(False), self.N)
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_filling_tasks_check_return_value(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
|
|
async def coro():
|
|
async with barrier:
|
|
results1.append(True)
|
|
|
|
async with barrier as i:
|
|
results2.append(True)
|
|
return i
|
|
|
|
res, _ = await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertEqual(len(results1), self.N)
|
|
self.assertTrue(all(results1))
|
|
self.assertEqual(len(results2), self.N)
|
|
self.assertTrue(all(results2))
|
|
self.assertListEqual(sorted(res), list(range(self.N)))
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_draining_state(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results = []
|
|
|
|
async def coro():
|
|
async with barrier:
|
|
# barrier state change to filling for the last task release
|
|
results.append("draining" in repr(barrier))
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertEqual(len(results), self.N)
|
|
self.assertEqual(results[-1], False)
|
|
self.assertTrue(all(results[:self.N-1]))
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_blocking_tasks_while_draining(self):
|
|
rewait = 2
|
|
barrier = asyncio.Barrier(self.N)
|
|
barrier_nowaiting = asyncio.Barrier(self.N - rewait)
|
|
results = []
|
|
rewait_n = rewait
|
|
counter = 0
|
|
|
|
async def coro():
|
|
nonlocal rewait_n
|
|
|
|
# first time waiting
|
|
await barrier.wait()
|
|
|
|
# after wainting once for all tasks
|
|
if rewait_n > 0:
|
|
rewait_n -= 1
|
|
# wait again only for rewait tasks
|
|
await barrier.wait()
|
|
else:
|
|
# wait for end of draining state`
|
|
await barrier_nowaiting.wait()
|
|
# wait for other waiting tasks
|
|
await barrier.wait()
|
|
|
|
# a success means that barrier_nowaiting
|
|
# was waited for exactly N-rewait=3 times
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
async def test_filling_tasks_cancel_one(self):
|
|
self.N = 3
|
|
barrier = asyncio.Barrier(self.N)
|
|
results = []
|
|
|
|
async def coro():
|
|
await barrier.wait()
|
|
results.append(True)
|
|
|
|
t1 = asyncio.create_task(coro())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 1)
|
|
|
|
t2 = asyncio.create_task(coro())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 2)
|
|
|
|
t1.cancel()
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 1)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await t1
|
|
self.assertTrue(t1.cancelled())
|
|
|
|
t3 = asyncio.create_task(coro())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(barrier.n_waiting, 2)
|
|
|
|
t4 = asyncio.create_task(coro())
|
|
await asyncio.gather(t2, t3, t4)
|
|
|
|
self.assertEqual(len(results), self.N)
|
|
self.assertTrue(all(results))
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_reset_barrier(self):
|
|
barrier = asyncio.Barrier(1)
|
|
|
|
asyncio.create_task(barrier.reset())
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_reset_barrier_while_tasks_waiting(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results = []
|
|
|
|
async def coro():
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
results.append(True)
|
|
|
|
async def coro_reset():
|
|
await barrier.reset()
|
|
|
|
# N-1 tasks waiting on barrier with N parties
|
|
tasks = self.make_tasks(self.N-1, coro)
|
|
await asyncio.sleep(0)
|
|
|
|
# reset the barrier
|
|
asyncio.create_task(coro_reset())
|
|
await asyncio.gather(*tasks)
|
|
|
|
self.assertEqual(len(results), self.N-1)
|
|
self.assertTrue(all(results))
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertNotIn("resetting", repr(barrier))
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_reset_barrier_when_tasks_half_draining(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
rest_of_tasks = self.N//2
|
|
|
|
async def coro():
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
# catch here waiting tasks
|
|
results1.append(True)
|
|
else:
|
|
# here drained task ouside the barrier
|
|
if rest_of_tasks == barrier._count:
|
|
# tasks outside the barrier
|
|
await barrier.reset()
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertEqual(results1, [True]*rest_of_tasks)
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertNotIn("resetting", repr(barrier))
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
blocking_tasks = self.N//2
|
|
count = 0
|
|
|
|
async def coro():
|
|
nonlocal count
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
# here catch still waiting tasks
|
|
results1.append(True)
|
|
|
|
# so now waiting again to reach nb_parties
|
|
await barrier.wait()
|
|
else:
|
|
count += 1
|
|
if count > blocking_tasks:
|
|
# reset now: raise asyncio.BrokenBarrierError for waiting tasks
|
|
await barrier.reset()
|
|
|
|
# so now waiting again to reach nb_parties
|
|
await barrier.wait()
|
|
else:
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
# here no catch - blocked tasks go to wait
|
|
results2.append(True)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertEqual(results1, [True]*blocking_tasks)
|
|
self.assertEqual(results2, [])
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertNotIn("resetting", repr(barrier))
|
|
self.assertFalse(barrier.broken)
|
|
|
|
async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
|
|
async def coro1():
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
results1.append(True)
|
|
finally:
|
|
await barrier.wait()
|
|
results2.append(True)
|
|
|
|
async def coro2():
|
|
async with barrier:
|
|
results2.append(True)
|
|
|
|
tasks = self.make_tasks(self.N-1, coro1)
|
|
|
|
# reset barrier, N-1 waiting tasks raise an BrokenBarrierError
|
|
asyncio.create_task(barrier.reset())
|
|
await asyncio.sleep(0)
|
|
|
|
# complete waiting tasks in the `finally`
|
|
asyncio.create_task(coro2())
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
self.assertFalse(barrier.broken)
|
|
self.assertEqual(len(results1), self.N-1)
|
|
self.assertTrue(all(results1))
|
|
self.assertEqual(len(results2), self.N)
|
|
self.assertTrue(all(results2))
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
|
|
|
|
async def test_reset_barrier_while_tasks_draining(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
results3 = []
|
|
count = 0
|
|
|
|
async def coro():
|
|
nonlocal count
|
|
|
|
i = await barrier.wait()
|
|
count += 1
|
|
if count == self.N:
|
|
# last task exited from barrier
|
|
await barrier.reset()
|
|
|
|
# wit here to reach the `parties`
|
|
await barrier.wait()
|
|
else:
|
|
try:
|
|
# second waiting
|
|
await barrier.wait()
|
|
|
|
# N-1 tasks here
|
|
results1.append(True)
|
|
except Exception as e:
|
|
# never goes here
|
|
results2.append(True)
|
|
|
|
# Now, pass the barrier again
|
|
# last wait, must be completed
|
|
k = await barrier.wait()
|
|
results3.append(True)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertFalse(barrier.broken)
|
|
self.assertTrue(all(results1))
|
|
self.assertEqual(len(results1), self.N-1)
|
|
self.assertEqual(len(results2), 0)
|
|
self.assertEqual(len(results3), self.N)
|
|
self.assertTrue(all(results3))
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
|
|
async def test_abort_barrier(self):
|
|
barrier = asyncio.Barrier(1)
|
|
|
|
asyncio.create_task(barrier.abort())
|
|
await asyncio.sleep(0)
|
|
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertTrue(barrier.broken)
|
|
|
|
async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
blocking_tasks = self.N//2
|
|
count = 0
|
|
|
|
async def coro():
|
|
nonlocal count
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
# here catch tasks waiting to drain
|
|
results1.append(True)
|
|
else:
|
|
count += 1
|
|
if count > blocking_tasks:
|
|
# abort now: raise asyncio.BrokenBarrierError for all tasks
|
|
await barrier.abort()
|
|
else:
|
|
try:
|
|
await barrier.wait()
|
|
except asyncio.BrokenBarrierError:
|
|
# here catch blocked tasks (already drained)
|
|
results2.append(True)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertTrue(barrier.broken)
|
|
self.assertEqual(results1, [True]*blocking_tasks)
|
|
self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
self.assertNotIn("resetting", repr(barrier))
|
|
|
|
async def test_abort_barrier_when_exception(self):
|
|
# test from threading.Barrier: see `lock_tests.test_reset`
|
|
barrier = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
|
|
async def coro():
|
|
try:
|
|
async with barrier as i :
|
|
if i == self.N//2:
|
|
raise RuntimeError
|
|
async with barrier:
|
|
results1.append(True)
|
|
except asyncio.BrokenBarrierError:
|
|
results2.append(True)
|
|
except RuntimeError:
|
|
await barrier.abort()
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertTrue(barrier.broken)
|
|
self.assertEqual(len(results1), 0)
|
|
self.assertEqual(len(results2), self.N-1)
|
|
self.assertTrue(all(results2))
|
|
self.assertEqual(barrier.n_waiting, 0)
|
|
|
|
async def test_abort_barrier_when_exception_then_resetting(self):
|
|
# test from threading.Barrier: see `lock_tests.test_abort_and_reset``
|
|
barrier1 = asyncio.Barrier(self.N)
|
|
barrier2 = asyncio.Barrier(self.N)
|
|
results1 = []
|
|
results2 = []
|
|
results3 = []
|
|
|
|
async def coro():
|
|
try:
|
|
i = await barrier1.wait()
|
|
if i == self.N//2:
|
|
raise RuntimeError
|
|
await barrier1.wait()
|
|
results1.append(True)
|
|
except asyncio.BrokenBarrierError:
|
|
results2.append(True)
|
|
except RuntimeError:
|
|
await barrier1.abort()
|
|
|
|
# Synchronize and reset the barrier. Must synchronize first so
|
|
# that everyone has left it when we reset, and after so that no
|
|
# one enters it before the reset.
|
|
i = await barrier2.wait()
|
|
if i == self.N//2:
|
|
await barrier1.reset()
|
|
await barrier2.wait()
|
|
await barrier1.wait()
|
|
results3.append(True)
|
|
|
|
await self.gather_tasks(self.N, coro)
|
|
|
|
self.assertFalse(barrier1.broken)
|
|
self.assertEqual(len(results1), 0)
|
|
self.assertEqual(len(results2), self.N-1)
|
|
self.assertTrue(all(results2))
|
|
self.assertEqual(len(results3), self.N)
|
|
self.assertTrue(all(results3))
|
|
|
|
self.assertEqual(barrier1.n_waiting, 0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|