gh-124694: Add concurrent.futures.InterpreterPoolExecutor (gh-124548)

This is an implementation of InterpreterPoolExecutor that builds on ThreadPoolExecutor.

(Note that this is not tied to PEP 734, which is strictly about adding a new stdlib module.)

Possible future improvements:

* support passing a script for the initializer or to submit()
* support passing (most) arbitrary functions without pickling
* support passing closures
* optionally exec functions against __main__ instead of the their original module
This commit is contained in:
Eric Snow 2024-10-16 16:50:46 -06:00 committed by GitHub
parent a38fef4439
commit a5a7f5e16d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 828 additions and 40 deletions

View file

@ -43,19 +43,46 @@ if hasattr(os, 'register_at_fork'):
after_in_parent=_global_shutdown_lock.release)
class _WorkItem:
def __init__(self, future, fn, args, kwargs):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs
class WorkerContext:
def run(self):
@classmethod
def prepare(cls, initializer, initargs):
if initializer is not None:
if not callable(initializer):
raise TypeError("initializer must be a callable")
def create_context():
return cls(initializer, initargs)
def resolve_task(fn, args, kwargs):
return (fn, args, kwargs)
return create_context, resolve_task
def __init__(self, initializer, initargs):
self.initializer = initializer
self.initargs = initargs
def initialize(self):
if self.initializer is not None:
self.initializer(*self.initargs)
def finalize(self):
pass
def run(self, task):
fn, args, kwargs = task
return fn(*args, **kwargs)
class _WorkItem:
def __init__(self, future, task):
self.future = future
self.task = task
def run(self, ctx):
if not self.future.set_running_or_notify_cancel():
return
try:
result = self.fn(*self.args, **self.kwargs)
result = ctx.run(self.task)
except BaseException as exc:
self.future.set_exception(exc)
# Break a reference cycle with the exception 'exc'
@ -66,16 +93,15 @@ class _WorkItem:
__class_getitem__ = classmethod(types.GenericAlias)
def _worker(executor_reference, work_queue, initializer, initargs):
if initializer is not None:
try:
initializer(*initargs)
except BaseException:
_base.LOGGER.critical('Exception in initializer:', exc_info=True)
executor = executor_reference()
if executor is not None:
executor._initializer_failed()
return
def _worker(executor_reference, ctx, work_queue):
try:
ctx.initialize()
except BaseException:
_base.LOGGER.critical('Exception in initializer:', exc_info=True)
executor = executor_reference()
if executor is not None:
executor._initializer_failed()
return
try:
while True:
try:
@ -89,7 +115,7 @@ def _worker(executor_reference, work_queue, initializer, initargs):
work_item = work_queue.get(block=True)
if work_item is not None:
work_item.run()
work_item.run(ctx)
# Delete references to object. See GH-60488
del work_item
continue
@ -110,6 +136,8 @@ def _worker(executor_reference, work_queue, initializer, initargs):
del executor
except BaseException:
_base.LOGGER.critical('Exception in worker', exc_info=True)
finally:
ctx.finalize()
class BrokenThreadPool(_base.BrokenExecutor):
@ -120,11 +148,17 @@ class BrokenThreadPool(_base.BrokenExecutor):
class ThreadPoolExecutor(_base.Executor):
BROKEN = BrokenThreadPool
# Used to assign unique thread names when thread_name_prefix is not supplied.
_counter = itertools.count().__next__
@classmethod
def prepare_context(cls, initializer, initargs):
return WorkerContext.prepare(initializer, initargs)
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=()):
initializer=None, initargs=(), **ctxkwargs):
"""Initializes a new ThreadPoolExecutor instance.
Args:
@ -133,6 +167,7 @@ class ThreadPoolExecutor(_base.Executor):
thread_name_prefix: An optional name prefix to give our threads.
initializer: A callable used to initialize worker threads.
initargs: A tuple of arguments to pass to the initializer.
ctxkwargs: Additional arguments to cls.prepare_context().
"""
if max_workers is None:
# ThreadPoolExecutor is often used to:
@ -146,8 +181,9 @@ class ThreadPoolExecutor(_base.Executor):
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
if initializer is not None and not callable(initializer):
raise TypeError("initializer must be a callable")
(self._create_worker_context,
self._resolve_work_item_task,
) = type(self).prepare_context(initializer, initargs, **ctxkwargs)
self._max_workers = max_workers
self._work_queue = queue.SimpleQueue()
@ -158,13 +194,11 @@ class ThreadPoolExecutor(_base.Executor):
self._shutdown_lock = threading.Lock()
self._thread_name_prefix = (thread_name_prefix or
("ThreadPoolExecutor-%d" % self._counter()))
self._initializer = initializer
self._initargs = initargs
def submit(self, fn, /, *args, **kwargs):
with self._shutdown_lock, _global_shutdown_lock:
if self._broken:
raise BrokenThreadPool(self._broken)
raise self.BROKEN(self._broken)
if self._shutdown:
raise RuntimeError('cannot schedule new futures after shutdown')
@ -173,7 +207,8 @@ class ThreadPoolExecutor(_base.Executor):
'interpreter shutdown')
f = _base.Future()
w = _WorkItem(f, fn, args, kwargs)
task = self._resolve_work_item_task(fn, args, kwargs)
w = _WorkItem(f, task)
self._work_queue.put(w)
self._adjust_thread_count()
@ -196,9 +231,8 @@ class ThreadPoolExecutor(_base.Executor):
num_threads)
t = threading.Thread(name=thread_name, target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs))
self._create_worker_context(),
self._work_queue))
t.start()
self._threads.add(t)
_threads_queues[t] = self._work_queue
@ -214,7 +248,7 @@ class ThreadPoolExecutor(_base.Executor):
except queue.Empty:
break
if work_item is not None:
work_item.future.set_exception(BrokenThreadPool(self._broken))
work_item.future.set_exception(self.BROKEN(self._broken))
def shutdown(self, wait=True, *, cancel_futures=False):
with self._shutdown_lock: