mirror of
https://github.com/python/cpython.git
synced 2025-08-30 13:38:43 +00:00
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
20e6e5636a
commit
d03acd7270
6 changed files with 856 additions and 5 deletions
|
@ -1,4 +1,4 @@
|
|||
"""Tests for lock.py"""
|
||||
"""Tests for locks.py"""
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
@ -9,7 +9,10 @@ import asyncio
|
|||
STR_RGX_REPR = (
|
||||
r'^<(?P<class>.*?) object at (?P<address>.*?)'
|
||||
r'\[(?P<extras>'
|
||||
r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
|
||||
r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
|
||||
r'(, value:\d)?'
|
||||
r'(, waiters:\d+)?'
|
||||
r'(, waiters:\d+\/\d+)?' # barrier
|
||||
r')\]>\Z'
|
||||
)
|
||||
RGX_REPR = re.compile(STR_RGX_REPR)
|
||||
|
@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
class BarrierTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.N = 5
|
||||
|
||||
def make_tasks(self, n, coro):
|
||||
tasks = [asyncio.create_task(coro()) for _ in range(n)]
|
||||
return tasks
|
||||
|
||||
async def gather_tasks(self, n, coro):
|
||||
tasks = self.make_tasks(n, coro)
|
||||
res = await asyncio.gather(*tasks)
|
||||
return res, tasks
|
||||
|
||||
async def test_barrier(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"object Barrier can't be used in 'await' expression",
|
||||
):
|
||||
await barrier
|
||||
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
async def test_repr(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
waiters = []
|
||||
async def wait(barrier):
|
||||
await barrier.wait()
|
||||
|
||||
incr = 2
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
# create missing waiters
|
||||
for i in range(barrier.parties - barrier.n_waiting):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("draining", repr(barrier))
|
||||
|
||||
# add a part of waiters
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
# and reset
|
||||
await barrier.reset()
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("resetting", repr(barrier))
|
||||
|
||||
# add a part of waiters again
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
# and abort
|
||||
await barrier.abort()
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("broken", repr(barrier))
|
||||
self.assertTrue(barrier.broken)
|
||||
|
||||
# suppress unhandled exceptions
|
||||
await asyncio.gather(*waiters, return_exceptions=True)
|
||||
|
||||
async def test_barrier_parties(self):
|
||||
self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
|
||||
self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
|
||||
|
||||
self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
|
||||
|
||||
async def test_context_manager(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier as i:
|
||||
results.append(i)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertListEqual(sorted(results), list(range(self.N)))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_one_task(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
async def f():
|
||||
async with barrier as i:
|
||||
return True
|
||||
|
||||
ret = await f()
|
||||
|
||||
self.assertTrue(ret)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_one_task_twice(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
t1 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
t2 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(t1.result(), t2.result())
|
||||
self.assertEqual(t1.done(), t2.done())
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_task_by_task(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
|
||||
t1 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
t2 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
t3 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await asyncio.wait([t1, t2, t3])
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_tasks_wait_twice(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
results.append(True)
|
||||
|
||||
async with barrier:
|
||||
results.append(False)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results), self.N*2)
|
||||
self.assertEqual(results.count(True), self.N)
|
||||
self.assertEqual(results.count(False), self.N)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_tasks_check_return_value(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
results1.append(True)
|
||||
|
||||
async with barrier as i:
|
||||
results2.append(True)
|
||||
return i
|
||||
|
||||
res, _ = await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results1), self.N)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results2), self.N)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertListEqual(sorted(res), list(range(self.N)))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_draining_state(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
# barrier state change to filling for the last task release
|
||||
results.append("draining" in repr(barrier))
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results), self.N)
|
||||
self.assertEqual(results[-1], False)
|
||||
self.assertTrue(all(results[:self.N-1]))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_blocking_tasks_while_draining(self):
|
||||
rewait = 2
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
barrier_nowaiting = asyncio.Barrier(self.N - rewait)
|
||||
results = []
|
||||
rewait_n = rewait
|
||||
counter = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal rewait_n
|
||||
|
||||
# first time waiting
|
||||
await barrier.wait()
|
||||
|
||||
# after wainting once for all tasks
|
||||
if rewait_n > 0:
|
||||
rewait_n -= 1
|
||||
# wait again only for rewait tasks
|
||||
await barrier.wait()
|
||||
else:
|
||||
# wait for end of draining state`
|
||||
await barrier_nowaiting.wait()
|
||||
# wait for other waiting tasks
|
||||
await barrier.wait()
|
||||
|
||||
# a success means that barrier_nowaiting
|
||||
# was waited for exactly N-rewait=3 times
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
async def test_filling_tasks_cancel_one(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
await barrier.wait()
|
||||
results.append(True)
|
||||
|
||||
t1 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
|
||||
t2 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
|
||||
t1.cancel()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await t1
|
||||
self.assertTrue(t1.cancelled())
|
||||
|
||||
t3 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
|
||||
t4 = asyncio.create_task(coro())
|
||||
await asyncio.gather(t2, t3, t4)
|
||||
|
||||
self.assertEqual(len(results), self.N)
|
||||
self.assertTrue(all(results))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
asyncio.create_task(barrier.reset())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_while_tasks_waiting(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
results.append(True)
|
||||
|
||||
async def coro_reset():
|
||||
await barrier.reset()
|
||||
|
||||
# N-1 tasks waiting on barrier with N parties
|
||||
tasks = self.make_tasks(self.N-1, coro)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# reset the barrier
|
||||
asyncio.create_task(coro_reset())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.assertEqual(len(results), self.N-1)
|
||||
self.assertTrue(all(results))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_when_tasks_half_draining(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
rest_of_tasks = self.N//2
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# catch here waiting tasks
|
||||
results1.append(True)
|
||||
else:
|
||||
# here drained task ouside the barrier
|
||||
if rest_of_tasks == barrier._count:
|
||||
# tasks outside the barrier
|
||||
await barrier.reset()
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(results1, [True]*rest_of_tasks)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
blocking_tasks = self.N//2
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch still waiting tasks
|
||||
results1.append(True)
|
||||
|
||||
# so now waiting again to reach nb_parties
|
||||
await barrier.wait()
|
||||
else:
|
||||
count += 1
|
||||
if count > blocking_tasks:
|
||||
# reset now: raise asyncio.BrokenBarrierError for waiting tasks
|
||||
await barrier.reset()
|
||||
|
||||
# so now waiting again to reach nb_parties
|
||||
await barrier.wait()
|
||||
else:
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here no catch - blocked tasks go to wait
|
||||
results2.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(results1, [True]*blocking_tasks)
|
||||
self.assertEqual(results2, [])
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro1():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
results1.append(True)
|
||||
finally:
|
||||
await barrier.wait()
|
||||
results2.append(True)
|
||||
|
||||
async def coro2():
|
||||
async with barrier:
|
||||
results2.append(True)
|
||||
|
||||
tasks = self.make_tasks(self.N-1, coro1)
|
||||
|
||||
# reset barrier, N-1 waiting tasks raise an BrokenBarrierError
|
||||
asyncio.create_task(barrier.reset())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# complete waiting tasks in the `finally`
|
||||
asyncio.create_task(coro2())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.assertFalse(barrier.broken)
|
||||
self.assertEqual(len(results1), self.N-1)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results2), self.N)
|
||||
self.assertTrue(all(results2))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
|
||||
async def test_reset_barrier_while_tasks_draining(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
|
||||
i = await barrier.wait()
|
||||
count += 1
|
||||
if count == self.N:
|
||||
# last task exited from barrier
|
||||
await barrier.reset()
|
||||
|
||||
# wit here to reach the `parties`
|
||||
await barrier.wait()
|
||||
else:
|
||||
try:
|
||||
# second waiting
|
||||
await barrier.wait()
|
||||
|
||||
# N-1 tasks here
|
||||
results1.append(True)
|
||||
except Exception as e:
|
||||
# never goes here
|
||||
results2.append(True)
|
||||
|
||||
# Now, pass the barrier again
|
||||
# last wait, must be completed
|
||||
k = await barrier.wait()
|
||||
results3.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertFalse(barrier.broken)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results1), self.N-1)
|
||||
self.assertEqual(len(results2), 0)
|
||||
self.assertEqual(len(results3), self.N)
|
||||
self.assertTrue(all(results3))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
async def test_abort_barrier(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
asyncio.create_task(barrier.abort())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertTrue(barrier.broken)
|
||||
|
||||
async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
blocking_tasks = self.N//2
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch tasks waiting to drain
|
||||
results1.append(True)
|
||||
else:
|
||||
count += 1
|
||||
if count > blocking_tasks:
|
||||
# abort now: raise asyncio.BrokenBarrierError for all tasks
|
||||
await barrier.abort()
|
||||
else:
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch blocked tasks (already drained)
|
||||
results2.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertTrue(barrier.broken)
|
||||
self.assertEqual(results1, [True]*blocking_tasks)
|
||||
self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
|
||||
async def test_abort_barrier_when_exception(self):
|
||||
# test from threading.Barrier: see `lock_tests.test_reset`
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
async with barrier as i :
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
async with barrier:
|
||||
results1.append(True)
|
||||
except asyncio.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
await barrier.abort()
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertTrue(barrier.broken)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
async def test_abort_barrier_when_exception_then_resetting(self):
|
||||
# test from threading.Barrier: see `lock_tests.test_abort_and_reset``
|
||||
barrier1 = asyncio.Barrier(self.N)
|
||||
barrier2 = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
i = await barrier1.wait()
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
await barrier1.wait()
|
||||
results1.append(True)
|
||||
except asyncio.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
await barrier1.abort()
|
||||
|
||||
# Synchronize and reset the barrier. Must synchronize first so
|
||||
# that everyone has left it when we reset, and after so that no
|
||||
# one enters it before the reset.
|
||||
i = await barrier2.wait()
|
||||
if i == self.N//2:
|
||||
await barrier1.reset()
|
||||
await barrier2.wait()
|
||||
await barrier1.wait()
|
||||
results3.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertFalse(barrier1.broken)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertEqual(len(results3), self.N)
|
||||
self.assertTrue(all(results3))
|
||||
|
||||
self.assertEqual(barrier1.n_waiting, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue