gh-117378: Only run the new multiprocessing SysPath test when appropriate (GH-126635)

The first version had it running two forkserver and one spawn tests underneath each of the _fork, _forkserver, and _spawn test suites that build off the generic one.

This adds to the existing complexity of the multiprocessing test suite by offering BaseTestCase classes another attribute to control which suites they are invoked under. Practicality vs purity here. :/

Net result: we don't over-run the new test and their internal logic is simplified.
This commit is contained in:
Gregory P. Smith 2024-11-10 13:17:05 -08:00 committed by GitHub
parent 0f6bb28ff3
commit ca878b6e45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -258,6 +258,9 @@ class TimingWrapper(object):
class BaseTestCase(object): class BaseTestCase(object):
ALLOWED_TYPES = ('processes', 'manager', 'threads') ALLOWED_TYPES = ('processes', 'manager', 'threads')
# If not empty, limit which start method suites run this class.
START_METHODS: set[str] = set()
start_method = None # set by install_tests_in_module_dict()
def assertTimingAlmostEqual(self, a, b): def assertTimingAlmostEqual(self, a, b):
if CHECK_TIMINGS: if CHECK_TIMINGS:
@ -6403,7 +6406,9 @@ class _TestAtExit(BaseTestCase):
class _TestSpawnedSysPath(BaseTestCase): class _TestSpawnedSysPath(BaseTestCase):
"""Test that sys.path is setup in forkserver and spawn processes.""" """Test that sys.path is setup in forkserver and spawn processes."""
ALLOWED_TYPES = ('processes',) ALLOWED_TYPES = {'processes'}
# Not applicable to fork which inherits everything from the process as is.
START_METHODS = {"forkserver", "spawn"}
def setUp(self): def setUp(self):
self._orig_sys_path = list(sys.path) self._orig_sys_path = list(sys.path)
@ -6415,11 +6420,8 @@ class _TestSpawnedSysPath(BaseTestCase):
sys.path[:] = [p for p in sys.path if p] # remove any existing ""s sys.path[:] = [p for p in sys.path if p] # remove any existing ""s
sys.path.insert(0, self._temp_dir) sys.path.insert(0, self._temp_dir)
sys.path.insert(0, "") # Replaced with an abspath in child. sys.path.insert(0, "") # Replaced with an abspath in child.
try: self.assertIn(self.start_method, self.START_METHODS)
self._ctx_forkserver = multiprocessing.get_context("forkserver") self._ctx = multiprocessing.get_context(self.start_method)
except ValueError:
self._ctx_forkserver = None
self._ctx_spawn = multiprocessing.get_context("spawn")
def tearDown(self): def tearDown(self):
sys.path[:] = self._orig_sys_path sys.path[:] = self._orig_sys_path
@ -6430,15 +6432,15 @@ class _TestSpawnedSysPath(BaseTestCase):
queue.put(tuple(sys.modules)) queue.put(tuple(sys.modules))
def test_forkserver_preload_imports_sys_path(self): def test_forkserver_preload_imports_sys_path(self):
ctx = self._ctx_forkserver if self._ctx.get_start_method() != "forkserver":
if not ctx: self.skipTest("forkserver specific test.")
self.skipTest("requires forkserver start method.")
self.assertNotIn(self._mod_name, sys.modules) self.assertNotIn(self._mod_name, sys.modules)
multiprocessing.forkserver._forkserver._stop() # Must be fresh. multiprocessing.forkserver._forkserver._stop() # Must be fresh.
ctx.set_forkserver_preload( self._ctx.set_forkserver_preload(
["test.test_multiprocessing_forkserver", self._mod_name]) ["test.test_multiprocessing_forkserver", self._mod_name])
q = ctx.Queue() q = self._ctx.Queue()
proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) proc = self._ctx.Process(
target=self.enq_imported_module_names, args=(q,))
proc.start() proc.start()
proc.join() proc.join()
child_imported_modules = q.get() child_imported_modules = q.get()
@ -6456,23 +6458,19 @@ class _TestSpawnedSysPath(BaseTestCase):
queue.put(None) queue.put(None)
def test_child_sys_path(self): def test_child_sys_path(self):
for ctx in (self._ctx_spawn, self._ctx_forkserver): q = self._ctx.Queue()
if not ctx: proc = self._ctx.Process(
continue target=self.enq_sys_path_and_import, args=(q, self._mod_name))
with self.subTest(f"{ctx.get_start_method()} start method"): proc.start()
q = ctx.Queue() proc.join()
proc = ctx.Process(target=self.enq_sys_path_and_import, child_sys_path = q.get()
args=(q, self._mod_name)) import_error = q.get()
proc.start() q.close()
proc.join() self.assertNotIn("", child_sys_path) # replaced by an abspath
child_sys_path = q.get() self.assertIn(self._temp_dir, child_sys_path) # our addition
import_error = q.get() # ignore the first element, it is the absolute "" replacement
q.close() self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertNotIn("", child_sys_path) # replaced by an abspath self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")
self.assertIn(self._temp_dir, child_sys_path) # our addition
# ignore the first element, it is the absolute "" replacement
self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")
class MiscTestCase(unittest.TestCase): class MiscTestCase(unittest.TestCase):
@ -6669,6 +6667,8 @@ def install_tests_in_module_dict(remote_globs, start_method,
if base is BaseTestCase: if base is BaseTestCase:
continue continue
assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES
if base.START_METHODS and start_method not in base.START_METHODS:
continue # class not intended for this start method.
for type_ in base.ALLOWED_TYPES: for type_ in base.ALLOWED_TYPES:
if only_type and type_ != only_type: if only_type and type_ != only_type:
continue continue
@ -6682,6 +6682,7 @@ def install_tests_in_module_dict(remote_globs, start_method,
Temp = hashlib_helper.requires_hashdigest('sha256')(Temp) Temp = hashlib_helper.requires_hashdigest('sha256')(Temp)
Temp.__name__ = Temp.__qualname__ = newname Temp.__name__ = Temp.__qualname__ = newname
Temp.__module__ = __module__ Temp.__module__ = __module__
Temp.start_method = start_method
remote_globs[newname] = Temp remote_globs[newname] = Temp
elif issubclass(base, unittest.TestCase): elif issubclass(base, unittest.TestCase):
if only_type: if only_type: