bpo-31540: Allow passing multiprocessing context to ProcessPoolExecutor (#3682)

This commit is contained in:
Thomas Moreau 2017-10-03 11:53:17 +02:00 committed by Antoine Pitrou
parent efb560eee2
commit e8c368df22
5 changed files with 170 additions and 40 deletions

View file

@ -19,6 +19,7 @@ from concurrent import futures
from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
from concurrent.futures.process import BrokenProcessPool
from multiprocessing import get_context
def create_future(state=PENDING, exception=None, result=None):
@ -56,6 +57,15 @@ class MyObject(object):
pass
class EventfulGCObj():
def __init__(self, ctx):
mgr = get_context(ctx).Manager()
self.event = mgr.Event()
def __del__(self):
self.event.set()
def make_dummy_object(_):
return MyObject()
@ -77,7 +87,13 @@ class ExecutorMixin:
self.t1 = time.time()
try:
self.executor = self.executor_type(max_workers=self.worker_count)
if hasattr(self, "ctx"):
self.executor = self.executor_type(
max_workers=self.worker_count,
mp_context=get_context(self.ctx))
else:
self.executor = self.executor_type(
max_workers=self.worker_count)
except NotImplementedError as e:
self.skipTest(str(e))
self._prime_executor()
@ -107,8 +123,29 @@ class ThreadPoolMixin(ExecutorMixin):
executor_type = futures.ThreadPoolExecutor
class ProcessPoolMixin(ExecutorMixin):
class ProcessPoolForkMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor
ctx = "fork"
def setUp(self):
if sys.platform == "win32":
self.skipTest("require unix system")
super().setUp()
class ProcessPoolSpawnMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor
ctx = "spawn"
class ProcessPoolForkserverMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor
ctx = "forkserver"
def setUp(self):
if sys.platform == "win32":
self.skipTest("require unix system")
super().setUp()
class ExecutorShutdownTest:
@ -124,9 +161,17 @@ class ExecutorShutdownTest:
from concurrent.futures import {executor_type}
from time import sleep
from test.test_concurrent_futures import sleep_and_print
t = {executor_type}(5)
t.submit(sleep_and_print, 1.0, "apple")
""".format(executor_type=self.executor_type.__name__))
if __name__ == "__main__":
context = '{context}'
if context == "":
t = {executor_type}(5)
else:
from multiprocessing import get_context
context = get_context(context)
t = {executor_type}(5, mp_context=context)
t.submit(sleep_and_print, 1.0, "apple")
""".format(executor_type=self.executor_type.__name__,
context=getattr(self, "ctx", "")))
# Errors in atexit hooks don't change the process exit code, check
# stderr manually.
self.assertFalse(err)
@ -194,7 +239,7 @@ class ThreadPoolShutdownTest(ThreadPoolMixin, ExecutorShutdownTest, BaseTestCase
t.join()
class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest, BaseTestCase):
class ProcessPoolShutdownTest(ExecutorShutdownTest):
def _prime_executor(self):
pass
@ -233,6 +278,22 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest, BaseTestCa
call_queue.join_thread()
class ProcessPoolForkShutdownTest(ProcessPoolForkMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolForkserverShutdownTest(ProcessPoolForkserverMixin,
BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolSpawnShutdownTest(ProcessPoolSpawnMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
class WaitTests:
def test_first_completed(self):
@ -352,7 +413,17 @@ class ThreadPoolWaitTests(ThreadPoolMixin, WaitTests, BaseTestCase):
sys.setswitchinterval(oldswitchinterval)
class ProcessPoolWaitTests(ProcessPoolMixin, WaitTests, BaseTestCase):
class ProcessPoolForkWaitTests(ProcessPoolForkMixin, WaitTests, BaseTestCase):
pass
class ProcessPoolForkserverWaitTests(ProcessPoolForkserverMixin, WaitTests,
BaseTestCase):
pass
class ProcessPoolSpawnWaitTests(ProcessPoolSpawnMixin, BaseTestCase,
WaitTests):
pass
@ -440,7 +511,19 @@ class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase
pass
class ProcessPoolAsCompletedTests(ProcessPoolMixin, AsCompletedTests, BaseTestCase):
class ProcessPoolForkAsCompletedTests(ProcessPoolForkMixin, AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolForkserverAsCompletedTests(ProcessPoolForkserverMixin,
AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolSpawnAsCompletedTests(ProcessPoolSpawnMixin, AsCompletedTests,
BaseTestCase):
pass
@ -540,7 +623,7 @@ class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, BaseTestCase):
(os.cpu_count() or 1) * 5)
class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, BaseTestCase):
class ProcessPoolExecutorTest(ExecutorTest):
def test_killed_child(self):
# When a child process is abruptly terminated, the whole pool gets
# "broken".
@ -595,6 +678,34 @@ class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, BaseTestCase):
self.assertIn('raise RuntimeError(123) # some comment',
f1.getvalue())
def test_ressources_gced_in_workers(self):
# Ensure that argument for a job are correctly gc-ed after the job
# is finished
obj = EventfulGCObj(self.ctx)
future = self.executor.submit(id, obj)
future.result()
self.assertTrue(obj.event.wait(timeout=1))
class ProcessPoolForkExecutorTest(ProcessPoolForkMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolForkserverExecutorTest(ProcessPoolForkserverMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolSpawnExecutorTest(ProcessPoolSpawnMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class FutureTests(BaseTestCase):
def test_done_callback_with_result(self):