mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
20e6e5636a
commit
d03acd7270
6 changed files with 856 additions and 5 deletions
|
@ -186,11 +186,16 @@ Threading-like synchronization primitives that can be used in Tasks.
|
||||||
* - :class:`BoundedSemaphore`
|
* - :class:`BoundedSemaphore`
|
||||||
- A bounded semaphore.
|
- A bounded semaphore.
|
||||||
|
|
||||||
|
* - :class:`Barrier`
|
||||||
|
- A barrier object.
|
||||||
|
|
||||||
|
|
||||||
.. rubric:: Examples
|
.. rubric:: Examples
|
||||||
|
|
||||||
* :ref:`Using asyncio.Event <asyncio_example_sync_event>`.
|
* :ref:`Using asyncio.Event <asyncio_example_sync_event>`.
|
||||||
|
|
||||||
|
* :ref:`Using asyncio.Barrier <asyncio_example_barrier>`.
|
||||||
|
|
||||||
* See also the documentation of asyncio
|
* See also the documentation of asyncio
|
||||||
:ref:`synchronization primitives <asyncio-sync>`.
|
:ref:`synchronization primitives <asyncio-sync>`.
|
||||||
|
|
||||||
|
@ -206,6 +211,9 @@ Exceptions
|
||||||
* - :exc:`asyncio.CancelledError`
|
* - :exc:`asyncio.CancelledError`
|
||||||
- Raised when a Task is cancelled. See also :meth:`Task.cancel`.
|
- Raised when a Task is cancelled. See also :meth:`Task.cancel`.
|
||||||
|
|
||||||
|
* - :exc:`asyncio.BrokenBarrierError`
|
||||||
|
- Raised when a Barrier is broken. See also :meth:`Barrier.wait`.
|
||||||
|
|
||||||
|
|
||||||
.. rubric:: Examples
|
.. rubric:: Examples
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ asyncio has the following basic synchronization primitives:
|
||||||
* :class:`Condition`
|
* :class:`Condition`
|
||||||
* :class:`Semaphore`
|
* :class:`Semaphore`
|
||||||
* :class:`BoundedSemaphore`
|
* :class:`BoundedSemaphore`
|
||||||
|
* :class:`Barrier`
|
||||||
|
|
||||||
|
|
||||||
---------
|
---------
|
||||||
|
@ -340,6 +341,115 @@ BoundedSemaphore
|
||||||
.. versionchanged:: 3.10
|
.. versionchanged:: 3.10
|
||||||
Removed the *loop* parameter.
|
Removed the *loop* parameter.
|
||||||
|
|
||||||
|
|
||||||
|
Barrier
|
||||||
|
=======
|
||||||
|
|
||||||
|
.. class:: Barrier(parties, action=None)
|
||||||
|
|
||||||
|
A barrier object. Not thread-safe.
|
||||||
|
|
||||||
|
A barrier is a simple synchronization primitive that allows to block until
|
||||||
|
*parties* number of tasks are waiting on it.
|
||||||
|
Tasks can wait on the :meth:`~Barrier.wait` method and would be blocked until
|
||||||
|
the specified number of tasks end up waiting on :meth:`~Barrier.wait`.
|
||||||
|
At that point all of the waiting tasks would unblock simultaneously.
|
||||||
|
|
||||||
|
:keyword:`async with` can be used as an alternative to awaiting on
|
||||||
|
:meth:`~Barrier.wait`.
|
||||||
|
|
||||||
|
The barrier can be reused any number of times.
|
||||||
|
|
||||||
|
.. _asyncio_example_barrier:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
async def example_barrier():
|
||||||
|
# barrier with 3 parties
|
||||||
|
b = asyncio.Barrier(3)
|
||||||
|
|
||||||
|
# create 2 new waiting tasks
|
||||||
|
asyncio.create_task(b.wait())
|
||||||
|
asyncio.create_task(b.wait())
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
print(b)
|
||||||
|
|
||||||
|
# The third .wait() call passes the barrier
|
||||||
|
await b.wait()
|
||||||
|
print(b)
|
||||||
|
print("barrier passed")
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
print(b)
|
||||||
|
|
||||||
|
asyncio.run(example_barrier())
|
||||||
|
|
||||||
|
Result of this example is::
|
||||||
|
|
||||||
|
<asyncio.locks.Barrier object at 0x... [filling, waiters:2/3]>
|
||||||
|
<asyncio.locks.Barrier object at 0x... [draining, waiters:0/3]>
|
||||||
|
barrier passed
|
||||||
|
<asyncio.locks.Barrier object at 0x... [filling, waiters:0/3]>
|
||||||
|
|
||||||
|
.. versionadded:: 3.11
|
||||||
|
|
||||||
|
.. coroutinemethod:: wait()
|
||||||
|
|
||||||
|
Pass the barrier. When all the tasks party to the barrier have called
|
||||||
|
this function, they are all unblocked simultaneously.
|
||||||
|
|
||||||
|
When a waiting or blocked task in the barrier is cancelled,
|
||||||
|
this task exits the barrier which stays in the same state.
|
||||||
|
If the state of the barrier is "filling", the number of waiting task
|
||||||
|
decreases by 1.
|
||||||
|
|
||||||
|
The return value is an integer in the range of 0 to ``parties-1``, different
|
||||||
|
for each task. This can be used to select a task to do some special
|
||||||
|
housekeeping, e.g.::
|
||||||
|
|
||||||
|
...
|
||||||
|
async with barrier as position:
|
||||||
|
if position == 0:
|
||||||
|
# Only one task print this
|
||||||
|
print('End of *draining phasis*')
|
||||||
|
|
||||||
|
This method may raise a :class:`BrokenBarrierError` exception if the
|
||||||
|
barrier is broken or reset while a task is waiting.
|
||||||
|
It could raise a :exc:`CancelledError` if a task is cancelled.
|
||||||
|
|
||||||
|
.. coroutinemethod:: reset()
|
||||||
|
|
||||||
|
Return the barrier to the default, empty state. Any tasks waiting on it
|
||||||
|
will receive the :class:`BrokenBarrierError` exception.
|
||||||
|
|
||||||
|
If a barrier is broken it may be better to just leave it and create a new one.
|
||||||
|
|
||||||
|
.. coroutinemethod:: abort()
|
||||||
|
|
||||||
|
Put the barrier into a broken state. This causes any active or future
|
||||||
|
calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`.
|
||||||
|
Use this for example if one of the taks needs to abort, to avoid infinite
|
||||||
|
waiting tasks.
|
||||||
|
|
||||||
|
.. attribute:: parties
|
||||||
|
|
||||||
|
The number of tasks required to pass the barrier.
|
||||||
|
|
||||||
|
.. attribute:: n_waiting
|
||||||
|
|
||||||
|
The number of tasks currently waiting in the barrier while filling.
|
||||||
|
|
||||||
|
.. attribute:: broken
|
||||||
|
|
||||||
|
A boolean that is ``True`` if the barrier is in the broken state.
|
||||||
|
|
||||||
|
|
||||||
|
.. exception:: BrokenBarrierError
|
||||||
|
|
||||||
|
This exception, a subclass of :exc:`RuntimeError`, is raised when the
|
||||||
|
:class:`Barrier` object is reset or broken.
|
||||||
|
|
||||||
---------
|
---------
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
"""asyncio exceptions."""
|
"""asyncio exceptions."""
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError',
|
__all__ = ('BrokenBarrierError',
|
||||||
|
'CancelledError', 'InvalidStateError', 'TimeoutError',
|
||||||
'IncompleteReadError', 'LimitOverrunError',
|
'IncompleteReadError', 'LimitOverrunError',
|
||||||
'SendfileNotAvailableError')
|
'SendfileNotAvailableError')
|
||||||
|
|
||||||
|
@ -55,3 +56,7 @@ class LimitOverrunError(Exception):
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return type(self), (self.args[0], self.consumed)
|
return type(self), (self.args[0], self.consumed)
|
||||||
|
|
||||||
|
|
||||||
|
class BrokenBarrierError(RuntimeError):
|
||||||
|
"""Barrier is broken by barrier.abort() call."""
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
"""Synchronization primitives."""
|
"""Synchronization primitives."""
|
||||||
|
|
||||||
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
|
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
|
||||||
|
'BoundedSemaphore', 'Barrier')
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import enum
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from . import mixins
|
from . import mixins
|
||||||
from . import tasks
|
from . import tasks
|
||||||
|
|
||||||
|
|
||||||
class _ContextManagerMixin:
|
class _ContextManagerMixin:
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
await self.acquire()
|
await self.acquire()
|
||||||
|
@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
|
||||||
if self._value >= self._bound_value:
|
if self._value >= self._bound_value:
|
||||||
raise ValueError('BoundedSemaphore released too many times')
|
raise ValueError('BoundedSemaphore released too many times')
|
||||||
super().release()
|
super().release()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class _BarrierState(enum.Enum):
|
||||||
|
FILLING = 'filling'
|
||||||
|
DRAINING = 'draining'
|
||||||
|
RESETTING = 'resetting'
|
||||||
|
BROKEN = 'broken'
|
||||||
|
|
||||||
|
|
||||||
|
class Barrier(mixins._LoopBoundMixin):
|
||||||
|
"""Asyncio equivalent to threading.Barrier
|
||||||
|
|
||||||
|
Implements a Barrier primitive.
|
||||||
|
Useful for synchronizing a fixed number of tasks at known synchronization
|
||||||
|
points. Tasks block on 'wait()' and are simultaneously awoken once they
|
||||||
|
have all made their call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parties):
|
||||||
|
"""Create a barrier, initialised to 'parties' tasks."""
|
||||||
|
if parties < 1:
|
||||||
|
raise ValueError('parties must be > 0')
|
||||||
|
|
||||||
|
self._cond = Condition() # notify all tasks when state changes
|
||||||
|
|
||||||
|
self._parties = parties
|
||||||
|
self._state = _BarrierState.FILLING
|
||||||
|
self._count = 0 # count tasks in Barrier
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
res = super().__repr__()
|
||||||
|
extra = f'{self._state.value}'
|
||||||
|
if not self.broken:
|
||||||
|
extra += f', waiters:{self.n_waiting}/{self.parties}'
|
||||||
|
return f'<{res[1:-1]} [{extra}]>'
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
# wait for the barrier reaches the parties number
|
||||||
|
# when start draining release and return index of waited task
|
||||||
|
return await self.wait()
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""Wait for the barrier.
|
||||||
|
|
||||||
|
When the specified number of tasks have started waiting, they are all
|
||||||
|
simultaneously awoken.
|
||||||
|
Returns an unique and individual index number from 0 to 'parties-1'.
|
||||||
|
"""
|
||||||
|
async with self._cond:
|
||||||
|
await self._block() # Block while the barrier drains or resets.
|
||||||
|
try:
|
||||||
|
index = self._count
|
||||||
|
self._count += 1
|
||||||
|
if index + 1 == self._parties:
|
||||||
|
# We release the barrier
|
||||||
|
await self._release()
|
||||||
|
else:
|
||||||
|
await self._wait()
|
||||||
|
return index
|
||||||
|
finally:
|
||||||
|
self._count -= 1
|
||||||
|
# Wake up any tasks waiting for barrier to drain.
|
||||||
|
self._exit()
|
||||||
|
|
||||||
|
async def _block(self):
|
||||||
|
# Block until the barrier is ready for us,
|
||||||
|
# or raise an exception if it is broken.
|
||||||
|
#
|
||||||
|
# It is draining or resetting, wait until done
|
||||||
|
# unless a CancelledError occurs
|
||||||
|
await self._cond.wait_for(
|
||||||
|
lambda: self._state not in (
|
||||||
|
_BarrierState.DRAINING, _BarrierState.RESETTING
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# see if the barrier is in a broken state
|
||||||
|
if self._state is _BarrierState.BROKEN:
|
||||||
|
raise exceptions.BrokenBarrierError("Barrier aborted")
|
||||||
|
|
||||||
|
async def _release(self):
|
||||||
|
# Release the tasks waiting in the barrier.
|
||||||
|
|
||||||
|
# Enter draining state.
|
||||||
|
# Next waiting tasks will be blocked until the end of draining.
|
||||||
|
self._state = _BarrierState.DRAINING
|
||||||
|
self._cond.notify_all()
|
||||||
|
|
||||||
|
async def _wait(self):
|
||||||
|
# Wait in the barrier until we are released. Raise an exception
|
||||||
|
# if the barrier is reset or broken.
|
||||||
|
|
||||||
|
# wait for end of filling
|
||||||
|
# unless a CancelledError occurs
|
||||||
|
await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
|
||||||
|
|
||||||
|
if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
|
||||||
|
raise exceptions.BrokenBarrierError("Abort or reset of barrier")
|
||||||
|
|
||||||
|
def _exit(self):
|
||||||
|
# If we are the last tasks to exit the barrier, signal any tasks
|
||||||
|
# waiting for the barrier to drain.
|
||||||
|
if self._count == 0:
|
||||||
|
if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
|
||||||
|
self._state = _BarrierState.FILLING
|
||||||
|
self._cond.notify_all()
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Reset the barrier to the initial state.
|
||||||
|
|
||||||
|
Any tasks currently waiting will get the BrokenBarrier exception
|
||||||
|
raised.
|
||||||
|
"""
|
||||||
|
async with self._cond:
|
||||||
|
if self._count > 0:
|
||||||
|
if self._state is not _BarrierState.RESETTING:
|
||||||
|
#reset the barrier, waking up tasks
|
||||||
|
self._state = _BarrierState.RESETTING
|
||||||
|
else:
|
||||||
|
self._state = _BarrierState.FILLING
|
||||||
|
self._cond.notify_all()
|
||||||
|
|
||||||
|
async def abort(self):
|
||||||
|
"""Place the barrier into a 'broken' state.
|
||||||
|
|
||||||
|
Useful in case of error. Any currently waiting tasks and tasks
|
||||||
|
attempting to 'wait()' will have BrokenBarrierError raised.
|
||||||
|
"""
|
||||||
|
async with self._cond:
|
||||||
|
self._state = _BarrierState.BROKEN
|
||||||
|
self._cond.notify_all()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parties(self):
|
||||||
|
"""Return the number of tasks required to trip the barrier."""
|
||||||
|
return self._parties
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_waiting(self):
|
||||||
|
"""Return the number of tasks currently waiting at the barrier."""
|
||||||
|
if self._state is _BarrierState.FILLING:
|
||||||
|
return self._count
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def broken(self):
|
||||||
|
"""Return True if the barrier is in a broken state."""
|
||||||
|
return self._state is _BarrierState.BROKEN
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
"""Tests for lock.py"""
|
"""Tests for locks.py"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
@ -9,7 +9,10 @@ import asyncio
|
||||||
STR_RGX_REPR = (
|
STR_RGX_REPR = (
|
||||||
r'^<(?P<class>.*?) object at (?P<address>.*?)'
|
r'^<(?P<class>.*?) object at (?P<address>.*?)'
|
||||||
r'\[(?P<extras>'
|
r'\[(?P<extras>'
|
||||||
r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
|
r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
|
||||||
|
r'(, value:\d)?'
|
||||||
|
r'(, waiters:\d+)?'
|
||||||
|
r'(, waiters:\d+\/\d+)?' # barrier
|
||||||
r')\]>\Z'
|
r')\]>\Z'
|
||||||
)
|
)
|
||||||
RGX_REPR = re.compile(STR_RGX_REPR)
|
RGX_REPR = re.compile(STR_RGX_REPR)
|
||||||
|
@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BarrierTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
|
self.N = 5
|
||||||
|
|
||||||
|
def make_tasks(self, n, coro):
|
||||||
|
tasks = [asyncio.create_task(coro()) for _ in range(n)]
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
async def gather_tasks(self, n, coro):
|
||||||
|
tasks = self.make_tasks(n, coro)
|
||||||
|
res = await asyncio.gather(*tasks)
|
||||||
|
return res, tasks
|
||||||
|
|
||||||
|
async def test_barrier(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError,
|
||||||
|
"object Barrier can't be used in 'await' expression",
|
||||||
|
):
|
||||||
|
await barrier
|
||||||
|
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
|
||||||
|
async def test_repr(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
|
||||||
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
|
||||||
|
waiters = []
|
||||||
|
async def wait(barrier):
|
||||||
|
await barrier.wait()
|
||||||
|
|
||||||
|
incr = 2
|
||||||
|
for i in range(incr):
|
||||||
|
waiters.append(asyncio.create_task(wait(barrier)))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||||
|
self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
|
||||||
|
# create missing waiters
|
||||||
|
for i in range(barrier.parties - barrier.n_waiting):
|
||||||
|
waiters.append(asyncio.create_task(wait(barrier)))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||||
|
self.assertIn("draining", repr(barrier))
|
||||||
|
|
||||||
|
# add a part of waiters
|
||||||
|
for i in range(incr):
|
||||||
|
waiters.append(asyncio.create_task(wait(barrier)))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# and reset
|
||||||
|
await barrier.reset()
|
||||||
|
|
||||||
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||||
|
self.assertIn("resetting", repr(barrier))
|
||||||
|
|
||||||
|
# add a part of waiters again
|
||||||
|
for i in range(incr):
|
||||||
|
waiters.append(asyncio.create_task(wait(barrier)))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# and abort
|
||||||
|
await barrier.abort()
|
||||||
|
|
||||||
|
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||||
|
self.assertIn("broken", repr(barrier))
|
||||||
|
self.assertTrue(barrier.broken)
|
||||||
|
|
||||||
|
# suppress unhandled exceptions
|
||||||
|
await asyncio.gather(*waiters, return_exceptions=True)
|
||||||
|
|
||||||
|
async def test_barrier_parties(self):
|
||||||
|
self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
|
||||||
|
self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
|
||||||
|
|
||||||
|
self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
|
||||||
|
|
||||||
|
async def test_context_manager(self):
|
||||||
|
self.N = 3
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
async with barrier as i:
|
||||||
|
results.append(i)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertListEqual(sorted(results), list(range(self.N)))
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_filling_one_task(self):
|
||||||
|
barrier = asyncio.Barrier(1)
|
||||||
|
|
||||||
|
async def f():
|
||||||
|
async with barrier as i:
|
||||||
|
return True
|
||||||
|
|
||||||
|
ret = await f()
|
||||||
|
|
||||||
|
self.assertTrue(ret)
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_filling_one_task_twice(self):
|
||||||
|
barrier = asyncio.Barrier(1)
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(barrier.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
|
||||||
|
t2 = asyncio.create_task(barrier.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
self.assertEqual(t1.result(), t2.result())
|
||||||
|
self.assertEqual(t1.done(), t2.done())
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_filling_task_by_task(self):
|
||||||
|
self.N = 3
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(barrier.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 1)
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
|
||||||
|
t2 = asyncio.create_task(barrier.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 2)
|
||||||
|
self.assertIn("filling", repr(barrier))
|
||||||
|
|
||||||
|
t3 = asyncio.create_task(barrier.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await asyncio.wait([t1, t2, t3])
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_filling_tasks_wait_twice(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
async with barrier:
|
||||||
|
results.append(True)
|
||||||
|
|
||||||
|
async with barrier:
|
||||||
|
results.append(False)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), self.N*2)
|
||||||
|
self.assertEqual(results.count(True), self.N)
|
||||||
|
self.assertEqual(results.count(False), self.N)
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_filling_tasks_check_return_value(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
async with barrier:
|
||||||
|
results1.append(True)
|
||||||
|
|
||||||
|
async with barrier as i:
|
||||||
|
results2.append(True)
|
||||||
|
return i
|
||||||
|
|
||||||
|
res, _ = await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertEqual(len(results1), self.N)
|
||||||
|
self.assertTrue(all(results1))
|
||||||
|
self.assertEqual(len(results2), self.N)
|
||||||
|
self.assertTrue(all(results2))
|
||||||
|
self.assertListEqual(sorted(res), list(range(self.N)))
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_draining_state(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
async with barrier:
|
||||||
|
# barrier state change to filling for the last task release
|
||||||
|
results.append("draining" in repr(barrier))
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), self.N)
|
||||||
|
self.assertEqual(results[-1], False)
|
||||||
|
self.assertTrue(all(results[:self.N-1]))
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_blocking_tasks_while_draining(self):
|
||||||
|
rewait = 2
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
barrier_nowaiting = asyncio.Barrier(self.N - rewait)
|
||||||
|
results = []
|
||||||
|
rewait_n = rewait
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
nonlocal rewait_n
|
||||||
|
|
||||||
|
# first time waiting
|
||||||
|
await barrier.wait()
|
||||||
|
|
||||||
|
# after wainting once for all tasks
|
||||||
|
if rewait_n > 0:
|
||||||
|
rewait_n -= 1
|
||||||
|
# wait again only for rewait tasks
|
||||||
|
await barrier.wait()
|
||||||
|
else:
|
||||||
|
# wait for end of draining state`
|
||||||
|
await barrier_nowaiting.wait()
|
||||||
|
# wait for other waiting tasks
|
||||||
|
await barrier.wait()
|
||||||
|
|
||||||
|
# a success means that barrier_nowaiting
|
||||||
|
# was waited for exactly N-rewait=3 times
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
async def test_filling_tasks_cancel_one(self):
|
||||||
|
self.N = 3
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
await barrier.wait()
|
||||||
|
results.append(True)
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(coro())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 1)
|
||||||
|
|
||||||
|
t2 = asyncio.create_task(coro())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 2)
|
||||||
|
|
||||||
|
t1.cancel()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 1)
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
await t1
|
||||||
|
self.assertTrue(t1.cancelled())
|
||||||
|
|
||||||
|
t3 = asyncio.create_task(coro())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual(barrier.n_waiting, 2)
|
||||||
|
|
||||||
|
t4 = asyncio.create_task(coro())
|
||||||
|
await asyncio.gather(t2, t3, t4)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), self.N)
|
||||||
|
self.assertTrue(all(results))
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_reset_barrier(self):
|
||||||
|
barrier = asyncio.Barrier(1)
|
||||||
|
|
||||||
|
asyncio.create_task(barrier.reset())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_reset_barrier_while_tasks_waiting(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
results.append(True)
|
||||||
|
|
||||||
|
async def coro_reset():
|
||||||
|
await barrier.reset()
|
||||||
|
|
||||||
|
# N-1 tasks waiting on barrier with N parties
|
||||||
|
tasks = self.make_tasks(self.N-1, coro)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# reset the barrier
|
||||||
|
asyncio.create_task(coro_reset())
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), self.N-1)
|
||||||
|
self.assertTrue(all(results))
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertNotIn("resetting", repr(barrier))
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_reset_barrier_when_tasks_half_draining(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
rest_of_tasks = self.N//2
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
# catch here waiting tasks
|
||||||
|
results1.append(True)
|
||||||
|
else:
|
||||||
|
# here drained task ouside the barrier
|
||||||
|
if rest_of_tasks == barrier._count:
|
||||||
|
# tasks outside the barrier
|
||||||
|
await barrier.reset()
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertEqual(results1, [True]*rest_of_tasks)
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertNotIn("resetting", repr(barrier))
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
blocking_tasks = self.N//2
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
nonlocal count
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
# here catch still waiting tasks
|
||||||
|
results1.append(True)
|
||||||
|
|
||||||
|
# so now waiting again to reach nb_parties
|
||||||
|
await barrier.wait()
|
||||||
|
else:
|
||||||
|
count += 1
|
||||||
|
if count > blocking_tasks:
|
||||||
|
# reset now: raise asyncio.BrokenBarrierError for waiting tasks
|
||||||
|
await barrier.reset()
|
||||||
|
|
||||||
|
# so now waiting again to reach nb_parties
|
||||||
|
await barrier.wait()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
# here no catch - blocked tasks go to wait
|
||||||
|
results2.append(True)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertEqual(results1, [True]*blocking_tasks)
|
||||||
|
self.assertEqual(results2, [])
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertNotIn("resetting", repr(barrier))
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
|
||||||
|
async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
|
||||||
|
async def coro1():
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
results1.append(True)
|
||||||
|
finally:
|
||||||
|
await barrier.wait()
|
||||||
|
results2.append(True)
|
||||||
|
|
||||||
|
async def coro2():
|
||||||
|
async with barrier:
|
||||||
|
results2.append(True)
|
||||||
|
|
||||||
|
tasks = self.make_tasks(self.N-1, coro1)
|
||||||
|
|
||||||
|
# reset barrier, N-1 waiting tasks raise an BrokenBarrierError
|
||||||
|
asyncio.create_task(barrier.reset())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# complete waiting tasks in the `finally`
|
||||||
|
asyncio.create_task(coro2())
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
self.assertEqual(len(results1), self.N-1)
|
||||||
|
self.assertTrue(all(results1))
|
||||||
|
self.assertEqual(len(results2), self.N)
|
||||||
|
self.assertTrue(all(results2))
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reset_barrier_while_tasks_draining(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
results3 = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
nonlocal count
|
||||||
|
|
||||||
|
i = await barrier.wait()
|
||||||
|
count += 1
|
||||||
|
if count == self.N:
|
||||||
|
# last task exited from barrier
|
||||||
|
await barrier.reset()
|
||||||
|
|
||||||
|
# wit here to reach the `parties`
|
||||||
|
await barrier.wait()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# second waiting
|
||||||
|
await barrier.wait()
|
||||||
|
|
||||||
|
# N-1 tasks here
|
||||||
|
results1.append(True)
|
||||||
|
except Exception as e:
|
||||||
|
# never goes here
|
||||||
|
results2.append(True)
|
||||||
|
|
||||||
|
# Now, pass the barrier again
|
||||||
|
# last wait, must be completed
|
||||||
|
k = await barrier.wait()
|
||||||
|
results3.append(True)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertFalse(barrier.broken)
|
||||||
|
self.assertTrue(all(results1))
|
||||||
|
self.assertEqual(len(results1), self.N-1)
|
||||||
|
self.assertEqual(len(results2), 0)
|
||||||
|
self.assertEqual(len(results3), self.N)
|
||||||
|
self.assertTrue(all(results3))
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
|
||||||
|
async def test_abort_barrier(self):
|
||||||
|
barrier = asyncio.Barrier(1)
|
||||||
|
|
||||||
|
asyncio.create_task(barrier.abort())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertTrue(barrier.broken)
|
||||||
|
|
||||||
|
async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
blocking_tasks = self.N//2
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
nonlocal count
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
# here catch tasks waiting to drain
|
||||||
|
results1.append(True)
|
||||||
|
else:
|
||||||
|
count += 1
|
||||||
|
if count > blocking_tasks:
|
||||||
|
# abort now: raise asyncio.BrokenBarrierError for all tasks
|
||||||
|
await barrier.abort()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await barrier.wait()
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
# here catch blocked tasks (already drained)
|
||||||
|
results2.append(True)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertTrue(barrier.broken)
|
||||||
|
self.assertEqual(results1, [True]*blocking_tasks)
|
||||||
|
self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
self.assertNotIn("resetting", repr(barrier))
|
||||||
|
|
||||||
|
async def test_abort_barrier_when_exception(self):
|
||||||
|
# test from threading.Barrier: see `lock_tests.test_reset`
|
||||||
|
barrier = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
try:
|
||||||
|
async with barrier as i :
|
||||||
|
if i == self.N//2:
|
||||||
|
raise RuntimeError
|
||||||
|
async with barrier:
|
||||||
|
results1.append(True)
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
results2.append(True)
|
||||||
|
except RuntimeError:
|
||||||
|
await barrier.abort()
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertTrue(barrier.broken)
|
||||||
|
self.assertEqual(len(results1), 0)
|
||||||
|
self.assertEqual(len(results2), self.N-1)
|
||||||
|
self.assertTrue(all(results2))
|
||||||
|
self.assertEqual(barrier.n_waiting, 0)
|
||||||
|
|
||||||
|
async def test_abort_barrier_when_exception_then_resetting(self):
|
||||||
|
# test from threading.Barrier: see `lock_tests.test_abort_and_reset``
|
||||||
|
barrier1 = asyncio.Barrier(self.N)
|
||||||
|
barrier2 = asyncio.Barrier(self.N)
|
||||||
|
results1 = []
|
||||||
|
results2 = []
|
||||||
|
results3 = []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
try:
|
||||||
|
i = await barrier1.wait()
|
||||||
|
if i == self.N//2:
|
||||||
|
raise RuntimeError
|
||||||
|
await barrier1.wait()
|
||||||
|
results1.append(True)
|
||||||
|
except asyncio.BrokenBarrierError:
|
||||||
|
results2.append(True)
|
||||||
|
except RuntimeError:
|
||||||
|
await barrier1.abort()
|
||||||
|
|
||||||
|
# Synchronize and reset the barrier. Must synchronize first so
|
||||||
|
# that everyone has left it when we reset, and after so that no
|
||||||
|
# one enters it before the reset.
|
||||||
|
i = await barrier2.wait()
|
||||||
|
if i == self.N//2:
|
||||||
|
await barrier1.reset()
|
||||||
|
await barrier2.wait()
|
||||||
|
await barrier1.wait()
|
||||||
|
results3.append(True)
|
||||||
|
|
||||||
|
await self.gather_tasks(self.N, coro)
|
||||||
|
|
||||||
|
self.assertFalse(barrier1.broken)
|
||||||
|
self.assertEqual(len(results1), 0)
|
||||||
|
self.assertEqual(len(results2), self.N-1)
|
||||||
|
self.assertTrue(all(results2))
|
||||||
|
self.assertEqual(len(results3), self.N)
|
||||||
|
self.assertTrue(all(results3))
|
||||||
|
|
||||||
|
self.assertEqual(barrier1.n_waiting, 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Add an Barrier object in synchronization primitives of *asyncio* Lib in order to be consistant with Barrier from *threading* and *multiprocessing* libs*
|
Loading…
Add table
Add a link
Reference in a new issue