mirror of
https://github.com/python/cpython.git
synced 2025-10-17 12:18:23 +00:00
asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.
This commit is contained in:
parent
ab27a9fc4b
commit
ab3c88983b
2 changed files with 95 additions and 22 deletions
|
@ -9,6 +9,36 @@ from . import futures
|
||||||
from . import tasks
|
from . import tasks
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextManager:
|
||||||
|
"""Context manager.
|
||||||
|
|
||||||
|
This enables the following idiom for acquiring and releasing a
|
||||||
|
lock around a block:
|
||||||
|
|
||||||
|
with (yield from lock):
|
||||||
|
<block>
|
||||||
|
|
||||||
|
while failing loudly when accidentally using:
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
<block>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock):
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# We have no use for the "as ..." clause in the with
|
||||||
|
# statement for locks.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
try:
|
||||||
|
self._lock.release()
|
||||||
|
finally:
|
||||||
|
self._lock = None # Crudely prevent reuse.
|
||||||
|
|
||||||
|
|
||||||
class Lock:
|
class Lock:
|
||||||
"""Primitive lock objects.
|
"""Primitive lock objects.
|
||||||
|
|
||||||
|
@ -124,17 +154,29 @@ class Lock:
|
||||||
raise RuntimeError('Lock is not acquired.')
|
raise RuntimeError('Lock is not acquired.')
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if not self._locked:
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
'"yield from" should be used as context manager expression')
|
||||||
'"yield from" should be used as context manager expression')
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self.release()
|
# This must exist because __enter__ exists, even though that
|
||||||
|
# always raises; that's how the with-statement works.
|
||||||
|
pass
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
# This is not a coroutine. It is meant to enable the idiom:
|
||||||
|
#
|
||||||
|
# with (yield from lock):
|
||||||
|
# <block>
|
||||||
|
#
|
||||||
|
# as an alternative to:
|
||||||
|
#
|
||||||
|
# yield from lock.acquire()
|
||||||
|
# try:
|
||||||
|
# <block>
|
||||||
|
# finally:
|
||||||
|
# lock.release()
|
||||||
yield from self.acquire()
|
yield from self.acquire()
|
||||||
return self
|
return _ContextManager(self)
|
||||||
|
|
||||||
|
|
||||||
class Event:
|
class Event:
|
||||||
|
@ -311,14 +353,16 @@ class Condition:
|
||||||
self.notify(len(self._waiters))
|
self.notify(len(self._waiters))
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self._lock.__enter__()
|
raise RuntimeError(
|
||||||
|
'"yield from" should be used as context manager expression')
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
return self._lock.__exit__(*args)
|
pass
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
# See comment in Lock.__iter__().
|
||||||
yield from self.acquire()
|
yield from self.acquire()
|
||||||
return self
|
return _ContextManager(self)
|
||||||
|
|
||||||
|
|
||||||
class Semaphore:
|
class Semaphore:
|
||||||
|
@ -341,7 +385,6 @@ class Semaphore:
|
||||||
raise ValueError("Semaphore initial value must be >= 0")
|
raise ValueError("Semaphore initial value must be >= 0")
|
||||||
self._value = value
|
self._value = value
|
||||||
self._waiters = collections.deque()
|
self._waiters = collections.deque()
|
||||||
self._locked = (value == 0)
|
|
||||||
if loop is not None:
|
if loop is not None:
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
else:
|
else:
|
||||||
|
@ -349,7 +392,7 @@ class Semaphore:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
res = super().__repr__()
|
res = super().__repr__()
|
||||||
extra = 'locked' if self._locked else 'unlocked,value:{}'.format(
|
extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
|
||||||
self._value)
|
self._value)
|
||||||
if self._waiters:
|
if self._waiters:
|
||||||
extra = '{},waiters:{}'.format(extra, len(self._waiters))
|
extra = '{},waiters:{}'.format(extra, len(self._waiters))
|
||||||
|
@ -357,7 +400,7 @@ class Semaphore:
|
||||||
|
|
||||||
def locked(self):
|
def locked(self):
|
||||||
"""Returns True if semaphore can not be acquired immediately."""
|
"""Returns True if semaphore can not be acquired immediately."""
|
||||||
return self._locked
|
return self._value == 0
|
||||||
|
|
||||||
@tasks.coroutine
|
@tasks.coroutine
|
||||||
def acquire(self):
|
def acquire(self):
|
||||||
|
@ -371,8 +414,6 @@ class Semaphore:
|
||||||
"""
|
"""
|
||||||
if not self._waiters and self._value > 0:
|
if not self._waiters and self._value > 0:
|
||||||
self._value -= 1
|
self._value -= 1
|
||||||
if self._value == 0:
|
|
||||||
self._locked = True
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
fut = futures.Future(loop=self._loop)
|
fut = futures.Future(loop=self._loop)
|
||||||
|
@ -380,8 +421,6 @@ class Semaphore:
|
||||||
try:
|
try:
|
||||||
yield from fut
|
yield from fut
|
||||||
self._value -= 1
|
self._value -= 1
|
||||||
if self._value == 0:
|
|
||||||
self._locked = True
|
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
self._waiters.remove(fut)
|
self._waiters.remove(fut)
|
||||||
|
@ -392,23 +431,22 @@ class Semaphore:
|
||||||
become larger than zero again, wake up that coroutine.
|
become larger than zero again, wake up that coroutine.
|
||||||
"""
|
"""
|
||||||
self._value += 1
|
self._value += 1
|
||||||
self._locked = False
|
|
||||||
for waiter in self._waiters:
|
for waiter in self._waiters:
|
||||||
if not waiter.done():
|
if not waiter.done():
|
||||||
waiter.set_result(True)
|
waiter.set_result(True)
|
||||||
break
|
break
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# TODO: This is questionable. How do we know the user actually
|
raise RuntimeError(
|
||||||
# wrote "with (yield from sema)" instead of "with sema"?
|
'"yield from" should be used as context manager expression')
|
||||||
return True
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self.release()
|
pass
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
# See comment in Lock.__iter__().
|
||||||
yield from self.acquire()
|
yield from self.acquire()
|
||||||
return self
|
return _ContextManager(self)
|
||||||
|
|
||||||
|
|
||||||
class BoundedSemaphore(Semaphore):
|
class BoundedSemaphore(Semaphore):
|
||||||
|
|
|
@ -208,6 +208,24 @@ class LockTests(unittest.TestCase):
|
||||||
|
|
||||||
self.assertFalse(lock.locked())
|
self.assertFalse(lock.locked())
|
||||||
|
|
||||||
|
def test_context_manager_cant_reuse(self):
|
||||||
|
lock = asyncio.Lock(loop=self.loop)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def acquire_lock():
|
||||||
|
return (yield from lock)
|
||||||
|
|
||||||
|
# This spells "yield from lock" outside a generator.
|
||||||
|
cm = self.loop.run_until_complete(acquire_lock())
|
||||||
|
with cm:
|
||||||
|
self.assertTrue(lock.locked())
|
||||||
|
|
||||||
|
self.assertFalse(lock.locked())
|
||||||
|
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
with cm:
|
||||||
|
pass
|
||||||
|
|
||||||
def test_context_manager_no_yield(self):
|
def test_context_manager_no_yield(self):
|
||||||
lock = asyncio.Lock(loop=self.loop)
|
lock = asyncio.Lock(loop=self.loop)
|
||||||
|
|
||||||
|
@ -219,6 +237,8 @@ class LockTests(unittest.TestCase):
|
||||||
str(err),
|
str(err),
|
||||||
'"yield from" should be used as context manager expression')
|
'"yield from" should be used as context manager expression')
|
||||||
|
|
||||||
|
self.assertFalse(lock.locked())
|
||||||
|
|
||||||
|
|
||||||
class EventTests(unittest.TestCase):
|
class EventTests(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase):
|
||||||
str(err),
|
str(err),
|
||||||
'"yield from" should be used as context manager expression')
|
'"yield from" should be used as context manager expression')
|
||||||
|
|
||||||
|
self.assertFalse(cond.locked())
|
||||||
|
|
||||||
|
|
||||||
class SemaphoreTests(unittest.TestCase):
|
class SemaphoreTests(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(2, sem._value)
|
self.assertEqual(2, sem._value)
|
||||||
|
|
||||||
|
def test_context_manager_no_yield(self):
|
||||||
|
sem = asyncio.Semaphore(2, loop=self.loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with sem:
|
||||||
|
self.fail('RuntimeError is not raised in with expression')
|
||||||
|
except RuntimeError as err:
|
||||||
|
self.assertEqual(
|
||||||
|
str(err),
|
||||||
|
'"yield from" should be used as context manager expression')
|
||||||
|
|
||||||
|
self.assertEqual(2, sem._value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue