mirror of
https://github.com/python/cpython.git
synced 2025-07-07 19:35:27 +00:00
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
"""Run a test case multiple times in parallel threads."""
|
|
|
|
import copy
|
|
import threading
|
|
import unittest
|
|
|
|
from unittest import TestCase
|
|
|
|
|
|
class ParallelTestCase(TestCase):
|
|
def __init__(self, test_case: TestCase, num_threads: int):
|
|
self.test_case = test_case
|
|
self.num_threads = num_threads
|
|
self._testMethodName = test_case._testMethodName
|
|
self._testMethodDoc = test_case._testMethodDoc
|
|
|
|
def __str__(self):
|
|
return f"{str(self.test_case)} [threads={self.num_threads}]"
|
|
|
|
def run_worker(self, test_case: TestCase, result: unittest.TestResult,
|
|
barrier: threading.Barrier):
|
|
barrier.wait()
|
|
test_case.run(result)
|
|
|
|
def run(self, result=None):
|
|
if result is None:
|
|
result = test_case.defaultTestResult()
|
|
startTestRun = getattr(result, 'startTestRun', None)
|
|
stopTestRun = getattr(result, 'stopTestRun', None)
|
|
if startTestRun is not None:
|
|
startTestRun()
|
|
else:
|
|
stopTestRun = None
|
|
|
|
# Called at the beginning of each test. See TestCase.run.
|
|
result.startTest(self)
|
|
|
|
cases = [copy.copy(self.test_case) for _ in range(self.num_threads)]
|
|
results = [unittest.TestResult() for _ in range(self.num_threads)]
|
|
|
|
barrier = threading.Barrier(self.num_threads)
|
|
threads = []
|
|
for i, (case, r) in enumerate(zip(cases, results)):
|
|
thread = threading.Thread(target=self.run_worker,
|
|
args=(case, r, barrier),
|
|
name=f"{str(self.test_case)}-{i}",
|
|
daemon=True)
|
|
threads.append(thread)
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
|
|
for threads in threads:
|
|
threads.join()
|
|
|
|
# Aggregate test results
|
|
if all(r.wasSuccessful() for r in results):
|
|
result.addSuccess(self)
|
|
|
|
# Note: We can't call result.addError, result.addFailure, etc. because
|
|
# we no longer have the original exception, just the string format.
|
|
for r in results:
|
|
if len(r.errors) > 0 or len(r.failures) > 0:
|
|
result._mirrorOutput = True
|
|
result.errors.extend(r.errors)
|
|
result.failures.extend(r.failures)
|
|
result.skipped.extend(r.skipped)
|
|
result.expectedFailures.extend(r.expectedFailures)
|
|
result.unexpectedSuccesses.extend(r.unexpectedSuccesses)
|
|
result.collectedDurations.extend(r.collectedDurations)
|
|
|
|
if any(r.shouldStop for r in results):
|
|
result.stop()
|
|
|
|
# Test has finished running
|
|
result.stopTest(self)
|
|
if stopTestRun is not None:
|
|
stopTestRun()
|