mirror of
https://github.com/python/cpython.git
synced 2025-11-02 11:08:57 +00:00
Issue #9244: multiprocessing.pool: Worker crashes if result can't be encoded
This commit is contained in:
parent
fb0469112f
commit
2afcbf2249
2 changed files with 88 additions and 10 deletions
|
|
@ -42,6 +42,23 @@ def mapstar(args):
|
||||||
# Code run by worker processes
|
# Code run by worker processes
|
||||||
#
|
#
|
||||||
|
|
||||||
|
class MaybeEncodingError(Exception):
|
||||||
|
"""Wraps possible unpickleable errors, so they can be
|
||||||
|
safely sent through the socket."""
|
||||||
|
|
||||||
|
def __init__(self, exc, value):
|
||||||
|
self.exc = repr(exc)
|
||||||
|
self.value = repr(value)
|
||||||
|
super(MaybeEncodingError, self).__init__(self.exc, self.value)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Error sending result: '%s'. Reason: '%s'" % (self.value,
|
||||||
|
self.exc)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<MaybeEncodingError: %s>" % str(self)
|
||||||
|
|
||||||
|
|
||||||
def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
|
def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
|
||||||
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
|
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
|
||||||
put = outqueue.put
|
put = outqueue.put
|
||||||
|
|
@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
|
||||||
result = (True, func(*args, **kwds))
|
result = (True, func(*args, **kwds))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = (False, e)
|
result = (False, e)
|
||||||
put((job, i, result))
|
try:
|
||||||
|
put((job, i, result))
|
||||||
|
except Exception as e:
|
||||||
|
wrapped = MaybeEncodingError(e, result[1])
|
||||||
|
debug("Possible encoding error while sending result: %s" % (
|
||||||
|
wrapped))
|
||||||
|
put((job, i, (False, wrapped)))
|
||||||
completed += 1
|
completed += 1
|
||||||
debug('worker exiting after %d tasks' % completed)
|
debug('worker exiting after %d tasks' % completed)
|
||||||
|
|
||||||
|
|
@ -235,16 +258,18 @@ class Pool(object):
|
||||||
for i, x in enumerate(task_batches)), result._set_length))
|
for i, x in enumerate(task_batches)), result._set_length))
|
||||||
return (item for chunk in result for item in chunk)
|
return (item for chunk in result for item in chunk)
|
||||||
|
|
||||||
def apply_async(self, func, args=(), kwds={}, callback=None):
|
def apply_async(self, func, args=(), kwds={}, callback=None,
|
||||||
|
error_callback=None):
|
||||||
'''
|
'''
|
||||||
Asynchronous version of `apply()` method.
|
Asynchronous version of `apply()` method.
|
||||||
'''
|
'''
|
||||||
assert self._state == RUN
|
assert self._state == RUN
|
||||||
result = ApplyResult(self._cache, callback)
|
result = ApplyResult(self._cache, callback, error_callback)
|
||||||
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
|
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def map_async(self, func, iterable, chunksize=None, callback=None):
|
def map_async(self, func, iterable, chunksize=None, callback=None,
|
||||||
|
error_callback=None):
|
||||||
'''
|
'''
|
||||||
Asynchronous version of `map()` method.
|
Asynchronous version of `map()` method.
|
||||||
'''
|
'''
|
||||||
|
|
@ -260,7 +285,8 @@ class Pool(object):
|
||||||
chunksize = 0
|
chunksize = 0
|
||||||
|
|
||||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||||
result = MapResult(self._cache, chunksize, len(iterable), callback)
|
result = MapResult(self._cache, chunksize, len(iterable), callback,
|
||||||
|
error_callback=error_callback)
|
||||||
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
|
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
|
||||||
for i, x in enumerate(task_batches)), None))
|
for i, x in enumerate(task_batches)), None))
|
||||||
return result
|
return result
|
||||||
|
|
@ -459,12 +485,13 @@ class Pool(object):
|
||||||
|
|
||||||
class ApplyResult(object):
|
class ApplyResult(object):
|
||||||
|
|
||||||
def __init__(self, cache, callback):
|
def __init__(self, cache, callback, error_callback):
|
||||||
self._cond = threading.Condition(threading.Lock())
|
self._cond = threading.Condition(threading.Lock())
|
||||||
self._job = next(job_counter)
|
self._job = next(job_counter)
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._ready = False
|
self._ready = False
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
|
self._error_callback = error_callback
|
||||||
cache[self._job] = self
|
cache[self._job] = self
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
|
|
@ -495,6 +522,8 @@ class ApplyResult(object):
|
||||||
self._success, self._value = obj
|
self._success, self._value = obj
|
||||||
if self._callback and self._success:
|
if self._callback and self._success:
|
||||||
self._callback(self._value)
|
self._callback(self._value)
|
||||||
|
if self._error_callback and not self._success:
|
||||||
|
self._error_callback(self._value)
|
||||||
self._cond.acquire()
|
self._cond.acquire()
|
||||||
try:
|
try:
|
||||||
self._ready = True
|
self._ready = True
|
||||||
|
|
@ -509,8 +538,9 @@ class ApplyResult(object):
|
||||||
|
|
||||||
class MapResult(ApplyResult):
|
class MapResult(ApplyResult):
|
||||||
|
|
||||||
def __init__(self, cache, chunksize, length, callback):
|
def __init__(self, cache, chunksize, length, callback, error_callback):
|
||||||
ApplyResult.__init__(self, cache, callback)
|
ApplyResult.__init__(self, cache, callback,
|
||||||
|
error_callback=error_callback)
|
||||||
self._success = True
|
self._success = True
|
||||||
self._value = [None] * length
|
self._value = [None] * length
|
||||||
self._chunksize = chunksize
|
self._chunksize = chunksize
|
||||||
|
|
@ -535,10 +565,11 @@ class MapResult(ApplyResult):
|
||||||
self._cond.notify()
|
self._cond.notify()
|
||||||
finally:
|
finally:
|
||||||
self._cond.release()
|
self._cond.release()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self._success = False
|
self._success = False
|
||||||
self._value = result
|
self._value = result
|
||||||
|
if self._error_callback:
|
||||||
|
self._error_callback(self._value)
|
||||||
del self._cache[self._job]
|
del self._cache[self._job]
|
||||||
self._cond.acquire()
|
self._cond.acquire()
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1011,6 +1011,7 @@ class _TestContainers(BaseTestCase):
|
||||||
def sqr(x, wait=0.0):
|
def sqr(x, wait=0.0):
|
||||||
time.sleep(wait)
|
time.sleep(wait)
|
||||||
return x*x
|
return x*x
|
||||||
|
|
||||||
class _TestPool(BaseTestCase):
|
class _TestPool(BaseTestCase):
|
||||||
|
|
||||||
def test_apply(self):
|
def test_apply(self):
|
||||||
|
|
@ -1087,9 +1088,55 @@ class _TestPool(BaseTestCase):
|
||||||
join()
|
join()
|
||||||
self.assertTrue(join.elapsed < 0.2)
|
self.assertTrue(join.elapsed < 0.2)
|
||||||
|
|
||||||
class _TestPoolWorkerLifetime(BaseTestCase):
|
def raising():
|
||||||
|
raise KeyError("key")
|
||||||
|
|
||||||
|
def unpickleable_result():
|
||||||
|
return lambda: 42
|
||||||
|
|
||||||
|
class _TestPoolWorkerErrors(BaseTestCase):
|
||||||
ALLOWED_TYPES = ('processes', )
|
ALLOWED_TYPES = ('processes', )
|
||||||
|
|
||||||
|
def test_async_error_callback(self):
|
||||||
|
p = multiprocessing.Pool(2)
|
||||||
|
|
||||||
|
scratchpad = [None]
|
||||||
|
def errback(exc):
|
||||||
|
scratchpad[0] = exc
|
||||||
|
|
||||||
|
res = p.apply_async(raising, error_callback=errback)
|
||||||
|
self.assertRaises(KeyError, res.get)
|
||||||
|
self.assertTrue(scratchpad[0])
|
||||||
|
self.assertIsInstance(scratchpad[0], KeyError)
|
||||||
|
|
||||||
|
p.close()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
def test_unpickleable_result(self):
|
||||||
|
from multiprocessing.pool import MaybeEncodingError
|
||||||
|
p = multiprocessing.Pool(2)
|
||||||
|
|
||||||
|
# Make sure we don't lose pool processes because of encoding errors.
|
||||||
|
for iteration in range(20):
|
||||||
|
|
||||||
|
scratchpad = [None]
|
||||||
|
def errback(exc):
|
||||||
|
scratchpad[0] = exc
|
||||||
|
|
||||||
|
res = p.apply_async(unpickleable_result, error_callback=errback)
|
||||||
|
self.assertRaises(MaybeEncodingError, res.get)
|
||||||
|
wrapped = scratchpad[0]
|
||||||
|
self.assertTrue(wrapped)
|
||||||
|
self.assertIsInstance(scratchpad[0], MaybeEncodingError)
|
||||||
|
self.assertIsNotNone(wrapped.exc)
|
||||||
|
self.assertIsNotNone(wrapped.value)
|
||||||
|
|
||||||
|
p.close()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
class _TestPoolWorkerLifetime(BaseTestCase):
|
||||||
|
ALLOWED_TYPES = ('processes', )
|
||||||
|
|
||||||
def test_pool_worker_lifetime(self):
|
def test_pool_worker_lifetime(self):
|
||||||
p = multiprocessing.Pool(3, maxtasksperchild=10)
|
p = multiprocessing.Pool(3, maxtasksperchild=10)
|
||||||
self.assertEqual(3, len(p._pool))
|
self.assertEqual(3, len(p._pool))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue