bpo-43352: Add a Barrier object in asyncio lib (GH-24903)

Co-authored-by: Yury Selivanov <yury@edgedb.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Duprat 2022-03-25 23:01:21 +01:00 committed by GitHub
parent 20e6e5636a
commit d03acd7270
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 856 additions and 5 deletions

View file

@ -1,4 +1,4 @@
"""Tests for lock.py"""
"""Tests for locks.py"""
import unittest
from unittest import mock
@ -9,7 +9,10 @@ import asyncio
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
r'(, value:\d)?'
r'(, waiters:\d+)?'
r'(, waiters:\d+\/\d+)?' # barrier
r')\]>\Z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
)
class BarrierTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.N = 5
def make_tasks(self, n, coro):
tasks = [asyncio.create_task(coro()) for _ in range(n)]
return tasks
async def gather_tasks(self, n, coro):
tasks = self.make_tasks(n, coro)
res = await asyncio.gather(*tasks)
return res, tasks
async def test_barrier(self):
barrier = asyncio.Barrier(self.N)
self.assertIn("filling", repr(barrier))
with self.assertRaisesRegex(
TypeError,
"object Barrier can't be used in 'await' expression",
):
await barrier
self.assertIn("filling", repr(barrier))
async def test_repr(self):
barrier = asyncio.Barrier(self.N)
self.assertTrue(RGX_REPR.match(repr(barrier)))
self.assertIn("filling", repr(barrier))
waiters = []
async def wait(barrier):
await barrier.wait()
incr = 2
for i in range(incr):
waiters.append(asyncio.create_task(wait(barrier)))
await asyncio.sleep(0)
self.assertTrue(RGX_REPR.match(repr(barrier)))
self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
self.assertIn("filling", repr(barrier))
# create missing waiters
for i in range(barrier.parties - barrier.n_waiting):
waiters.append(asyncio.create_task(wait(barrier)))
await asyncio.sleep(0)
self.assertTrue(RGX_REPR.match(repr(barrier)))
self.assertIn("draining", repr(barrier))
# add a part of waiters
for i in range(incr):
waiters.append(asyncio.create_task(wait(barrier)))
await asyncio.sleep(0)
# and reset
await barrier.reset()
self.assertTrue(RGX_REPR.match(repr(barrier)))
self.assertIn("resetting", repr(barrier))
# add a part of waiters again
for i in range(incr):
waiters.append(asyncio.create_task(wait(barrier)))
await asyncio.sleep(0)
# and abort
await barrier.abort()
self.assertTrue(RGX_REPR.match(repr(barrier)))
self.assertIn("broken", repr(barrier))
self.assertTrue(barrier.broken)
# suppress unhandled exceptions
await asyncio.gather(*waiters, return_exceptions=True)
async def test_barrier_parties(self):
self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
async def test_context_manager(self):
self.N = 3
barrier = asyncio.Barrier(self.N)
results = []
async def coro():
async with barrier as i:
results.append(i)
await self.gather_tasks(self.N, coro)
self.assertListEqual(sorted(results), list(range(self.N)))
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_filling_one_task(self):
barrier = asyncio.Barrier(1)
async def f():
async with barrier as i:
return True
ret = await f()
self.assertTrue(ret)
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_filling_one_task_twice(self):
barrier = asyncio.Barrier(1)
t1 = asyncio.create_task(barrier.wait())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 0)
t2 = asyncio.create_task(barrier.wait())
await asyncio.sleep(0)
self.assertEqual(t1.result(), t2.result())
self.assertEqual(t1.done(), t2.done())
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_filling_task_by_task(self):
self.N = 3
barrier = asyncio.Barrier(self.N)
t1 = asyncio.create_task(barrier.wait())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 1)
self.assertIn("filling", repr(barrier))
t2 = asyncio.create_task(barrier.wait())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 2)
self.assertIn("filling", repr(barrier))
t3 = asyncio.create_task(barrier.wait())
await asyncio.sleep(0)
await asyncio.wait([t1, t2, t3])
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_filling_tasks_wait_twice(self):
barrier = asyncio.Barrier(self.N)
results = []
async def coro():
async with barrier:
results.append(True)
async with barrier:
results.append(False)
await self.gather_tasks(self.N, coro)
self.assertEqual(len(results), self.N*2)
self.assertEqual(results.count(True), self.N)
self.assertEqual(results.count(False), self.N)
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_filling_tasks_check_return_value(self):
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
async def coro():
async with barrier:
results1.append(True)
async with barrier as i:
results2.append(True)
return i
res, _ = await self.gather_tasks(self.N, coro)
self.assertEqual(len(results1), self.N)
self.assertTrue(all(results1))
self.assertEqual(len(results2), self.N)
self.assertTrue(all(results2))
self.assertListEqual(sorted(res), list(range(self.N)))
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_draining_state(self):
barrier = asyncio.Barrier(self.N)
results = []
async def coro():
async with barrier:
# barrier state change to filling for the last task release
results.append("draining" in repr(barrier))
await self.gather_tasks(self.N, coro)
self.assertEqual(len(results), self.N)
self.assertEqual(results[-1], False)
self.assertTrue(all(results[:self.N-1]))
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_blocking_tasks_while_draining(self):
rewait = 2
barrier = asyncio.Barrier(self.N)
barrier_nowaiting = asyncio.Barrier(self.N - rewait)
results = []
rewait_n = rewait
counter = 0
async def coro():
nonlocal rewait_n
# first time waiting
await barrier.wait()
# after wainting once for all tasks
if rewait_n > 0:
rewait_n -= 1
# wait again only for rewait tasks
await barrier.wait()
else:
# wait for end of draining state`
await barrier_nowaiting.wait()
# wait for other waiting tasks
await barrier.wait()
# a success means that barrier_nowaiting
# was waited for exactly N-rewait=3 times
await self.gather_tasks(self.N, coro)
async def test_filling_tasks_cancel_one(self):
self.N = 3
barrier = asyncio.Barrier(self.N)
results = []
async def coro():
await barrier.wait()
results.append(True)
t1 = asyncio.create_task(coro())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 1)
t2 = asyncio.create_task(coro())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 2)
t1.cancel()
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 1)
with self.assertRaises(asyncio.CancelledError):
await t1
self.assertTrue(t1.cancelled())
t3 = asyncio.create_task(coro())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 2)
t4 = asyncio.create_task(coro())
await asyncio.gather(t2, t3, t4)
self.assertEqual(len(results), self.N)
self.assertTrue(all(results))
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_reset_barrier(self):
barrier = asyncio.Barrier(1)
asyncio.create_task(barrier.reset())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 0)
self.assertFalse(barrier.broken)
async def test_reset_barrier_while_tasks_waiting(self):
barrier = asyncio.Barrier(self.N)
results = []
async def coro():
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
results.append(True)
async def coro_reset():
await barrier.reset()
# N-1 tasks waiting on barrier with N parties
tasks = self.make_tasks(self.N-1, coro)
await asyncio.sleep(0)
# reset the barrier
asyncio.create_task(coro_reset())
await asyncio.gather(*tasks)
self.assertEqual(len(results), self.N-1)
self.assertTrue(all(results))
self.assertEqual(barrier.n_waiting, 0)
self.assertNotIn("resetting", repr(barrier))
self.assertFalse(barrier.broken)
async def test_reset_barrier_when_tasks_half_draining(self):
barrier = asyncio.Barrier(self.N)
results1 = []
rest_of_tasks = self.N//2
async def coro():
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
# catch here waiting tasks
results1.append(True)
else:
# here drained task ouside the barrier
if rest_of_tasks == barrier._count:
# tasks outside the barrier
await barrier.reset()
await self.gather_tasks(self.N, coro)
self.assertEqual(results1, [True]*rest_of_tasks)
self.assertEqual(barrier.n_waiting, 0)
self.assertNotIn("resetting", repr(barrier))
self.assertFalse(barrier.broken)
async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
blocking_tasks = self.N//2
count = 0
async def coro():
nonlocal count
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
# here catch still waiting tasks
results1.append(True)
# so now waiting again to reach nb_parties
await barrier.wait()
else:
count += 1
if count > blocking_tasks:
# reset now: raise asyncio.BrokenBarrierError for waiting tasks
await barrier.reset()
# so now waiting again to reach nb_parties
await barrier.wait()
else:
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
# here no catch - blocked tasks go to wait
results2.append(True)
await self.gather_tasks(self.N, coro)
self.assertEqual(results1, [True]*blocking_tasks)
self.assertEqual(results2, [])
self.assertEqual(barrier.n_waiting, 0)
self.assertNotIn("resetting", repr(barrier))
self.assertFalse(barrier.broken)
async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
async def coro1():
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
results1.append(True)
finally:
await barrier.wait()
results2.append(True)
async def coro2():
async with barrier:
results2.append(True)
tasks = self.make_tasks(self.N-1, coro1)
# reset barrier, N-1 waiting tasks raise an BrokenBarrierError
asyncio.create_task(barrier.reset())
await asyncio.sleep(0)
# complete waiting tasks in the `finally`
asyncio.create_task(coro2())
await asyncio.gather(*tasks)
self.assertFalse(barrier.broken)
self.assertEqual(len(results1), self.N-1)
self.assertTrue(all(results1))
self.assertEqual(len(results2), self.N)
self.assertTrue(all(results2))
self.assertEqual(barrier.n_waiting, 0)
async def test_reset_barrier_while_tasks_draining(self):
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
results3 = []
count = 0
async def coro():
nonlocal count
i = await barrier.wait()
count += 1
if count == self.N:
# last task exited from barrier
await barrier.reset()
# wit here to reach the `parties`
await barrier.wait()
else:
try:
# second waiting
await barrier.wait()
# N-1 tasks here
results1.append(True)
except Exception as e:
# never goes here
results2.append(True)
# Now, pass the barrier again
# last wait, must be completed
k = await barrier.wait()
results3.append(True)
await self.gather_tasks(self.N, coro)
self.assertFalse(barrier.broken)
self.assertTrue(all(results1))
self.assertEqual(len(results1), self.N-1)
self.assertEqual(len(results2), 0)
self.assertEqual(len(results3), self.N)
self.assertTrue(all(results3))
self.assertEqual(barrier.n_waiting, 0)
async def test_abort_barrier(self):
barrier = asyncio.Barrier(1)
asyncio.create_task(barrier.abort())
await asyncio.sleep(0)
self.assertEqual(barrier.n_waiting, 0)
self.assertTrue(barrier.broken)
async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
blocking_tasks = self.N//2
count = 0
async def coro():
nonlocal count
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
# here catch tasks waiting to drain
results1.append(True)
else:
count += 1
if count > blocking_tasks:
# abort now: raise asyncio.BrokenBarrierError for all tasks
await barrier.abort()
else:
try:
await barrier.wait()
except asyncio.BrokenBarrierError:
# here catch blocked tasks (already drained)
results2.append(True)
await self.gather_tasks(self.N, coro)
self.assertTrue(barrier.broken)
self.assertEqual(results1, [True]*blocking_tasks)
self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
self.assertEqual(barrier.n_waiting, 0)
self.assertNotIn("resetting", repr(barrier))
async def test_abort_barrier_when_exception(self):
# test from threading.Barrier: see `lock_tests.test_reset`
barrier = asyncio.Barrier(self.N)
results1 = []
results2 = []
async def coro():
try:
async with barrier as i :
if i == self.N//2:
raise RuntimeError
async with barrier:
results1.append(True)
except asyncio.BrokenBarrierError:
results2.append(True)
except RuntimeError:
await barrier.abort()
await self.gather_tasks(self.N, coro)
self.assertTrue(barrier.broken)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertTrue(all(results2))
self.assertEqual(barrier.n_waiting, 0)
async def test_abort_barrier_when_exception_then_resetting(self):
# test from threading.Barrier: see `lock_tests.test_abort_and_reset``
barrier1 = asyncio.Barrier(self.N)
barrier2 = asyncio.Barrier(self.N)
results1 = []
results2 = []
results3 = []
async def coro():
try:
i = await barrier1.wait()
if i == self.N//2:
raise RuntimeError
await barrier1.wait()
results1.append(True)
except asyncio.BrokenBarrierError:
results2.append(True)
except RuntimeError:
await barrier1.abort()
# Synchronize and reset the barrier. Must synchronize first so
# that everyone has left it when we reset, and after so that no
# one enters it before the reset.
i = await barrier2.wait()
if i == self.N//2:
await barrier1.reset()
await barrier2.wait()
await barrier1.wait()
results3.append(True)
await self.gather_tasks(self.N, coro)
self.assertFalse(barrier1.broken)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertTrue(all(results2))
self.assertEqual(len(results3), self.N)
self.assertTrue(all(results3))
self.assertEqual(barrier1.n_waiting, 0)
if __name__ == '__main__':
unittest.main()