Issue #18999: Make multiprocessing use context objects.

This allows different parts of a program to use different methods for
starting processes without interfering with each other.
This commit is contained in:
Richard Oudkerk 2013-10-16 16:41:56 +01:00
parent 3e4b52875e
commit b1694cf588
20 changed files with 733 additions and 611 deletions

View file

@ -98,8 +98,8 @@ necessary, see :ref:`multiprocessing-programming`.
Start methods Contexts and start methods
~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~
Depending on the platform, :mod:`multiprocessing` supports three ways Depending on the platform, :mod:`multiprocessing` supports three ways
to start a process. These *start methods* are to start a process. These *start methods* are
@ -132,7 +132,7 @@ to start a process. These *start methods* are
unnecessary resources are inherited. unnecessary resources are inherited.
Available on Unix platforms which support passing file descriptors Available on Unix platforms which support passing file descriptors
over unix pipes. over Unix pipes.
Before Python 3.4 *fork* was the only option available on Unix. Also, Before Python 3.4 *fork* was the only option available on Unix. Also,
prior to Python 3.4, child processes would inherit all the parents prior to Python 3.4, child processes would inherit all the parents
@ -153,18 +153,46 @@ example::
import multiprocessing as mp import multiprocessing as mp
def foo(): def foo(q):
print('hello') q.put('hello')
if __name__ == '__main__': if __name__ == '__main__':
mp.set_start_method('spawn') mp.set_start_method('spawn')
p = mp.Process(target=foo) q = mp.Queue()
p = mp.Process(target=foo, args=(q,))
p.start() p.start()
print(q.get())
p.join() p.join()
:func:`set_start_method` should not be used more than once in the :func:`set_start_method` should not be used more than once in the
program. program.
Alternatively, you can use :func:`get_context` to obtain a context
object. Context objects have the same API as the multiprocessing
module, and allow one to use multiple start methods in the same
program. ::
import multiprocessing as mp
def foo(q):
q.put('hello')
if __name__ == '__main__':
ctx = mp.get_context('spawn')
q = ctx.Queue()
p = ctx.Process(target=foo, args=(q,))
p.start()
print(q.get())
p.join()
Note that objects related to one context may not be compatible with
processes for a different context. In particular, locks created using
the *fork* context cannot be passed to a processes started using the
*spawn* or *forkserver* start methods.
A library which wants to use a particular start method should probably
use :func:`get_context` to avoid interfering with the choice of the
library user.
Exchanging objects between processes Exchanging objects between processes
@ -859,11 +887,30 @@ Miscellaneous
.. versionadded:: 3.4 .. versionadded:: 3.4
.. function:: get_start_method() .. function:: get_context(method=None)
Return the current start method. This can be ``'fork'``, Return a context object which has the same attributes as the
``'spawn'`` or ``'forkserver'``. ``'fork'`` is the default on :mod:`multiprocessing` module.
Unix, while ``'spawn'`` is the default on Windows.
If *method* is *None* then the default context is returned.
Otherwise *method* should be ``'fork'``, ``'spawn'``,
``'forkserver'``. :exc:`ValueError` is raised if the specified
start method is not available.
.. versionadded:: 3.4
.. function:: get_start_method(allow_none=False)
Return the name of start method used for starting processes.
If the start method has not been fixed and *allow_none* is false,
then the start method is fixed to the default and the name is
returned. If the start method has not been fixed and *allow_none*
is true then *None* is returned.
The return value can be ``'fork'``, ``'spawn'``, ``'forkserver'``
or *None*. ``'fork'`` is the default on Unix, while ``'spawn'`` is
the default on Windows.
.. versionadded:: 3.4 .. versionadded:: 3.4
@ -1785,7 +1832,7 @@ Process Pools
One can create a pool of processes which will carry out tasks submitted to it One can create a pool of processes which will carry out tasks submitted to it
with the :class:`Pool` class. with the :class:`Pool` class.
.. class:: Pool([processes[, initializer[, initargs[, maxtasksperchild]]]]) .. class:: Pool([processes[, initializer[, initargs[, maxtasksperchild [, context]]]]])
A process pool object which controls a pool of worker processes to which jobs A process pool object which controls a pool of worker processes to which jobs
can be submitted. It supports asynchronous results with timeouts and can be submitted. It supports asynchronous results with timeouts and
@ -1805,6 +1852,13 @@ with the :class:`Pool` class.
unused resources to be freed. The default *maxtasksperchild* is None, which unused resources to be freed. The default *maxtasksperchild* is None, which
means worker processes will live as long as the pool. means worker processes will live as long as the pool.
.. versionadded:: 3.4
*context* can be used to specify the context used for starting
the worker processes. Usually a pool is created using the
function :func:`multiprocessing.Pool` or the :meth:`Pool` method
of a context object. In both cases *context* is set
appropriately.
.. note:: .. note::
Worker processes within a :class:`Pool` typically live for the complete Worker processes within a :class:`Pool` typically live for the complete

View file

@ -12,27 +12,16 @@
# Licensed to PSF under a Contributor Agreement. # Licensed to PSF under a Contributor Agreement.
# #
__version__ = '0.70a1'
__all__ = [
'Process', 'current_process', 'active_children', 'freeze_support',
'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger',
'allow_connection_pickling', 'BufferTooShort', 'TimeoutError',
'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
'Event', 'Barrier', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool',
'Value', 'Array', 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING',
'set_executable', 'set_start_method', 'get_start_method',
'get_all_start_methods', 'set_forkserver_preload'
]
#
# Imports
#
import os
import sys import sys
from . import context
from .process import Process, current_process, active_children #
# Copy stuff from default context
#
globals().update((name, getattr(context._default_context, name))
for name in context._default_context.__all__)
__all__ = context._default_context.__all__
# #
# XXX These should not really be documented or public. # XXX These should not really be documented or public.
@ -47,240 +36,3 @@ SUBWARNING = 25
if '__main__' in sys.modules: if '__main__' in sys.modules:
sys.modules['__mp_main__'] = sys.modules['__main__'] sys.modules['__mp_main__'] = sys.modules['__main__']
#
# Exceptions
#
class ProcessError(Exception):
pass
class BufferTooShort(ProcessError):
pass
class TimeoutError(ProcessError):
pass
class AuthenticationError(ProcessError):
pass
#
# Definitions not depending on native semaphores
#
def Manager():
'''
Returns a manager associated with a running server process
The managers methods such as `Lock()`, `Condition()` and `Queue()`
can be used to create shared objects.
'''
from .managers import SyncManager
m = SyncManager()
m.start()
return m
def Pipe(duplex=True):
'''
Returns two connection object connected by a pipe
'''
from .connection import Pipe
return Pipe(duplex)
def cpu_count():
'''
Returns the number of CPUs in the system
'''
num = os.cpu_count()
if num is None:
raise NotImplementedError('cannot determine number of cpus')
else:
return num
def freeze_support():
'''
Check whether this is a fake forked process in a frozen executable.
If so then run code specified by commandline and exit.
'''
if sys.platform == 'win32' and getattr(sys, 'frozen', False):
from .spawn import freeze_support
freeze_support()
def get_logger():
'''
Return package logger -- if it does not already exist then it is created
'''
from .util import get_logger
return get_logger()
def log_to_stderr(level=None):
'''
Turn on logging and add a handler which prints to stderr
'''
from .util import log_to_stderr
return log_to_stderr(level)
def allow_connection_pickling():
'''
Install support for sending connections and sockets between processes
'''
# This is undocumented. In previous versions of multiprocessing
# its only effect was to make socket objects inheritable on Windows.
from . import connection
#
# Definitions depending on native semaphores
#
def Lock():
'''
Returns a non-recursive lock object
'''
from .synchronize import Lock
return Lock()
def RLock():
'''
Returns a recursive lock object
'''
from .synchronize import RLock
return RLock()
def Condition(lock=None):
'''
Returns a condition object
'''
from .synchronize import Condition
return Condition(lock)
def Semaphore(value=1):
'''
Returns a semaphore object
'''
from .synchronize import Semaphore
return Semaphore(value)
def BoundedSemaphore(value=1):
'''
Returns a bounded semaphore object
'''
from .synchronize import BoundedSemaphore
return BoundedSemaphore(value)
def Event():
'''
Returns an event object
'''
from .synchronize import Event
return Event()
def Barrier(parties, action=None, timeout=None):
'''
Returns a barrier object
'''
from .synchronize import Barrier
return Barrier(parties, action, timeout)
def Queue(maxsize=0):
'''
Returns a queue object
'''
from .queues import Queue
return Queue(maxsize)
def JoinableQueue(maxsize=0):
'''
Returns a queue object
'''
from .queues import JoinableQueue
return JoinableQueue(maxsize)
def SimpleQueue():
'''
Returns a queue object
'''
from .queues import SimpleQueue
return SimpleQueue()
def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None):
'''
Returns a process pool object
'''
from .pool import Pool
return Pool(processes, initializer, initargs, maxtasksperchild)
def RawValue(typecode_or_type, *args):
'''
Returns a shared object
'''
from .sharedctypes import RawValue
return RawValue(typecode_or_type, *args)
def RawArray(typecode_or_type, size_or_initializer):
'''
Returns a shared array
'''
from .sharedctypes import RawArray
return RawArray(typecode_or_type, size_or_initializer)
def Value(typecode_or_type, *args, lock=True):
'''
Returns a synchronized shared object
'''
from .sharedctypes import Value
return Value(typecode_or_type, *args, lock=lock)
def Array(typecode_or_type, size_or_initializer, *, lock=True):
'''
Returns a synchronized shared array
'''
from .sharedctypes import Array
return Array(typecode_or_type, size_or_initializer, lock=lock)
#
#
#
def set_executable(executable):
'''
Sets the path to a python.exe or pythonw.exe binary used to run
child processes instead of sys.executable when using the 'spawn'
start method. Useful for people embedding Python.
'''
from .spawn import set_executable
set_executable(executable)
def set_start_method(method):
'''
Set method for starting processes: 'fork', 'spawn' or 'forkserver'.
'''
from .popen import set_start_method
set_start_method(method)
def get_start_method():
'''
Get method for starting processes: 'fork', 'spawn' or 'forkserver'.
'''
from .popen import get_start_method
return get_start_method()
def get_all_start_methods():
'''
Get list of availables start methods, default first.
'''
from .popen import get_all_start_methods
return get_all_start_methods()
def set_forkserver_preload(module_names):
'''
Set list of module names to try to load in the forkserver process
when it is started. Properly chosen this can significantly reduce
the cost of starting a new process using the forkserver method.
The default list is ['__main__'].
'''
try:
from .forkserver import set_forkserver_preload
except ImportError:
pass
else:
set_forkserver_preload(module_names)

View file

@ -0,0 +1,348 @@
import os
import sys
import threading
from . import process
__all__ = [] # things are copied from here to __init__.py
#
# Exceptions
#
class ProcessError(Exception):
pass
class BufferTooShort(ProcessError):
pass
class TimeoutError(ProcessError):
pass
class AuthenticationError(ProcessError):
pass
#
# Base type for contexts
#
class BaseContext(object):
ProcessError = ProcessError
BufferTooShort = BufferTooShort
TimeoutError = TimeoutError
AuthenticationError = AuthenticationError
current_process = staticmethod(process.current_process)
active_children = staticmethod(process.active_children)
def cpu_count(self):
'''Returns the number of CPUs in the system'''
num = os.cpu_count()
if num is None:
raise NotImplementedError('cannot determine number of cpus')
else:
return num
def Manager(self):
'''Returns a manager associated with a running server process
The managers methods such as `Lock()`, `Condition()` and `Queue()`
can be used to create shared objects.
'''
from .managers import SyncManager
m = SyncManager(ctx=self.get_context())
m.start()
return m
def Pipe(self, duplex=True):
'''Returns two connection object connected by a pipe'''
from .connection import Pipe
return Pipe(duplex)
def Lock(self):
'''Returns a non-recursive lock object'''
from .synchronize import Lock
return Lock(ctx=self.get_context())
def RLock(self):
'''Returns a recursive lock object'''
from .synchronize import RLock
return RLock(ctx=self.get_context())
def Condition(self, lock=None):
'''Returns a condition object'''
from .synchronize import Condition
return Condition(lock, ctx=self.get_context())
def Semaphore(self, value=1):
'''Returns a semaphore object'''
from .synchronize import Semaphore
return Semaphore(value, ctx=self.get_context())
def BoundedSemaphore(self, value=1):
'''Returns a bounded semaphore object'''
from .synchronize import BoundedSemaphore
return BoundedSemaphore(value, ctx=self.get_context())
def Event(self):
'''Returns an event object'''
from .synchronize import Event
return Event(ctx=self.get_context())
def Barrier(self, parties, action=None, timeout=None):
'''Returns a barrier object'''
from .synchronize import Barrier
return Barrier(parties, action, timeout, ctx=self.get_context())
def Queue(self, maxsize=0):
'''Returns a queue object'''
from .queues import Queue
return Queue(maxsize, ctx=self.get_context())
def JoinableQueue(self, maxsize=0):
'''Returns a queue object'''
from .queues import JoinableQueue
return JoinableQueue(maxsize, ctx=self.get_context())
def SimpleQueue(self):
'''Returns a queue object'''
from .queues import SimpleQueue
return SimpleQueue(ctx=self.get_context())
def Pool(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None):
'''Returns a process pool object'''
from .pool import Pool
return Pool(processes, initializer, initargs, maxtasksperchild,
context=self.get_context())
def RawValue(self, typecode_or_type, *args):
'''Returns a shared object'''
from .sharedctypes import RawValue
return RawValue(typecode_or_type, *args)
def RawArray(self, typecode_or_type, size_or_initializer):
'''Returns a shared array'''
from .sharedctypes import RawArray
return RawArray(typecode_or_type, size_or_initializer)
def Value(self, typecode_or_type, *args, lock=True):
'''Returns a synchronized shared object'''
from .sharedctypes import Value
return Value(typecode_or_type, *args, lock=lock,
ctx=self.get_context())
def Array(self, typecode_or_type, size_or_initializer, *, lock=True):
'''Returns a synchronized shared array'''
from .sharedctypes import Array
return Array(typecode_or_type, size_or_initializer, lock=lock,
ctx=self.get_context())
def freeze_support(self):
'''Check whether this is a fake forked process in a frozen executable.
If so then run code specified by commandline and exit.
'''
if sys.platform == 'win32' and getattr(sys, 'frozen', False):
from .spawn import freeze_support
freeze_support()
def get_logger(self):
'''Return package logger -- if it does not already exist then
it is created.
'''
from .util import get_logger
return get_logger()
def log_to_stderr(self, level=None):
'''Turn on logging and add a handler which prints to stderr'''
from .util import log_to_stderr
return log_to_stderr(level)
def allow_connection_pickling(self):
'''Install support for sending connections and sockets
between processes
'''
# This is undocumented. In previous versions of multiprocessing
# its only effect was to make socket objects inheritable on Windows.
from . import connection
def set_executable(self, executable):
'''Sets the path to a python.exe or pythonw.exe binary used to run
child processes instead of sys.executable when using the 'spawn'
start method. Useful for people embedding Python.
'''
from .spawn import set_executable
set_executable(executable)
def set_forkserver_preload(self, module_names):
'''Set list of module names to try to load in forkserver process.
This is really just a hint.
'''
from .forkserver import set_forkserver_preload
set_forkserver_preload(module_names)
def get_context(self, method=None):
if method is None:
return self
try:
ctx = _concrete_contexts[method]
except KeyError:
raise ValueError('cannot find context for %r' % method)
ctx._check_available()
return ctx
def get_start_method(self, allow_none=False):
return self._name
def set_start_method(self, method=None):
raise ValueError('cannot set start method of concrete context')
def _check_available(self):
pass
#
# Type of default context -- underlying context can be set at most once
#
class Process(process.BaseProcess):
_start_method = None
@staticmethod
def _Popen(process_obj):
return _default_context.get_context().Process._Popen(process_obj)
class DefaultContext(BaseContext):
Process = Process
def __init__(self, context):
self._default_context = context
self._actual_context = None
def get_context(self, method=None):
if method is None:
if self._actual_context is None:
self._actual_context = self._default_context
return self._actual_context
else:
return super().get_context(method)
def set_start_method(self, method, force=False):
if self._actual_context is not None and not force:
raise RuntimeError('context has already been set')
if method is None and force:
self._actual_context = None
return
self._actual_context = self.get_context(method)
def get_start_method(self, allow_none=False):
if self._actual_context is None:
if allow_none:
return None
self._actual_context = self._default_context
return self._actual_context._name
def get_all_start_methods(self):
if sys.platform == 'win32':
return ['spawn']
else:
from . import reduction
if reduction.HAVE_SEND_HANDLE:
return ['fork', 'spawn', 'forkserver']
else:
return ['fork', 'spawn']
DefaultContext.__all__ = list(x for x in dir(DefaultContext) if x[0] != '_')
#
# Context types for fixed start method
#
if sys.platform != 'win32':
class ForkProcess(process.BaseProcess):
_start_method = 'fork'
@staticmethod
def _Popen(process_obj):
from .popen_fork import Popen
return Popen(process_obj)
class SpawnProcess(process.BaseProcess):
_start_method = 'spawn'
@staticmethod
def _Popen(process_obj):
from .popen_spawn_posix import Popen
return Popen(process_obj)
class ForkServerProcess(process.BaseProcess):
_start_method = 'forkserver'
@staticmethod
def _Popen(process_obj):
from .popen_forkserver import Popen
return Popen(process_obj)
class ForkContext(BaseContext):
_name = 'fork'
Process = ForkProcess
class SpawnContext(BaseContext):
_name = 'spawn'
Process = SpawnProcess
class ForkServerContext(BaseContext):
_name = 'forkserver'
Process = ForkServerProcess
def _check_available(self):
from . import reduction
if not reduction.HAVE_SEND_HANDLE:
raise ValueError('forkserver start method not available')
_concrete_contexts = {
'fork': ForkContext(),
'spawn': SpawnContext(),
'forkserver': ForkServerContext(),
}
_default_context = DefaultContext(_concrete_contexts['fork'])
else:
class SpawnProcess(process.BaseProcess):
_start_method = 'spawn'
@staticmethod
def _Popen(process_obj):
from .popen_spawn_win32 import Popen
return Popen(process_obj)
class SpawnContext(BaseContext):
_name = 'spawn'
Process = SpawnProcess
_concrete_contexts = {
'spawn': SpawnContext(),
}
_default_context = DefaultContext(_concrete_contexts['spawn'])
#
# Force the start method
#
def _force_start_method(method):
_default_context._actual_context = _concrete_contexts[method]
#
# Check that the current thread is spawning a child process
#
_tls = threading.local()
def get_spawning_popen():
return getattr(_tls, 'spawning_popen', None)
def set_spawning_popen(popen):
_tls.spawning_popen = popen
def assert_spawning(obj):
if get_spawning_popen() is None:
raise RuntimeError(
'%s objects should only be shared between processes'
' through inheritance' % type(obj).__name__
)

View file

@ -24,31 +24,34 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
MAXFDS_TO_SEND = 256 MAXFDS_TO_SEND = 256
UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t
_forkserver_address = None
_forkserver_alive_fd = None
_inherited_fds = None
_lock = threading.Lock()
_preload_modules = ['__main__']
# #
# Public function # Forkserver class
# #
def set_forkserver_preload(modules_names): class ForkServer(object):
def __init__(self):
self._forkserver_address = None
self._forkserver_alive_fd = None
self._inherited_fds = None
self._lock = threading.Lock()
self._preload_modules = ['__main__']
def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.''' '''Set list of module names to try to load in forkserver process.'''
global _preload_modules if not all(type(mod) is str for mod in self._preload_modules):
_preload_modules = modules_names raise TypeError('module_names must be a list of strings')
self._preload_modules = modules_names
def get_inherited_fds(self):
def get_inherited_fds():
'''Return list of fds inherited from parent process. '''Return list of fds inherited from parent process.
This returns None if the current process was not started by fork server. This returns None if the current process was not started by fork
server.
''' '''
return _inherited_fds return self._inherited_fds
def connect_to_new_process(self, fds):
def connect_to_new_process(fds):
'''Request forkserver to create a child process. '''Request forkserver to create a child process.
Returns a pair of fds (status_r, data_w). The calling process can read Returns a pair of fds (status_r, data_w). The calling process can read
@ -56,14 +59,15 @@ def connect_to_new_process(fds):
The calling process should write to data_w the pickled preparation and The calling process should write to data_w the pickled preparation and
process data. process data.
''' '''
self.ensure_running()
if len(fds) + 4 >= MAXFDS_TO_SEND: if len(fds) + 4 >= MAXFDS_TO_SEND:
raise ValueError('too many fds') raise ValueError('too many fds')
with socket.socket(socket.AF_UNIX) as client: with socket.socket(socket.AF_UNIX) as client:
client.connect(_forkserver_address) client.connect(self._forkserver_address)
parent_r, child_w = os.pipe() parent_r, child_w = os.pipe()
child_r, parent_w = os.pipe() child_r, parent_w = os.pipe()
allfds = [child_r, child_w, _forkserver_alive_fd, allfds = [child_r, child_w, self._forkserver_alive_fd,
semaphore_tracker._semaphore_tracker_fd] semaphore_tracker.getfd()]
allfds += fds allfds += fds
try: try:
reduction.sendfds(client, allfds) reduction.sendfds(client, allfds)
@ -76,27 +80,26 @@ def connect_to_new_process(fds):
os.close(child_r) os.close(child_r)
os.close(child_w) os.close(child_w)
def ensure_running(self):
def ensure_running():
'''Make sure that a fork server is running. '''Make sure that a fork server is running.
This can be called from any process. Note that usually a child This can be called from any process. Note that usually a child
process will just reuse the forkserver started by its parent, so process will just reuse the forkserver started by its parent, so
ensure_running() will do nothing. ensure_running() will do nothing.
''' '''
global _forkserver_address, _forkserver_alive_fd with self._lock:
with _lock: semaphore_tracker.ensure_running()
if _forkserver_alive_fd is not None: if self._forkserver_alive_fd is not None:
return return
assert all(type(mod) is str for mod in _preload_modules)
cmd = ('from multiprocessing.forkserver import main; ' + cmd = ('from multiprocessing.forkserver import main; ' +
'main(%d, %d, %r, **%r)') 'main(%d, %d, %r, **%r)')
if _preload_modules: if self._preload_modules:
desired_keys = {'main_path', 'sys_path'} desired_keys = {'main_path', 'sys_path'}
data = spawn.get_preparation_data('ignore') data = spawn.get_preparation_data('ignore')
data = dict((x,y) for (x,y) in data.items() if x in desired_keys) data = dict((x,y) for (x,y) in data.items()
if x in desired_keys)
else: else:
data = {} data = {}
@ -111,18 +114,23 @@ def ensure_running():
alive_r, alive_w = os.pipe() alive_r, alive_w = os.pipe()
try: try:
fds_to_pass = [listener.fileno(), alive_r] fds_to_pass = [listener.fileno(), alive_r]
cmd %= (listener.fileno(), alive_r, _preload_modules, data) cmd %= (listener.fileno(), alive_r, self._preload_modules,
data)
exe = spawn.get_executable() exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags() + ['-c', cmd] args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd]
pid = util.spawnv_passfds(exe, args, fds_to_pass) pid = util.spawnv_passfds(exe, args, fds_to_pass)
except: except:
os.close(alive_w) os.close(alive_w)
raise raise
finally: finally:
os.close(alive_r) os.close(alive_r)
_forkserver_address = address self._forkserver_address = address
_forkserver_alive_fd = alive_w self._forkserver_alive_fd = alive_w
#
#
#
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
'''Run forkserver.''' '''Run forkserver.'''
@ -151,8 +159,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
selectors.DefaultSelector() as selector: selectors.DefaultSelector() as selector:
global _forkserver_address _forkserver._forkserver_address = listener.getsockname()
_forkserver_address = listener.getsockname()
selector.register(listener, selectors.EVENT_READ) selector.register(listener, selectors.EVENT_READ)
selector.register(alive_r, selectors.EVENT_READ) selector.register(alive_r, selectors.EVENT_READ)
@ -187,13 +194,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
if e.errno != errno.ECONNABORTED: if e.errno != errno.ECONNABORTED:
raise raise
#
# Code to bootstrap new process
#
def _serve_one(s, listener, alive_r, handler): def _serve_one(s, listener, alive_r, handler):
global _inherited_fds, _forkserver_alive_fd
# close unnecessary stuff and reset SIGCHLD handler # close unnecessary stuff and reset SIGCHLD handler
listener.close() listener.close()
os.close(alive_r) os.close(alive_r)
@ -203,8 +204,9 @@ def _serve_one(s, listener, alive_r, handler):
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
s.close() s.close()
assert len(fds) <= MAXFDS_TO_SEND assert len(fds) <= MAXFDS_TO_SEND
child_r, child_w, _forkserver_alive_fd, stfd, *_inherited_fds = fds (child_r, child_w, _forkserver._forkserver_alive_fd,
semaphore_tracker._semaphore_tracker_fd = stfd stfd, *_forkserver._inherited_fds) = fds
semaphore_tracker._semaphore_tracker._fd = stfd
# send pid to client processes # send pid to client processes
write_unsigned(child_w, os.getpid()) write_unsigned(child_w, os.getpid())
@ -253,3 +255,13 @@ def write_unsigned(fd, n):
if nbytes == 0: if nbytes == 0:
raise RuntimeError('should not get here') raise RuntimeError('should not get here')
msg = msg[nbytes:] msg = msg[nbytes:]
#
#
#
_forkserver = ForkServer()
ensure_running = _forkserver.ensure_running
get_inherited_fds = _forkserver.get_inherited_fds
connect_to_new_process = _forkserver.connect_to_new_process
set_forkserver_preload = _forkserver.set_forkserver_preload

View file

@ -16,7 +16,7 @@ import tempfile
import threading import threading
import _multiprocessing import _multiprocessing
from . import popen from . import context
from . import reduction from . import reduction
from . import util from . import util
@ -50,7 +50,7 @@ if sys.platform == 'win32':
self._state = (self.size, self.name) self._state = (self.size, self.name)
def __getstate__(self): def __getstate__(self):
popen.assert_spawning(self) context.assert_spawning(self)
return self._state return self._state
def __setstate__(self, state): def __setstate__(self, state):

View file

@ -23,11 +23,12 @@ from time import time as _time
from traceback import format_exc from traceback import format_exc
from . import connection from . import connection
from . import context
from . import pool from . import pool
from . import process from . import process
from . import popen
from . import reduction from . import reduction
from . import util from . import util
from . import get_context
# #
# Register some things for pickling # Register some things for pickling
@ -438,7 +439,8 @@ class BaseManager(object):
_registry = {} _registry = {}
_Server = Server _Server = Server
def __init__(self, address=None, authkey=None, serializer='pickle'): def __init__(self, address=None, authkey=None, serializer='pickle',
ctx=None):
if authkey is None: if authkey is None:
authkey = process.current_process().authkey authkey = process.current_process().authkey
self._address = address # XXX not final address if eg ('', 0) self._address = address # XXX not final address if eg ('', 0)
@ -447,6 +449,7 @@ class BaseManager(object):
self._state.value = State.INITIAL self._state.value = State.INITIAL
self._serializer = serializer self._serializer = serializer
self._Listener, self._Client = listener_client[serializer] self._Listener, self._Client = listener_client[serializer]
self._ctx = ctx or get_context()
def get_server(self): def get_server(self):
''' '''
@ -478,7 +481,7 @@ class BaseManager(object):
reader, writer = connection.Pipe(duplex=False) reader, writer = connection.Pipe(duplex=False)
# spawn process which runs a server # spawn process which runs a server
self._process = process.Process( self._process = self._ctx.Process(
target=type(self)._run_server, target=type(self)._run_server,
args=(self._registry, self._address, self._authkey, args=(self._registry, self._address, self._authkey,
self._serializer, writer, initializer, initargs), self._serializer, writer, initializer, initargs),
@ -800,7 +803,7 @@ class BaseProxy(object):
def __reduce__(self): def __reduce__(self):
kwds = {} kwds = {}
if popen.get_spawning_popen() is not None: if context.get_spawning_popen() is not None:
kwds['authkey'] = self._authkey kwds['authkey'] = self._authkey
if getattr(self, '_isauto', False): if getattr(self, '_isauto', False):

View file

@ -24,7 +24,7 @@ import traceback
# If threading is available then ThreadPool should be provided. Therefore # If threading is available then ThreadPool should be provided. Therefore
# we avoid top-level imports which are liable to fail on some systems. # we avoid top-level imports which are liable to fail on some systems.
from . import util from . import util
from . import Process, cpu_count, TimeoutError, SimpleQueue from . import get_context, cpu_count, TimeoutError
# #
# Constants representing the state of a pool # Constants representing the state of a pool
@ -137,10 +137,12 @@ class Pool(object):
''' '''
Class which supports an async version of applying functions to arguments. Class which supports an async version of applying functions to arguments.
''' '''
Process = Process def Process(self, *args, **kwds):
return self._ctx.Process(*args, **kwds)
def __init__(self, processes=None, initializer=None, initargs=(), def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None): maxtasksperchild=None, context=None):
self._ctx = context or get_context()
self._setup_queues() self._setup_queues()
self._taskqueue = queue.Queue() self._taskqueue = queue.Queue()
self._cache = {} self._cache = {}
@ -232,8 +234,8 @@ class Pool(object):
self._repopulate_pool() self._repopulate_pool()
def _setup_queues(self): def _setup_queues(self):
self._inqueue = SimpleQueue() self._inqueue = self._ctx.SimpleQueue()
self._outqueue = SimpleQueue() self._outqueue = self._ctx.SimpleQueue()
self._quick_put = self._inqueue._writer.send self._quick_put = self._inqueue._writer.send
self._quick_get = self._outqueue._reader.recv self._quick_get = self._outqueue._reader.recv

View file

@ -1,78 +0,0 @@
import sys
import threading
__all__ = ['Popen', 'get_spawning_popen', 'set_spawning_popen',
'assert_spawning']
#
# Check that the current thread is spawning a child process
#
_tls = threading.local()
def get_spawning_popen():
return getattr(_tls, 'spawning_popen', None)
def set_spawning_popen(popen):
_tls.spawning_popen = popen
def assert_spawning(obj):
if get_spawning_popen() is None:
raise RuntimeError(
'%s objects should only be shared between processes'
' through inheritance' % type(obj).__name__
)
#
#
#
_Popen = None
def Popen(process_obj):
if _Popen is None:
set_start_method()
return _Popen(process_obj)
def get_start_method():
if _Popen is None:
set_start_method()
return _Popen.method
def set_start_method(meth=None, *, start_helpers=True):
global _Popen
try:
modname = _method_to_module[meth]
__import__(modname)
except (KeyError, ImportError):
raise ValueError('could not use start method %r' % meth)
module = sys.modules[modname]
if start_helpers:
module.Popen.ensure_helpers_running()
_Popen = module.Popen
if sys.platform == 'win32':
_method_to_module = {
None: 'multiprocessing.popen_spawn_win32',
'spawn': 'multiprocessing.popen_spawn_win32',
}
def get_all_start_methods():
return ['spawn']
else:
_method_to_module = {
None: 'multiprocessing.popen_fork',
'fork': 'multiprocessing.popen_fork',
'spawn': 'multiprocessing.popen_spawn_posix',
'forkserver': 'multiprocessing.popen_forkserver',
}
def get_all_start_methods():
from . import reduction
if reduction.HAVE_SEND_HANDLE:
return ['fork', 'spawn', 'forkserver']
else:
return ['fork', 'spawn']

View file

@ -81,7 +81,3 @@ class Popen(object):
os.close(child_w) os.close(child_w)
util.Finalize(self, os.close, (parent_r,)) util.Finalize(self, os.close, (parent_r,))
self.sentinel = parent_r self.sentinel = parent_r
@staticmethod
def ensure_helpers_running():
pass

View file

@ -4,8 +4,8 @@ import os
from . import reduction from . import reduction
if not reduction.HAVE_SEND_HANDLE: if not reduction.HAVE_SEND_HANDLE:
raise ImportError('No support for sending fds between processes') raise ImportError('No support for sending fds between processes')
from . import context
from . import forkserver from . import forkserver
from . import popen
from . import popen_fork from . import popen_fork
from . import spawn from . import spawn
from . import util from . import util
@ -42,12 +42,12 @@ class Popen(popen_fork.Popen):
def _launch(self, process_obj): def _launch(self, process_obj):
prep_data = spawn.get_preparation_data(process_obj._name) prep_data = spawn.get_preparation_data(process_obj._name)
buf = io.BytesIO() buf = io.BytesIO()
popen.set_spawning_popen(self) context.set_spawning_popen(self)
try: try:
reduction.dump(prep_data, buf) reduction.dump(prep_data, buf)
reduction.dump(process_obj, buf) reduction.dump(process_obj, buf)
finally: finally:
popen.set_spawning_popen(None) context.set_spawning_popen(None)
self.sentinel, w = forkserver.connect_to_new_process(self._fds) self.sentinel, w = forkserver.connect_to_new_process(self._fds)
util.Finalize(self, os.close, (self.sentinel,)) util.Finalize(self, os.close, (self.sentinel,))
@ -67,9 +67,3 @@ class Popen(popen_fork.Popen):
# The process ended abnormally perhaps because of a signal # The process ended abnormally perhaps because of a signal
self.returncode = 255 self.returncode = 255
return self.returncode return self.returncode
@staticmethod
def ensure_helpers_running():
from . import semaphore_tracker
semaphore_tracker.ensure_running()
forkserver.ensure_running()

View file

@ -2,7 +2,7 @@ import fcntl
import io import io
import os import os
from . import popen from . import context
from . import popen_fork from . import popen_fork
from . import reduction from . import reduction
from . import spawn from . import spawn
@ -41,16 +41,16 @@ class Popen(popen_fork.Popen):
def _launch(self, process_obj): def _launch(self, process_obj):
from . import semaphore_tracker from . import semaphore_tracker
tracker_fd = semaphore_tracker._semaphore_tracker_fd tracker_fd = semaphore_tracker.getfd()
self._fds.append(tracker_fd) self._fds.append(tracker_fd)
prep_data = spawn.get_preparation_data(process_obj._name) prep_data = spawn.get_preparation_data(process_obj._name)
fp = io.BytesIO() fp = io.BytesIO()
popen.set_spawning_popen(self) context.set_spawning_popen(self)
try: try:
reduction.dump(prep_data, fp) reduction.dump(prep_data, fp)
reduction.dump(process_obj, fp) reduction.dump(process_obj, fp)
finally: finally:
popen.set_spawning_popen(None) context.set_spawning_popen(None)
parent_r = child_w = child_r = parent_w = None parent_r = child_w = child_r = parent_w = None
try: try:
@ -70,8 +70,3 @@ class Popen(popen_fork.Popen):
for fd in (child_r, child_w, parent_w): for fd in (child_r, child_w, parent_w):
if fd is not None: if fd is not None:
os.close(fd) os.close(fd)
@staticmethod
def ensure_helpers_running():
from . import semaphore_tracker
semaphore_tracker.ensure_running()

View file

@ -4,8 +4,8 @@ import signal
import sys import sys
import _winapi import _winapi
from . import context
from . import spawn from . import spawn
from . import popen
from . import reduction from . import reduction
from . import util from . import util
@ -60,15 +60,15 @@ class Popen(object):
util.Finalize(self, _winapi.CloseHandle, (self.sentinel,)) util.Finalize(self, _winapi.CloseHandle, (self.sentinel,))
# send information to child # send information to child
popen.set_spawning_popen(self) context.set_spawning_popen(self)
try: try:
reduction.dump(prep_data, to_child) reduction.dump(prep_data, to_child)
reduction.dump(process_obj, to_child) reduction.dump(process_obj, to_child)
finally: finally:
popen.set_spawning_popen(None) context.set_spawning_popen(None)
def duplicate_for_child(self, handle): def duplicate_for_child(self, handle):
assert self is popen.get_spawning_popen() assert self is context.get_spawning_popen()
return reduction.duplicate(handle, self.sentinel) return reduction.duplicate(handle, self.sentinel)
def wait(self, timeout=None): def wait(self, timeout=None):
@ -97,7 +97,3 @@ class Popen(object):
except OSError: except OSError:
if self.wait(timeout=1.0) is None: if self.wait(timeout=1.0) is None:
raise raise
@staticmethod
def ensure_helpers_running():
pass

View file

@ -7,7 +7,7 @@
# Licensed to PSF under a Contributor Agreement. # Licensed to PSF under a Contributor Agreement.
# #
__all__ = ['Process', 'current_process', 'active_children'] __all__ = ['BaseProcess', 'current_process', 'active_children']
# #
# Imports # Imports
@ -59,13 +59,14 @@ def _cleanup():
# The `Process` class # The `Process` class
# #
class Process(object): class BaseProcess(object):
''' '''
Process objects represent activity that is run in a separate process Process objects represent activity that is run in a separate process
The class is analogous to `threading.Thread` The class is analogous to `threading.Thread`
''' '''
_Popen = None def _Popen(self):
raise NotImplementedError
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, def __init__(self, group=None, target=None, name=None, args=(), kwargs={},
*, daemon=None): *, daemon=None):
@ -101,11 +102,7 @@ class Process(object):
assert not _current_process._config.get('daemon'), \ assert not _current_process._config.get('daemon'), \
'daemonic processes are not allowed to have children' 'daemonic processes are not allowed to have children'
_cleanup() _cleanup()
if self._Popen is not None: self._popen = self._Popen(self)
Popen = self._Popen
else:
from .popen import Popen
self._popen = Popen(self)
self._sentinel = self._popen.sentinel self._sentinel = self._popen.sentinel
_children.add(self) _children.add(self)
@ -229,10 +226,12 @@ class Process(object):
## ##
def _bootstrap(self): def _bootstrap(self):
from . import util from . import util, context
global _current_process, _process_counter, _children global _current_process, _process_counter, _children
try: try:
if self._start_method is not None:
context._force_start_method(self._start_method)
_process_counter = itertools.count(1) _process_counter = itertools.count(1)
_children = set() _children = set()
if sys.stdin is not None: if sys.stdin is not None:
@ -282,7 +281,7 @@ class Process(object):
class AuthenticationString(bytes): class AuthenticationString(bytes):
def __reduce__(self): def __reduce__(self):
from .popen import get_spawning_popen from .context import get_spawning_popen
if get_spawning_popen() is None: if get_spawning_popen() is None:
raise TypeError( raise TypeError(
'Pickling an AuthenticationString object is ' 'Pickling an AuthenticationString object is '
@ -294,7 +293,7 @@ class AuthenticationString(bytes):
# Create object representing the main process # Create object representing the main process
# #
class _MainProcess(Process): class _MainProcess(BaseProcess):
def __init__(self): def __init__(self):
self._identity = () self._identity = ()

View file

@ -22,8 +22,7 @@ from queue import Empty, Full
import _multiprocessing import _multiprocessing
from . import connection from . import connection
from . import popen from . import context
from . import synchronize
from .util import debug, info, Finalize, register_after_fork, is_exiting from .util import debug, info, Finalize, register_after_fork, is_exiting
from .reduction import ForkingPickler from .reduction import ForkingPickler
@ -34,18 +33,18 @@ from .reduction import ForkingPickler
class Queue(object): class Queue(object):
def __init__(self, maxsize=0): def __init__(self, maxsize=0, *, ctx):
if maxsize <= 0: if maxsize <= 0:
maxsize = _multiprocessing.SemLock.SEM_VALUE_MAX maxsize = _multiprocessing.SemLock.SEM_VALUE_MAX
self._maxsize = maxsize self._maxsize = maxsize
self._reader, self._writer = connection.Pipe(duplex=False) self._reader, self._writer = connection.Pipe(duplex=False)
self._rlock = synchronize.Lock() self._rlock = ctx.Lock()
self._opid = os.getpid() self._opid = os.getpid()
if sys.platform == 'win32': if sys.platform == 'win32':
self._wlock = None self._wlock = None
else: else:
self._wlock = synchronize.Lock() self._wlock = ctx.Lock()
self._sem = synchronize.BoundedSemaphore(maxsize) self._sem = ctx.BoundedSemaphore(maxsize)
# For use by concurrent.futures # For use by concurrent.futures
self._ignore_epipe = False self._ignore_epipe = False
@ -55,7 +54,7 @@ class Queue(object):
register_after_fork(self, Queue._after_fork) register_after_fork(self, Queue._after_fork)
def __getstate__(self): def __getstate__(self):
popen.assert_spawning(self) context.assert_spawning(self)
return (self._ignore_epipe, self._maxsize, self._reader, self._writer, return (self._ignore_epipe, self._maxsize, self._reader, self._writer,
self._rlock, self._wlock, self._sem, self._opid) self._rlock, self._wlock, self._sem, self._opid)
@ -279,10 +278,10 @@ _sentinel = object()
class JoinableQueue(Queue): class JoinableQueue(Queue):
def __init__(self, maxsize=0): def __init__(self, maxsize=0, *, ctx):
Queue.__init__(self, maxsize) Queue.__init__(self, maxsize, ctx=ctx)
self._unfinished_tasks = synchronize.Semaphore(0) self._unfinished_tasks = ctx.Semaphore(0)
self._cond = synchronize.Condition() self._cond = ctx.Condition()
def __getstate__(self): def __getstate__(self):
return Queue.__getstate__(self) + (self._cond, self._unfinished_tasks) return Queue.__getstate__(self) + (self._cond, self._unfinished_tasks)
@ -332,20 +331,20 @@ class JoinableQueue(Queue):
class SimpleQueue(object): class SimpleQueue(object):
def __init__(self): def __init__(self, *, ctx):
self._reader, self._writer = connection.Pipe(duplex=False) self._reader, self._writer = connection.Pipe(duplex=False)
self._rlock = synchronize.Lock() self._rlock = ctx.Lock()
self._poll = self._reader.poll self._poll = self._reader.poll
if sys.platform == 'win32': if sys.platform == 'win32':
self._wlock = None self._wlock = None
else: else:
self._wlock = synchronize.Lock() self._wlock = ctx.Lock()
def empty(self): def empty(self):
return not self._poll() return not self._poll()
def __getstate__(self): def __getstate__(self):
popen.assert_spawning(self) context.assert_spawning(self)
return (self._reader, self._writer, self._rlock, self._wlock) return (self._reader, self._writer, self._rlock, self._wlock)
def __setstate__(self, state): def __setstate__(self, state):

View file

@ -15,7 +15,7 @@ import pickle
import socket import socket
import sys import sys
from . import popen from . import context
from . import util from . import util
__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump'] __all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump']
@ -183,7 +183,7 @@ else:
def DupFd(fd): def DupFd(fd):
'''Return a wrapper for an fd.''' '''Return a wrapper for an fd.'''
popen_obj = popen.get_spawning_popen() popen_obj = context.get_spawning_popen()
if popen_obj is not None: if popen_obj is not None:
return popen_obj.DupFd(popen_obj.duplicate_for_child(fd)) return popen_obj.DupFd(popen_obj.duplicate_for_child(fd))
elif HAVE_SEND_HANDLE: elif HAVE_SEND_HANDLE:

View file

@ -26,25 +26,30 @@ from . import current_process
__all__ = ['ensure_running', 'register', 'unregister'] __all__ = ['ensure_running', 'register', 'unregister']
_semaphore_tracker_fd = None class SemaphoreTracker(object):
_lock = threading.Lock()
def __init__(self):
self._lock = threading.Lock()
self._fd = None
def ensure_running(): def getfd(self):
self.ensure_running()
return self._fd
def ensure_running(self):
'''Make sure that semaphore tracker process is running. '''Make sure that semaphore tracker process is running.
This can be run from any process. Usually a child process will use This can be run from any process. Usually a child process will use
the semaphore created by its parent.''' the semaphore created by its parent.'''
global _semaphore_tracker_fd with self._lock:
with _lock: if self._fd is not None:
if _semaphore_tracker_fd is not None:
return return
fds_to_pass = [] fds_to_pass = []
try: try:
fds_to_pass.append(sys.stderr.fileno()) fds_to_pass.append(sys.stderr.fileno())
except Exception: except Exception:
pass pass
cmd = 'from multiprocessing.semaphore_tracker import main; main(%d)' cmd = 'from multiprocessing.semaphore_tracker import main;main(%d)'
r, w = os.pipe() r, w = os.pipe()
try: try:
fds_to_pass.append(r) fds_to_pass.append(r)
@ -57,31 +62,36 @@ def ensure_running():
os.close(w) os.close(w)
raise raise
else: else:
_semaphore_tracker_fd = w self._fd = w
finally: finally:
os.close(r) os.close(r)
def register(self, name):
def register(name):
'''Register name of semaphore with semaphore tracker.''' '''Register name of semaphore with semaphore tracker.'''
_send('REGISTER', name) self._send('REGISTER', name)
def unregister(self, name):
def unregister(name):
'''Unregister name of semaphore with semaphore tracker.''' '''Unregister name of semaphore with semaphore tracker.'''
_send('UNREGISTER', name) self._send('UNREGISTER', name)
def _send(self, cmd, name):
def _send(cmd, name): self.ensure_running()
msg = '{0}:{1}\n'.format(cmd, name).encode('ascii') msg = '{0}:{1}\n'.format(cmd, name).encode('ascii')
if len(name) > 512: if len(name) > 512:
# posix guarantees that writes to a pipe of less than PIPE_BUF # posix guarantees that writes to a pipe of less than PIPE_BUF
# bytes are atomic, and that PIPE_BUF >= 512 # bytes are atomic, and that PIPE_BUF >= 512
raise ValueError('name too long') raise ValueError('name too long')
nbytes = os.write(_semaphore_tracker_fd, msg) nbytes = os.write(self._fd, msg)
assert nbytes == len(msg) assert nbytes == len(msg)
_semaphore_tracker = SemaphoreTracker()
ensure_running = _semaphore_tracker.ensure_running
register = _semaphore_tracker.register
unregister = _semaphore_tracker.unregister
getfd = _semaphore_tracker.getfd
def main(fd): def main(fd):
'''Run semaphore tracker.''' '''Run semaphore tracker.'''
# protect the process from ^C and "killall python" etc # protect the process from ^C and "killall python" etc

View file

@ -11,10 +11,10 @@ import ctypes
import weakref import weakref
from . import heap from . import heap
from . import get_context
from .synchronize import RLock from .context import assert_spawning
from .reduction import ForkingPickler from .reduction import ForkingPickler
from .popen import assert_spawning
__all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized'] __all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized']
@ -66,7 +66,7 @@ def RawArray(typecode_or_type, size_or_initializer):
result.__init__(*size_or_initializer) result.__init__(*size_or_initializer)
return result return result
def Value(typecode_or_type, *args, lock=True): def Value(typecode_or_type, *args, lock=True, ctx=None):
''' '''
Return a synchronization wrapper for a Value Return a synchronization wrapper for a Value
''' '''
@ -74,12 +74,13 @@ def Value(typecode_or_type, *args, lock=True):
if lock is False: if lock is False:
return obj return obj
if lock in (True, None): if lock in (True, None):
lock = RLock() ctx = ctx or get_context()
lock = ctx.RLock()
if not hasattr(lock, 'acquire'): if not hasattr(lock, 'acquire'):
raise AttributeError("'%r' has no method 'acquire'" % lock) raise AttributeError("'%r' has no method 'acquire'" % lock)
return synchronized(obj, lock) return synchronized(obj, lock, ctx=ctx)
def Array(typecode_or_type, size_or_initializer, *, lock=True): def Array(typecode_or_type, size_or_initializer, *, lock=True, ctx=None):
''' '''
Return a synchronization wrapper for a RawArray Return a synchronization wrapper for a RawArray
''' '''
@ -87,25 +88,27 @@ def Array(typecode_or_type, size_or_initializer, *, lock=True):
if lock is False: if lock is False:
return obj return obj
if lock in (True, None): if lock in (True, None):
lock = RLock() ctx = ctx or get_context()
lock = ctx.RLock()
if not hasattr(lock, 'acquire'): if not hasattr(lock, 'acquire'):
raise AttributeError("'%r' has no method 'acquire'" % lock) raise AttributeError("'%r' has no method 'acquire'" % lock)
return synchronized(obj, lock) return synchronized(obj, lock, ctx=ctx)
def copy(obj): def copy(obj):
new_obj = _new_value(type(obj)) new_obj = _new_value(type(obj))
ctypes.pointer(new_obj)[0] = obj ctypes.pointer(new_obj)[0] = obj
return new_obj return new_obj
def synchronized(obj, lock=None): def synchronized(obj, lock=None, ctx=None):
assert not isinstance(obj, SynchronizedBase), 'object already synchronized' assert not isinstance(obj, SynchronizedBase), 'object already synchronized'
ctx = ctx or get_context()
if isinstance(obj, ctypes._SimpleCData): if isinstance(obj, ctypes._SimpleCData):
return Synchronized(obj, lock) return Synchronized(obj, lock, ctx)
elif isinstance(obj, ctypes.Array): elif isinstance(obj, ctypes.Array):
if obj._type_ is ctypes.c_char: if obj._type_ is ctypes.c_char:
return SynchronizedString(obj, lock) return SynchronizedString(obj, lock, ctx)
return SynchronizedArray(obj, lock) return SynchronizedArray(obj, lock, ctx)
else: else:
cls = type(obj) cls = type(obj)
try: try:
@ -115,7 +118,7 @@ def synchronized(obj, lock=None):
d = dict((name, make_property(name)) for name in names) d = dict((name, make_property(name)) for name in names)
classname = 'Synchronized' + cls.__name__ classname = 'Synchronized' + cls.__name__
scls = class_cache[cls] = type(classname, (SynchronizedBase,), d) scls = class_cache[cls] = type(classname, (SynchronizedBase,), d)
return scls(obj, lock) return scls(obj, lock, ctx)
# #
# Functions for pickling/unpickling # Functions for pickling/unpickling
@ -175,9 +178,13 @@ class_cache = weakref.WeakKeyDictionary()
class SynchronizedBase(object): class SynchronizedBase(object):
def __init__(self, obj, lock=None): def __init__(self, obj, lock=None, ctx=None):
self._obj = obj self._obj = obj
self._lock = lock or RLock() if lock:
self._lock = lock
else:
ctx = ctx or get_context(force=True)
self._lock = ctx.RLock()
self.acquire = self._lock.acquire self.acquire = self._lock.acquire
self.release = self._lock.release self.release = self._lock.release

View file

@ -12,9 +12,9 @@ import os
import pickle import pickle
import sys import sys
from . import get_start_method, set_start_method
from . import process from . import process
from . import util from . import util
from . import popen
__all__ = ['_main', 'freeze_support', 'set_executable', 'get_executable', __all__ = ['_main', 'freeze_support', 'set_executable', 'get_executable',
'get_preparation_data', 'get_command_line', 'import_main_path'] 'get_preparation_data', 'get_command_line', 'import_main_path']
@ -91,7 +91,7 @@ def spawn_main(pipe_handle, parent_pid=None, tracker_fd=None):
fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY) fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY)
else: else:
from . import semaphore_tracker from . import semaphore_tracker
semaphore_tracker._semaphore_tracker_fd = tracker_fd semaphore_tracker._semaphore_tracker._fd = tracker_fd
fd = pipe_handle fd = pipe_handle
exitcode = _main(fd) exitcode = _main(fd)
sys.exit(exitcode) sys.exit(exitcode)
@ -154,7 +154,7 @@ def get_preparation_data(name):
sys_argv=sys.argv, sys_argv=sys.argv,
orig_dir=process.ORIGINAL_DIR, orig_dir=process.ORIGINAL_DIR,
dir=os.getcwd(), dir=os.getcwd(),
start_method=popen.get_start_method(), start_method=get_start_method(),
) )
if sys.platform != 'win32' or (not WINEXE and not WINSERVICE): if sys.platform != 'win32' or (not WINEXE and not WINSERVICE):
@ -204,7 +204,7 @@ def prepare(data):
process.ORIGINAL_DIR = data['orig_dir'] process.ORIGINAL_DIR = data['orig_dir']
if 'start_method' in data: if 'start_method' in data:
popen.set_start_method(data['start_method'], start_helpers=False) set_start_method(data['start_method'])
if 'main_path' in data: if 'main_path' in data:
import_main_path(data['main_path']) import_main_path(data['main_path'])

View file

@ -20,7 +20,7 @@ import _multiprocessing
from time import time as _time from time import time as _time
from . import popen from . import context
from . import process from . import process
from . import util from . import util
@ -50,14 +50,15 @@ class SemLock(object):
_rand = tempfile._RandomNameSequence() _rand = tempfile._RandomNameSequence()
def __init__(self, kind, value, maxvalue): def __init__(self, kind, value, maxvalue, *, ctx):
unlink_immediately = (sys.platform == 'win32' or ctx = ctx or get_context()
popen.get_start_method() == 'fork') ctx = ctx.get_context()
unlink_now = sys.platform == 'win32' or ctx._name == 'fork'
for i in range(100): for i in range(100):
try: try:
sl = self._semlock = _multiprocessing.SemLock( sl = self._semlock = _multiprocessing.SemLock(
kind, value, maxvalue, self._make_name(), kind, value, maxvalue, self._make_name(),
unlink_immediately) unlink_now)
except FileExistsError: except FileExistsError:
pass pass
else: else:
@ -99,10 +100,10 @@ class SemLock(object):
return self._semlock.__exit__(*args) return self._semlock.__exit__(*args)
def __getstate__(self): def __getstate__(self):
popen.assert_spawning(self) context.assert_spawning(self)
sl = self._semlock sl = self._semlock
if sys.platform == 'win32': if sys.platform == 'win32':
h = popen.get_spawning_popen().duplicate_for_child(sl.handle) h = context.get_spawning_popen().duplicate_for_child(sl.handle)
else: else:
h = sl.handle h = sl.handle
return (h, sl.kind, sl.maxvalue, sl.name) return (h, sl.kind, sl.maxvalue, sl.name)
@ -123,8 +124,8 @@ class SemLock(object):
class Semaphore(SemLock): class Semaphore(SemLock):
def __init__(self, value=1): def __init__(self, value=1, *, ctx):
SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX) SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx)
def get_value(self): def get_value(self):
return self._semlock._get_value() return self._semlock._get_value()
@ -142,8 +143,8 @@ class Semaphore(SemLock):
class BoundedSemaphore(Semaphore): class BoundedSemaphore(Semaphore):
def __init__(self, value=1): def __init__(self, value=1, *, ctx):
SemLock.__init__(self, SEMAPHORE, value, value) SemLock.__init__(self, SEMAPHORE, value, value, ctx=ctx)
def __repr__(self): def __repr__(self):
try: try:
@ -159,8 +160,8 @@ class BoundedSemaphore(Semaphore):
class Lock(SemLock): class Lock(SemLock):
def __init__(self): def __init__(self, *, ctx):
SemLock.__init__(self, SEMAPHORE, 1, 1) SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
def __repr__(self): def __repr__(self):
try: try:
@ -184,8 +185,8 @@ class Lock(SemLock):
class RLock(SemLock): class RLock(SemLock):
def __init__(self): def __init__(self, *, ctx):
SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1) SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1, ctx=ctx)
def __repr__(self): def __repr__(self):
try: try:
@ -210,15 +211,15 @@ class RLock(SemLock):
class Condition(object): class Condition(object):
def __init__(self, lock=None): def __init__(self, lock=None, *, ctx):
self._lock = lock or RLock() self._lock = lock or ctx.RLock()
self._sleeping_count = Semaphore(0) self._sleeping_count = ctx.Semaphore(0)
self._woken_count = Semaphore(0) self._woken_count = ctx.Semaphore(0)
self._wait_semaphore = Semaphore(0) self._wait_semaphore = ctx.Semaphore(0)
self._make_methods() self._make_methods()
def __getstate__(self): def __getstate__(self):
popen.assert_spawning(self) context.assert_spawning(self)
return (self._lock, self._sleeping_count, return (self._lock, self._sleeping_count,
self._woken_count, self._wait_semaphore) self._woken_count, self._wait_semaphore)
@ -332,9 +333,9 @@ class Condition(object):
class Event(object): class Event(object):
def __init__(self): def __init__(self, *, ctx):
self._cond = Condition(Lock()) self._cond = ctx.Condition(ctx.Lock())
self._flag = Semaphore(0) self._flag = ctx.Semaphore(0)
def is_set(self): def is_set(self):
self._cond.acquire() self._cond.acquire()
@ -383,11 +384,11 @@ class Event(object):
class Barrier(threading.Barrier): class Barrier(threading.Barrier):
def __init__(self, parties, action=None, timeout=None): def __init__(self, parties, action=None, timeout=None, *, ctx):
import struct import struct
from .heap import BufferWrapper from .heap import BufferWrapper
wrapper = BufferWrapper(struct.calcsize('i') * 2) wrapper = BufferWrapper(struct.calcsize('i') * 2)
cond = Condition() cond = ctx.Condition()
self.__setstate__((parties, action, timeout, cond, wrapper)) self.__setstate__((parties, action, timeout, cond, wrapper))
self._state = 0 self._state = 0
self._count = 0 self._count = 0

View file

@ -3555,6 +3555,32 @@ class TestIgnoreEINTR(unittest.TestCase):
conn.close() conn.close()
class TestStartMethod(unittest.TestCase): class TestStartMethod(unittest.TestCase):
@classmethod
def _check_context(cls, conn):
conn.send(multiprocessing.get_start_method())
def check_context(self, ctx):
r, w = ctx.Pipe(duplex=False)
p = ctx.Process(target=self._check_context, args=(w,))
p.start()
w.close()
child_method = r.recv()
r.close()
p.join()
self.assertEqual(child_method, ctx.get_start_method())
def test_context(self):
for method in ('fork', 'spawn', 'forkserver'):
try:
ctx = multiprocessing.get_context(method)
except ValueError:
continue
self.assertEqual(ctx.get_start_method(), method)
self.assertIs(ctx.get_context(), ctx)
self.assertRaises(ValueError, ctx.set_start_method, 'spawn')
self.assertRaises(ValueError, ctx.set_start_method, None)
self.check_context(ctx)
def test_set_get(self): def test_set_get(self):
multiprocessing.set_forkserver_preload(PRELOAD) multiprocessing.set_forkserver_preload(PRELOAD)
count = 0 count = 0
@ -3562,13 +3588,19 @@ class TestStartMethod(unittest.TestCase):
try: try:
for method in ('fork', 'spawn', 'forkserver'): for method in ('fork', 'spawn', 'forkserver'):
try: try:
multiprocessing.set_start_method(method) multiprocessing.set_start_method(method, force=True)
except ValueError: except ValueError:
continue continue
self.assertEqual(multiprocessing.get_start_method(), method) self.assertEqual(multiprocessing.get_start_method(), method)
ctx = multiprocessing.get_context()
self.assertEqual(ctx.get_start_method(), method)
self.assertTrue(type(ctx).__name__.lower().startswith(method))
self.assertTrue(
ctx.Process.__name__.lower().startswith(method))
self.check_context(multiprocessing)
count += 1 count += 1
finally: finally:
multiprocessing.set_start_method(old_method) multiprocessing.set_start_method(old_method, force=True)
self.assertGreaterEqual(count, 1) self.assertGreaterEqual(count, 1)
def test_get_all(self): def test_get_all(self):
@ -3753,9 +3785,9 @@ def install_tests_in_module_dict(remote_globs, start_method):
multiprocessing.process._cleanup() multiprocessing.process._cleanup()
dangling[0] = multiprocessing.process._dangling.copy() dangling[0] = multiprocessing.process._dangling.copy()
dangling[1] = threading._dangling.copy() dangling[1] = threading._dangling.copy()
old_start_method[0] = multiprocessing.get_start_method() old_start_method[0] = multiprocessing.get_start_method(allow_none=True)
try: try:
multiprocessing.set_start_method(start_method) multiprocessing.set_start_method(start_method, force=True)
except ValueError: except ValueError:
raise unittest.SkipTest(start_method + raise unittest.SkipTest(start_method +
' start method not supported') ' start method not supported')
@ -3771,7 +3803,7 @@ def install_tests_in_module_dict(remote_globs, start_method):
multiprocessing.get_logger().setLevel(LOG_LEVEL) multiprocessing.get_logger().setLevel(LOG_LEVEL)
def tearDownModule(): def tearDownModule():
multiprocessing.set_start_method(old_start_method[0]) multiprocessing.set_start_method(old_start_method[0], force=True)
# pause a bit so we don't get warning about dangling threads/processes # pause a bit so we don't get warning about dangling threads/processes
time.sleep(0.5) time.sleep(0.5)
multiprocessing.process._cleanup() multiprocessing.process._cleanup()