bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor (#4241)

* bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor

* Fix docstring
This commit is contained in:
Antoine Pitrou 2017-11-04 11:05:49 +01:00 committed by GitHub
parent b838cc3ff4
commit 63ff4131af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 246 additions and 81 deletions

View file

@ -7,6 +7,7 @@ test.support.import_module('multiprocessing.synchronize')
from test.support.script_helper import assert_python_ok
import contextlib
import itertools
import os
import sys
@ -17,7 +18,8 @@ import weakref
from concurrent import futures
from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future,
BrokenExecutor)
from concurrent.futures.process import BrokenProcessPool
from multiprocessing import get_context
@ -37,11 +39,12 @@ CANCELLED_AND_NOTIFIED_FUTURE = create_future(state=CANCELLED_AND_NOTIFIED)
EXCEPTION_FUTURE = create_future(state=FINISHED, exception=OSError())
SUCCESSFUL_FUTURE = create_future(state=FINISHED, result=42)
INITIALIZER_STATUS = 'uninitialized'
def mul(x, y):
return x * y
def sleep_and_raise(t):
time.sleep(t)
raise Exception('this is an exception')
@ -51,6 +54,17 @@ def sleep_and_print(t, msg):
print(msg)
sys.stdout.flush()
def init(x):
global INITIALIZER_STATUS
INITIALIZER_STATUS = x
def get_init_status():
return INITIALIZER_STATUS
def init_fail():
time.sleep(0.1) # let some futures be scheduled
raise ValueError('error in initializer')
class MyObject(object):
def my_method(self):
@ -81,6 +95,7 @@ class BaseTestCase(unittest.TestCase):
class ExecutorMixin:
worker_count = 5
executor_kwargs = {}
def setUp(self):
super().setUp()
@ -90,10 +105,12 @@ class ExecutorMixin:
if hasattr(self, "ctx"):
self.executor = self.executor_type(
max_workers=self.worker_count,
mp_context=get_context(self.ctx))
mp_context=get_context(self.ctx),
**self.executor_kwargs)
else:
self.executor = self.executor_type(
max_workers=self.worker_count)
max_workers=self.worker_count,
**self.executor_kwargs)
except NotImplementedError as e:
self.skipTest(str(e))
self._prime_executor()
@ -114,7 +131,6 @@ class ExecutorMixin:
# tests. This should reduce the probability of timeouts in the tests.
futures = [self.executor.submit(time.sleep, 0.1)
for _ in range(self.worker_count)]
for f in futures:
f.result()
@ -148,6 +164,90 @@ class ProcessPoolForkserverMixin(ExecutorMixin):
super().setUp()
def create_executor_tests(mixin, bases=(BaseTestCase,),
executor_mixins=(ThreadPoolMixin,
ProcessPoolForkMixin,
ProcessPoolForkserverMixin,
ProcessPoolSpawnMixin)):
def strip_mixin(name):
if name.endswith(('Mixin', 'Tests')):
return name[:-5]
elif name.endswith('Test'):
return name[:-4]
else:
return name
for exe in executor_mixins:
name = ("%s%sTest"
% (strip_mixin(exe.__name__), strip_mixin(mixin.__name__)))
cls = type(name, (mixin,) + (exe,) + bases, {})
globals()[name] = cls
class InitializerMixin(ExecutorMixin):
worker_count = 2
def setUp(self):
global INITIALIZER_STATUS
INITIALIZER_STATUS = 'uninitialized'
self.executor_kwargs = dict(initializer=init,
initargs=('initialized',))
super().setUp()
def test_initializer(self):
futures = [self.executor.submit(get_init_status)
for _ in range(self.worker_count)]
for f in futures:
self.assertEqual(f.result(), 'initialized')
class FailingInitializerMixin(ExecutorMixin):
worker_count = 2
def setUp(self):
self.executor_kwargs = dict(initializer=init_fail)
super().setUp()
def test_initializer(self):
with self._assert_logged('ValueError: error in initializer'):
try:
future = self.executor.submit(get_init_status)
except BrokenExecutor:
# Perhaps the executor is already broken
pass
else:
with self.assertRaises(BrokenExecutor):
future.result()
# At some point, the executor should break
t1 = time.time()
while not self.executor._broken:
if time.time() - t1 > 5:
self.fail("executor not broken after 5 s.")
time.sleep(0.01)
# ... and from this point submit() is guaranteed to fail
with self.assertRaises(BrokenExecutor):
self.executor.submit(get_init_status)
def _prime_executor(self):
pass
@contextlib.contextmanager
def _assert_logged(self, msg):
if self.executor_type is futures.ProcessPoolExecutor:
# No easy way to catch the child processes' stderr
yield
else:
with self.assertLogs('concurrent.futures', 'CRITICAL') as cm:
yield
self.assertTrue(any(msg in line for line in cm.output),
cm.output)
create_executor_tests(InitializerMixin)
create_executor_tests(FailingInitializerMixin)
class ExecutorShutdownTest:
def test_run_after_shutdown(self):
self.executor.shutdown()
@ -278,20 +378,11 @@ class ProcessPoolShutdownTest(ExecutorShutdownTest):
call_queue.join_thread()
class ProcessPoolForkShutdownTest(ProcessPoolForkMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolForkserverShutdownTest(ProcessPoolForkserverMixin,
BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolSpawnShutdownTest(ProcessPoolSpawnMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
create_executor_tests(ProcessPoolShutdownTest,
executor_mixins=(ProcessPoolForkMixin,
ProcessPoolForkserverMixin,
ProcessPoolSpawnMixin))
class WaitTests:
@ -413,18 +504,10 @@ class ThreadPoolWaitTests(ThreadPoolMixin, WaitTests, BaseTestCase):
sys.setswitchinterval(oldswitchinterval)
class ProcessPoolForkWaitTests(ProcessPoolForkMixin, WaitTests, BaseTestCase):
pass
class ProcessPoolForkserverWaitTests(ProcessPoolForkserverMixin, WaitTests,
BaseTestCase):
pass
class ProcessPoolSpawnWaitTests(ProcessPoolSpawnMixin, BaseTestCase,
WaitTests):
pass
create_executor_tests(WaitTests,
executor_mixins=(ProcessPoolForkMixin,
ProcessPoolForkserverMixin,
ProcessPoolSpawnMixin))
class AsCompletedTests:
@ -507,24 +590,7 @@ class AsCompletedTests:
self.assertEqual(str(cm.exception), '2 (of 4) futures unfinished')
class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase):
pass
class ProcessPoolForkAsCompletedTests(ProcessPoolForkMixin, AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolForkserverAsCompletedTests(ProcessPoolForkserverMixin,
AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolSpawnAsCompletedTests(ProcessPoolSpawnMixin, AsCompletedTests,
BaseTestCase):
pass
create_executor_tests(AsCompletedTests)
class ExecutorTest:
@ -688,23 +754,10 @@ class ProcessPoolExecutorTest(ExecutorTest):
self.assertTrue(obj.event.wait(timeout=1))
class ProcessPoolForkExecutorTest(ProcessPoolForkMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolForkserverExecutorTest(ProcessPoolForkserverMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolSpawnExecutorTest(ProcessPoolSpawnMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
create_executor_tests(ProcessPoolExecutorTest,
executor_mixins=(ProcessPoolForkMixin,
ProcessPoolForkserverMixin,
ProcessPoolSpawnMixin))
class FutureTests(BaseTestCase):
@ -932,6 +985,7 @@ class FutureTests(BaseTestCase):
self.assertTrue(isinstance(f1.exception(timeout=5), OSError))
t.join()
@test.support.reap_threads
def test_main():
try: