bpo-43352: Add a Barrier object in asyncio lib (GH-24903)

Co-authored-by: Yury Selivanov <yury@edgedb.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Duprat 2022-03-25 23:01:21 +01:00 committed by GitHub
parent 20e6e5636a
commit d03acd7270
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 856 additions and 5 deletions

View file

@ -1,7 +1,8 @@
"""asyncio exceptions."""
__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError',
__all__ = ('BrokenBarrierError',
'CancelledError', 'InvalidStateError', 'TimeoutError',
'IncompleteReadError', 'LimitOverrunError',
'SendfileNotAvailableError')
@ -55,3 +56,7 @@ class LimitOverrunError(Exception):
def __reduce__(self):
return type(self), (self.args[0], self.consumed)
class BrokenBarrierError(RuntimeError):
"""Barrier is broken by barrier.abort() call."""

View file

@ -1,14 +1,15 @@
"""Synchronization primitives."""
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
'BoundedSemaphore', 'Barrier')
import collections
import enum
from . import exceptions
from . import mixins
from . import tasks
class _ContextManagerMixin:
async def __aenter__(self):
await self.acquire()
@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()
class _BarrierState(enum.Enum):
FILLING = 'filling'
DRAINING = 'draining'
RESETTING = 'resetting'
BROKEN = 'broken'
class Barrier(mixins._LoopBoundMixin):
"""Asyncio equivalent to threading.Barrier
Implements a Barrier primitive.
Useful for synchronizing a fixed number of tasks at known synchronization
points. Tasks block on 'wait()' and are simultaneously awoken once they
have all made their call.
"""
def __init__(self, parties):
"""Create a barrier, initialised to 'parties' tasks."""
if parties < 1:
raise ValueError('parties must be > 0')
self._cond = Condition() # notify all tasks when state changes
self._parties = parties
self._state = _BarrierState.FILLING
self._count = 0 # count tasks in Barrier
def __repr__(self):
res = super().__repr__()
extra = f'{self._state.value}'
if not self.broken:
extra += f', waiters:{self.n_waiting}/{self.parties}'
return f'<{res[1:-1]} [{extra}]>'
async def __aenter__(self):
# wait for the barrier reaches the parties number
# when start draining release and return index of waited task
return await self.wait()
async def __aexit__(self, *args):
pass
async def wait(self):
"""Wait for the barrier.
When the specified number of tasks have started waiting, they are all
simultaneously awoken.
Returns an unique and individual index number from 0 to 'parties-1'.
"""
async with self._cond:
await self._block() # Block while the barrier drains or resets.
try:
index = self._count
self._count += 1
if index + 1 == self._parties:
# We release the barrier
await self._release()
else:
await self._wait()
return index
finally:
self._count -= 1
# Wake up any tasks waiting for barrier to drain.
self._exit()
async def _block(self):
# Block until the barrier is ready for us,
# or raise an exception if it is broken.
#
# It is draining or resetting, wait until done
# unless a CancelledError occurs
await self._cond.wait_for(
lambda: self._state not in (
_BarrierState.DRAINING, _BarrierState.RESETTING
)
)
# see if the barrier is in a broken state
if self._state is _BarrierState.BROKEN:
raise exceptions.BrokenBarrierError("Barrier aborted")
async def _release(self):
# Release the tasks waiting in the barrier.
# Enter draining state.
# Next waiting tasks will be blocked until the end of draining.
self._state = _BarrierState.DRAINING
self._cond.notify_all()
async def _wait(self):
# Wait in the barrier until we are released. Raise an exception
# if the barrier is reset or broken.
# wait for end of filling
# unless a CancelledError occurs
await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
raise exceptions.BrokenBarrierError("Abort or reset of barrier")
def _exit(self):
# If we are the last tasks to exit the barrier, signal any tasks
# waiting for the barrier to drain.
if self._count == 0:
if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
self._state = _BarrierState.FILLING
self._cond.notify_all()
async def reset(self):
"""Reset the barrier to the initial state.
Any tasks currently waiting will get the BrokenBarrier exception
raised.
"""
async with self._cond:
if self._count > 0:
if self._state is not _BarrierState.RESETTING:
#reset the barrier, waking up tasks
self._state = _BarrierState.RESETTING
else:
self._state = _BarrierState.FILLING
self._cond.notify_all()
async def abort(self):
"""Place the barrier into a 'broken' state.
Useful in case of error. Any currently waiting tasks and tasks
attempting to 'wait()' will have BrokenBarrierError raised.
"""
async with self._cond:
self._state = _BarrierState.BROKEN
self._cond.notify_all()
@property
def parties(self):
"""Return the number of tasks required to trip the barrier."""
return self._parties
@property
def n_waiting(self):
"""Return the number of tasks currently waiting at the barrier."""
if self._state is _BarrierState.FILLING:
return self._count
return 0
@property
def broken(self):
"""Return True if the barrier is in a broken state."""
return self._state is _BarrierState.BROKEN