bpo-31234: Add test.support.wait_threads_exit() (#3578)

Use _thread.count() to wait until threads exit. The new context
manager prevents the "dangling thread" warning.
This commit is contained in:
Victor Stinner 2017-09-14 13:07:24 -07:00 committed by GitHub
parent b8c7be2c52
commit ff40ecda73
6 changed files with 161 additions and 109 deletions

View file

@ -31,6 +31,9 @@ class Bunch(object):
self.started = [] self.started = []
self.finished = [] self.finished = []
self._can_exit = not wait_before_exit self._can_exit = not wait_before_exit
self.wait_thread = support.wait_threads_exit()
self.wait_thread.__enter__()
def task(): def task():
tid = threading.get_ident() tid = threading.get_ident()
self.started.append(tid) self.started.append(tid)
@ -40,6 +43,7 @@ class Bunch(object):
self.finished.append(tid) self.finished.append(tid)
while not self._can_exit: while not self._can_exit:
_wait() _wait()
try: try:
for i in range(n): for i in range(n):
start_new_thread(task, ()) start_new_thread(task, ())
@ -54,13 +58,8 @@ class Bunch(object):
def wait_for_finished(self): def wait_for_finished(self):
while len(self.finished) < self.n: while len(self.finished) < self.n:
_wait() _wait()
# Wait a little bit longer to prevent the "threading_cleanup() # Wait for threads exit
# failed to cleanup X threads" warning. The loop above is a weak self.wait_thread.__exit__(None, None, None)
# synchronization. At the C level, t_bootstrap() can still be
# running and so _thread.count() still accounts the "almost dead"
# thead.
for _ in range(self.n):
_wait()
def do_finish(self): def do_finish(self):
self._can_exit = True self._can_exit = True
@ -227,20 +226,23 @@ class LockTests(BaseLockTests):
# Lock needs to be released before re-acquiring. # Lock needs to be released before re-acquiring.
lock = self.locktype() lock = self.locktype()
phase = [] phase = []
def f(): def f():
lock.acquire() lock.acquire()
phase.append(None) phase.append(None)
lock.acquire() lock.acquire()
phase.append(None) phase.append(None)
start_new_thread(f, ())
while len(phase) == 0: with support.wait_threads_exit():
start_new_thread(f, ())
while len(phase) == 0:
_wait()
_wait() _wait()
_wait() self.assertEqual(len(phase), 1)
self.assertEqual(len(phase), 1) lock.release()
lock.release() while len(phase) == 1:
while len(phase) == 1: _wait()
_wait() self.assertEqual(len(phase), 2)
self.assertEqual(len(phase), 2)
def test_different_thread(self): def test_different_thread(self):
# Lock can be released from a different thread. # Lock can be released from a different thread.

View file

@ -2072,6 +2072,41 @@ def reap_threads(func):
return decorator return decorator
@contextlib.contextmanager
def wait_threads_exit(timeout=60.0):
"""
bpo-31234: Context manager to wait until all threads created in the with
statement exit.
Use _thread.count() to check if threads exited. Indirectly, wait until
threads exit the internal t_bootstrap() C function of the _thread module.
threading_setup() and threading_cleanup() are designed to emit a warning
if a test leaves running threads in the background. This context manager
is designed to cleanup threads started by the _thread.start_new_thread()
which doesn't allow to wait for thread exit, whereas thread.Thread has a
join() method.
"""
old_count = _thread._count()
try:
yield
finally:
start_time = time.monotonic()
deadline = start_time + timeout
while True:
count = _thread._count()
if count <= old_count:
break
if time.monotonic() > deadline:
dt = time.monotonic() - start_time
msg = (f"wait_threads() failed to cleanup {count - old_count} "
f"threads after {dt:.1f} seconds "
f"(count: {count}, old count: {old_count})")
raise AssertionError(msg)
time.sleep(0.010)
gc_collect()
def reap_children(): def reap_children():
"""Use this function at the end of test_main() whenever sub-processes """Use this function at the end of test_main() whenever sub-processes
are started. This will help ensure that no extra children (zombies) are started. This will help ensure that no extra children (zombies)

View file

@ -271,6 +271,9 @@ class ThreadableTest:
self.server_ready.set() self.server_ready.set()
def _setUp(self): def _setUp(self):
self.wait_threads = support.wait_threads_exit()
self.wait_threads.__enter__()
self.server_ready = threading.Event() self.server_ready = threading.Event()
self.client_ready = threading.Event() self.client_ready = threading.Event()
self.done = threading.Event() self.done = threading.Event()
@ -297,6 +300,7 @@ class ThreadableTest:
def _tearDown(self): def _tearDown(self):
self.__tearDown() self.__tearDown()
self.done.wait() self.done.wait()
self.wait_threads.__exit__(None, None, None)
if self.queue.qsize(): if self.queue.qsize():
exc = self.queue.get() exc = self.queue.get()

View file

@ -59,12 +59,13 @@ class ThreadRunningTests(BasicThreadTest):
self.done_mutex.release() self.done_mutex.release()
def test_starting_threads(self): def test_starting_threads(self):
# Basic test for thread creation. with support.wait_threads_exit():
for i in range(NUMTASKS): # Basic test for thread creation.
self.newtask() for i in range(NUMTASKS):
verbose_print("waiting for tasks to complete...") self.newtask()
self.done_mutex.acquire() verbose_print("waiting for tasks to complete...")
verbose_print("all tasks done") self.done_mutex.acquire()
verbose_print("all tasks done")
def test_stack_size(self): def test_stack_size(self):
# Various stack size tests. # Various stack size tests.
@ -94,12 +95,13 @@ class ThreadRunningTests(BasicThreadTest):
verbose_print("trying stack_size = (%d)" % tss) verbose_print("trying stack_size = (%d)" % tss)
self.next_ident = 0 self.next_ident = 0
self.created = 0 self.created = 0
for i in range(NUMTASKS): with support.wait_threads_exit():
self.newtask() for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for all tasks to complete") verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire() self.done_mutex.acquire()
verbose_print("all tasks done") verbose_print("all tasks done")
thread.stack_size(0) thread.stack_size(0)
@ -109,25 +111,28 @@ class ThreadRunningTests(BasicThreadTest):
mut = thread.allocate_lock() mut = thread.allocate_lock()
mut.acquire() mut.acquire()
started = [] started = []
def task(): def task():
started.append(None) started.append(None)
mut.acquire() mut.acquire()
mut.release() mut.release()
thread.start_new_thread(task, ())
while not started: with support.wait_threads_exit():
time.sleep(POLL_SLEEP) thread.start_new_thread(task, ())
self.assertEqual(thread._count(), orig + 1) while not started:
# Allow the task to finish. time.sleep(POLL_SLEEP)
mut.release() self.assertEqual(thread._count(), orig + 1)
# The only reliable way to be sure that the thread ended from the # Allow the task to finish.
# interpreter's point of view is to wait for the function object to be mut.release()
# destroyed. # The only reliable way to be sure that the thread ended from the
done = [] # interpreter's point of view is to wait for the function object to be
wr = weakref.ref(task, lambda _: done.append(None)) # destroyed.
del task done = []
while not done: wr = weakref.ref(task, lambda _: done.append(None))
time.sleep(POLL_SLEEP) del task
self.assertEqual(thread._count(), orig) while not done:
time.sleep(POLL_SLEEP)
self.assertEqual(thread._count(), orig)
def test_save_exception_state_on_error(self): def test_save_exception_state_on_error(self):
# See issue #14474 # See issue #14474
@ -140,16 +145,14 @@ class ThreadRunningTests(BasicThreadTest):
except ValueError: except ValueError:
pass pass
real_write(self, *args) real_write(self, *args)
c = thread._count()
started = thread.allocate_lock() started = thread.allocate_lock()
with support.captured_output("stderr") as stderr: with support.captured_output("stderr") as stderr:
real_write = stderr.write real_write = stderr.write
stderr.write = mywrite stderr.write = mywrite
started.acquire() started.acquire()
thread.start_new_thread(task, ()) with support.wait_threads_exit():
started.acquire() thread.start_new_thread(task, ())
while thread._count() > c: started.acquire()
time.sleep(POLL_SLEEP)
self.assertIn("Traceback", stderr.getvalue()) self.assertIn("Traceback", stderr.getvalue())
@ -181,13 +184,14 @@ class Barrier:
class BarrierTest(BasicThreadTest): class BarrierTest(BasicThreadTest):
def test_barrier(self): def test_barrier(self):
self.bar = Barrier(NUMTASKS) with support.wait_threads_exit():
self.running = NUMTASKS self.bar = Barrier(NUMTASKS)
for i in range(NUMTASKS): self.running = NUMTASKS
thread.start_new_thread(self.task2, (i,)) for i in range(NUMTASKS):
verbose_print("waiting for tasks to end") thread.start_new_thread(self.task2, (i,))
self.done_mutex.acquire() verbose_print("waiting for tasks to end")
verbose_print("tasks done") self.done_mutex.acquire()
verbose_print("tasks done")
def task2(self, ident): def task2(self, ident):
for i in range(NUMTRIPS): for i in range(NUMTRIPS):
@ -225,11 +229,10 @@ class TestForkInThread(unittest.TestCase):
@unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork') @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork')
@support.reap_threads @support.reap_threads
def test_forkinthread(self): def test_forkinthread(self):
running = True
status = "not set" status = "not set"
def thread1(): def thread1():
nonlocal running, status nonlocal status
# fork in a thread # fork in a thread
pid = os.fork() pid = os.fork()
@ -244,13 +247,11 @@ class TestForkInThread(unittest.TestCase):
# parent # parent
os.close(self.write_fd) os.close(self.write_fd)
pid, status = os.waitpid(pid, 0) pid, status = os.waitpid(pid, 0)
running = False
thread.start_new_thread(thread1, ()) with support.wait_threads_exit():
self.assertEqual(os.read(self.read_fd, 2), b"OK", thread.start_new_thread(thread1, ())
"Unable to fork() in thread") self.assertEqual(os.read(self.read_fd, 2), b"OK",
while running: "Unable to fork() in thread")
time.sleep(POLL_SLEEP)
self.assertEqual(status, 0) self.assertEqual(status, 0)
def tearDown(self): def tearDown(self):

View file

@ -125,9 +125,10 @@ class ThreadTests(BaseTestCase):
done.set() done.set()
done = threading.Event() done = threading.Event()
ident = [] ident = []
_thread.start_new_thread(f, ()) with support.wait_threads_exit():
done.wait() tid = _thread.start_new_thread(f, ())
self.assertIsNotNone(ident[0]) done.wait()
self.assertEqual(ident[0], tid)
# Kill the "immortal" _DummyThread # Kill the "immortal" _DummyThread
del threading._active[ident[0]] del threading._active[ident[0]]
@ -165,9 +166,10 @@ class ThreadTests(BaseTestCase):
mutex = threading.Lock() mutex = threading.Lock()
mutex.acquire() mutex.acquire()
tid = _thread.start_new_thread(f, (mutex,)) with support.wait_threads_exit():
# Wait for the thread to finish. tid = _thread.start_new_thread(f, (mutex,))
mutex.acquire() # Wait for the thread to finish.
mutex.acquire()
self.assertIn(tid, threading._active) self.assertIn(tid, threading._active)
self.assertIsInstance(threading._active[tid], threading._DummyThread) self.assertIsInstance(threading._active[tid], threading._DummyThread)
#Issue 29376 #Issue 29376

View file

@ -4,8 +4,8 @@ import unittest
import signal import signal
import os import os
import sys import sys
from test.support import run_unittest, import_module from test import support
thread = import_module('_thread') thread = support.import_module('_thread')
import time import time
if (sys.platform[:3] == 'win'): if (sys.platform[:3] == 'win'):
@ -39,13 +39,15 @@ def send_signals():
class ThreadSignals(unittest.TestCase): class ThreadSignals(unittest.TestCase):
def test_signals(self): def test_signals(self):
# Test signal handling semantics of threads. with support.wait_threads_exit():
# We spawn a thread, have the thread send two signals, and # Test signal handling semantics of threads.
# wait for it to finish. Check that we got both signals # We spawn a thread, have the thread send two signals, and
# and that they were run by the main thread. # wait for it to finish. Check that we got both signals
signalled_all.acquire() # and that they were run by the main thread.
self.spawnSignallingThread() signalled_all.acquire()
signalled_all.acquire() self.spawnSignallingThread()
signalled_all.acquire()
# the signals that we asked the kernel to send # the signals that we asked the kernel to send
# will come back, but we don't know when. # will come back, but we don't know when.
# (it might even be after the thread exits # (it might even be after the thread exits
@ -115,17 +117,19 @@ class ThreadSignals(unittest.TestCase):
# thread. # thread.
def other_thread(): def other_thread():
rlock.acquire() rlock.acquire()
thread.start_new_thread(other_thread, ())
# Wait until we can't acquire it without blocking... with support.wait_threads_exit():
while rlock.acquire(blocking=False): thread.start_new_thread(other_thread, ())
rlock.release() # Wait until we can't acquire it without blocking...
time.sleep(0.01) while rlock.acquire(blocking=False):
signal.alarm(1) rlock.release()
t1 = time.time() time.sleep(0.01)
self.assertRaises(KeyboardInterrupt, rlock.acquire, timeout=5) signal.alarm(1)
dt = time.time() - t1 t1 = time.time()
# See rationale above in test_lock_acquire_interruption self.assertRaises(KeyboardInterrupt, rlock.acquire, timeout=5)
self.assertLess(dt, 3.0) dt = time.time() - t1
# See rationale above in test_lock_acquire_interruption
self.assertLess(dt, 3.0)
finally: finally:
signal.signal(signal.SIGALRM, oldalrm) signal.signal(signal.SIGALRM, oldalrm)
@ -133,6 +137,7 @@ class ThreadSignals(unittest.TestCase):
self.sig_recvd = False self.sig_recvd = False
def my_handler(signal, frame): def my_handler(signal, frame):
self.sig_recvd = True self.sig_recvd = True
old_handler = signal.signal(signal.SIGUSR1, my_handler) old_handler = signal.signal(signal.SIGUSR1, my_handler)
try: try:
def other_thread(): def other_thread():
@ -147,14 +152,16 @@ class ThreadSignals(unittest.TestCase):
# the lock acquisition. Then we'll let it run. # the lock acquisition. Then we'll let it run.
time.sleep(0.5) time.sleep(0.5)
lock.release() lock.release()
thread.start_new_thread(other_thread, ())
# Wait until we can't acquire it without blocking... with support.wait_threads_exit():
while lock.acquire(blocking=False): thread.start_new_thread(other_thread, ())
lock.release() # Wait until we can't acquire it without blocking...
time.sleep(0.01) while lock.acquire(blocking=False):
result = lock.acquire() # Block while we receive a signal. lock.release()
self.assertTrue(self.sig_recvd) time.sleep(0.01)
self.assertTrue(result) result = lock.acquire() # Block while we receive a signal.
self.assertTrue(self.sig_recvd)
self.assertTrue(result)
finally: finally:
signal.signal(signal.SIGUSR1, old_handler) signal.signal(signal.SIGUSR1, old_handler)
@ -193,19 +200,20 @@ class ThreadSignals(unittest.TestCase):
os.kill(process_pid, signal.SIGUSR1) os.kill(process_pid, signal.SIGUSR1)
done.release() done.release()
# Send the signals from the non-main thread, since the main thread with support.wait_threads_exit():
# is the only one that can process signals. # Send the signals from the non-main thread, since the main thread
thread.start_new_thread(send_signals, ()) # is the only one that can process signals.
timed_acquire() thread.start_new_thread(send_signals, ())
# Wait for thread to finish timed_acquire()
done.acquire() # Wait for thread to finish
# This allows for some timing and scheduling imprecision done.acquire()
self.assertLess(self.end - self.start, 2.0) # This allows for some timing and scheduling imprecision
self.assertGreater(self.end - self.start, 0.3) self.assertLess(self.end - self.start, 2.0)
# If the signal is received several times before PyErr_CheckSignals() self.assertGreater(self.end - self.start, 0.3)
# is called, the handler will get called less than 40 times. Just # If the signal is received several times before PyErr_CheckSignals()
# check it's been called at least once. # is called, the handler will get called less than 40 times. Just
self.assertGreater(self.sigs_recvd, 0) # check it's been called at least once.
self.assertGreater(self.sigs_recvd, 0)
finally: finally:
signal.signal(signal.SIGUSR1, old_handler) signal.signal(signal.SIGUSR1, old_handler)
@ -219,7 +227,7 @@ def test_main():
oldsigs = registerSignals(handle_signals, handle_signals, handle_signals) oldsigs = registerSignals(handle_signals, handle_signals, handle_signals)
try: try:
run_unittest(ThreadSignals) support.run_unittest(ThreadSignals)
finally: finally:
registerSignals(*oldsigs) registerSignals(*oldsigs)