mirror of
https://github.com/python/cpython.git
synced 2025-10-10 00:43:41 +00:00
issue 8777
Add threading.Barrier
This commit is contained in:
parent
65ffae0aa3
commit
3be00037d6
4 changed files with 469 additions and 0 deletions
|
@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
|
|||
sem.acquire()
|
||||
sem.release()
|
||||
self.assertRaises(ValueError, sem.release)
|
||||
|
||||
|
||||
class BarrierTests(BaseTestCase):
|
||||
"""
|
||||
Tests for Barrier objects.
|
||||
"""
|
||||
N = 5
|
||||
|
||||
def setUp(self):
|
||||
self.barrier = self.barriertype(self.N, timeout=0.1)
|
||||
def tearDown(self):
|
||||
self.barrier.abort()
|
||||
|
||||
def run_threads(self, f):
|
||||
b = Bunch(f, self.N-1)
|
||||
f()
|
||||
b.wait_for_finished()
|
||||
|
||||
def multipass(self, results, n):
|
||||
m = self.barrier.parties
|
||||
self.assertEqual(m, self.N)
|
||||
for i in range(n):
|
||||
results[0].append(True)
|
||||
self.assertEqual(len(results[1]), i * m)
|
||||
self.barrier.wait()
|
||||
results[1].append(True)
|
||||
self.assertEqual(len(results[0]), (i + 1) * m)
|
||||
self.barrier.wait()
|
||||
self.assertEqual(self.barrier.n_waiting, 0)
|
||||
self.assertFalse(self.barrier.broken)
|
||||
|
||||
def test_barrier(self, passes=1):
|
||||
"""
|
||||
Test that a barrier is passed in lockstep
|
||||
"""
|
||||
results = [[],[]]
|
||||
def f():
|
||||
self.multipass(results, passes)
|
||||
self.run_threads(f)
|
||||
|
||||
def test_barrier_10(self):
|
||||
"""
|
||||
Test that a barrier works for 10 consecutive runs
|
||||
"""
|
||||
return self.test_barrier(10)
|
||||
|
||||
def test_wait_return(self):
|
||||
"""
|
||||
test the return value from barrier.wait
|
||||
"""
|
||||
results = []
|
||||
def f():
|
||||
r = self.barrier.wait()
|
||||
results.append(r)
|
||||
|
||||
self.run_threads(f)
|
||||
self.assertEqual(sum(results), sum(range(self.N)))
|
||||
|
||||
def test_action(self):
|
||||
"""
|
||||
Test the 'action' callback
|
||||
"""
|
||||
results = []
|
||||
def action():
|
||||
results.append(True)
|
||||
barrier = self.barriertype(self.N, action)
|
||||
def f():
|
||||
barrier.wait()
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
self.run_threads(f)
|
||||
|
||||
def test_abort(self):
|
||||
"""
|
||||
Test that an abort will put the barrier in a broken state
|
||||
"""
|
||||
results1 = []
|
||||
results2 = []
|
||||
def f():
|
||||
try:
|
||||
i = self.barrier.wait()
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
self.barrier.wait()
|
||||
results1.append(True)
|
||||
except threading.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
self.barrier.abort()
|
||||
pass
|
||||
|
||||
self.run_threads(f)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertTrue(self.barrier.broken)
|
||||
|
||||
def test_reset(self):
|
||||
"""
|
||||
Test that a 'reset' on a barrier frees the waiting threads
|
||||
"""
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
def f():
|
||||
i = self.barrier.wait()
|
||||
if i == self.N//2:
|
||||
# Wait until the other threads are all in the barrier.
|
||||
while self.barrier.n_waiting < self.N-1:
|
||||
time.sleep(0.001)
|
||||
self.barrier.reset()
|
||||
else:
|
||||
try:
|
||||
self.barrier.wait()
|
||||
results1.append(True)
|
||||
except threading.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
# Now, pass the barrier again
|
||||
self.barrier.wait()
|
||||
results3.append(True)
|
||||
|
||||
self.run_threads(f)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertEqual(len(results3), self.N)
|
||||
|
||||
|
||||
def test_abort_and_reset(self):
|
||||
"""
|
||||
Test that a barrier can be reset after being broken.
|
||||
"""
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
barrier2 = self.barriertype(self.N)
|
||||
def f():
|
||||
try:
|
||||
i = self.barrier.wait()
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
self.barrier.wait()
|
||||
results1.append(True)
|
||||
except threading.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
self.barrier.abort()
|
||||
pass
|
||||
# Synchronize and reset the barrier. Must synchronize first so
|
||||
# that everyone has left it when we reset, and after so that no
|
||||
# one enters it before the reset.
|
||||
if barrier2.wait() == self.N//2:
|
||||
self.barrier.reset()
|
||||
barrier2.wait()
|
||||
self.barrier.wait()
|
||||
results3.append(True)
|
||||
|
||||
self.run_threads(f)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertEqual(len(results3), self.N)
|
||||
|
||||
def test_timeout(self):
|
||||
"""
|
||||
Test wait(timeout)
|
||||
"""
|
||||
def f():
|
||||
i = self.barrier.wait()
|
||||
if i == self.N // 2:
|
||||
# One thread is late!
|
||||
time.sleep(0.1)
|
||||
# Default timeout is 0.1, so this is shorter.
|
||||
self.assertRaises(threading.BrokenBarrierError,
|
||||
self.barrier.wait, 0.05)
|
||||
self.run_threads(f)
|
||||
|
||||
def test_default_timeout(self):
|
||||
"""
|
||||
Test the barrier's default timeout
|
||||
"""
|
||||
def f():
|
||||
i = self.barrier.wait()
|
||||
if i == self.N // 2:
|
||||
# One thread is later than the default timeout of 0.1s.
|
||||
time.sleep(0.15)
|
||||
self.assertRaises(threading.BrokenBarrierError, self.barrier.wait)
|
||||
self.run_threads(f)
|
||||
|
||||
def test_single_thread(self):
|
||||
b = self.barriertype(1)
|
||||
b.wait()
|
||||
b.wait()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue