mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
20e6e5636a
commit
d03acd7270
6 changed files with 856 additions and 5 deletions
|
@ -1,7 +1,8 @@
|
|||
"""asyncio exceptions."""
|
||||
|
||||
|
||||
__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError',
|
||||
__all__ = ('BrokenBarrierError',
|
||||
'CancelledError', 'InvalidStateError', 'TimeoutError',
|
||||
'IncompleteReadError', 'LimitOverrunError',
|
||||
'SendfileNotAvailableError')
|
||||
|
||||
|
@ -55,3 +56,7 @@ class LimitOverrunError(Exception):
|
|||
|
||||
def __reduce__(self):
|
||||
return type(self), (self.args[0], self.consumed)
|
||||
|
||||
|
||||
class BrokenBarrierError(RuntimeError):
|
||||
"""Barrier is broken by barrier.abort() call."""
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
"""Synchronization primitives."""
|
||||
|
||||
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
|
||||
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
|
||||
'BoundedSemaphore', 'Barrier')
|
||||
|
||||
import collections
|
||||
import enum
|
||||
|
||||
from . import exceptions
|
||||
from . import mixins
|
||||
from . import tasks
|
||||
|
||||
|
||||
class _ContextManagerMixin:
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
|
@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
|
|||
if self._value >= self._bound_value:
|
||||
raise ValueError('BoundedSemaphore released too many times')
|
||||
super().release()
|
||||
|
||||
|
||||
|
||||
class _BarrierState(enum.Enum):
|
||||
FILLING = 'filling'
|
||||
DRAINING = 'draining'
|
||||
RESETTING = 'resetting'
|
||||
BROKEN = 'broken'
|
||||
|
||||
|
||||
class Barrier(mixins._LoopBoundMixin):
|
||||
"""Asyncio equivalent to threading.Barrier
|
||||
|
||||
Implements a Barrier primitive.
|
||||
Useful for synchronizing a fixed number of tasks at known synchronization
|
||||
points. Tasks block on 'wait()' and are simultaneously awoken once they
|
||||
have all made their call.
|
||||
"""
|
||||
|
||||
def __init__(self, parties):
|
||||
"""Create a barrier, initialised to 'parties' tasks."""
|
||||
if parties < 1:
|
||||
raise ValueError('parties must be > 0')
|
||||
|
||||
self._cond = Condition() # notify all tasks when state changes
|
||||
|
||||
self._parties = parties
|
||||
self._state = _BarrierState.FILLING
|
||||
self._count = 0 # count tasks in Barrier
|
||||
|
||||
def __repr__(self):
|
||||
res = super().__repr__()
|
||||
extra = f'{self._state.value}'
|
||||
if not self.broken:
|
||||
extra += f', waiters:{self.n_waiting}/{self.parties}'
|
||||
return f'<{res[1:-1]} [{extra}]>'
|
||||
|
||||
async def __aenter__(self):
|
||||
# wait for the barrier reaches the parties number
|
||||
# when start draining release and return index of waited task
|
||||
return await self.wait()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the barrier.
|
||||
|
||||
When the specified number of tasks have started waiting, they are all
|
||||
simultaneously awoken.
|
||||
Returns an unique and individual index number from 0 to 'parties-1'.
|
||||
"""
|
||||
async with self._cond:
|
||||
await self._block() # Block while the barrier drains or resets.
|
||||
try:
|
||||
index = self._count
|
||||
self._count += 1
|
||||
if index + 1 == self._parties:
|
||||
# We release the barrier
|
||||
await self._release()
|
||||
else:
|
||||
await self._wait()
|
||||
return index
|
||||
finally:
|
||||
self._count -= 1
|
||||
# Wake up any tasks waiting for barrier to drain.
|
||||
self._exit()
|
||||
|
||||
async def _block(self):
|
||||
# Block until the barrier is ready for us,
|
||||
# or raise an exception if it is broken.
|
||||
#
|
||||
# It is draining or resetting, wait until done
|
||||
# unless a CancelledError occurs
|
||||
await self._cond.wait_for(
|
||||
lambda: self._state not in (
|
||||
_BarrierState.DRAINING, _BarrierState.RESETTING
|
||||
)
|
||||
)
|
||||
|
||||
# see if the barrier is in a broken state
|
||||
if self._state is _BarrierState.BROKEN:
|
||||
raise exceptions.BrokenBarrierError("Barrier aborted")
|
||||
|
||||
async def _release(self):
|
||||
# Release the tasks waiting in the barrier.
|
||||
|
||||
# Enter draining state.
|
||||
# Next waiting tasks will be blocked until the end of draining.
|
||||
self._state = _BarrierState.DRAINING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def _wait(self):
|
||||
# Wait in the barrier until we are released. Raise an exception
|
||||
# if the barrier is reset or broken.
|
||||
|
||||
# wait for end of filling
|
||||
# unless a CancelledError occurs
|
||||
await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
|
||||
|
||||
if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
|
||||
raise exceptions.BrokenBarrierError("Abort or reset of barrier")
|
||||
|
||||
def _exit(self):
|
||||
# If we are the last tasks to exit the barrier, signal any tasks
|
||||
# waiting for the barrier to drain.
|
||||
if self._count == 0:
|
||||
if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
|
||||
self._state = _BarrierState.FILLING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the barrier to the initial state.
|
||||
|
||||
Any tasks currently waiting will get the BrokenBarrier exception
|
||||
raised.
|
||||
"""
|
||||
async with self._cond:
|
||||
if self._count > 0:
|
||||
if self._state is not _BarrierState.RESETTING:
|
||||
#reset the barrier, waking up tasks
|
||||
self._state = _BarrierState.RESETTING
|
||||
else:
|
||||
self._state = _BarrierState.FILLING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def abort(self):
|
||||
"""Place the barrier into a 'broken' state.
|
||||
|
||||
Useful in case of error. Any currently waiting tasks and tasks
|
||||
attempting to 'wait()' will have BrokenBarrierError raised.
|
||||
"""
|
||||
async with self._cond:
|
||||
self._state = _BarrierState.BROKEN
|
||||
self._cond.notify_all()
|
||||
|
||||
@property
|
||||
def parties(self):
|
||||
"""Return the number of tasks required to trip the barrier."""
|
||||
return self._parties
|
||||
|
||||
@property
|
||||
def n_waiting(self):
|
||||
"""Return the number of tasks currently waiting at the barrier."""
|
||||
if self._state is _BarrierState.FILLING:
|
||||
return self._count
|
||||
return 0
|
||||
|
||||
@property
|
||||
def broken(self):
|
||||
"""Return True if the barrier is in a broken state."""
|
||||
return self._state is _BarrierState.BROKEN
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue