bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)

an exception raised at the very first of an iterable would cause pools behave abnormally
(swallow the exception or hang)
This commit is contained in:
Xiang Zhang 2017-03-29 11:58:54 +08:00 committed by GitHub
parent ec1f5df46e
commit 794623bdb2
3 changed files with 117 additions and 25 deletions

View file

@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
try:
result = (True, func(*args, **kwds))
except Exception as e:
if wrap_exception:
if wrap_exception and func is not _helper_reraises_exception:
e = ExceptionWithTraceback(e, e.__traceback__)
result = (False, e)
try:
@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
completed += 1
util.debug('worker exiting after %d tasks' % completed)
def _helper_reraises_exception(ex):
'Pickle-able helper function for use by _guarded_task_generation.'
raise ex
#
# Class representing a process pool
#
@ -277,6 +281,17 @@ class Pool(object):
return self._map_async(func, iterable, starmapstar, chunksize,
callback, error_callback)
def _guarded_task_generation(self, result_job, func, iterable):
'''Provides a generator of tasks for imap and imap_unordered with
appropriate handling for iterables which throw exceptions during
iteration.'''
try:
i = -1
for i, x in enumerate(iterable):
yield (result_job, i, func, (x,), {})
except Exception as e:
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
def imap(self, func, iterable, chunksize=1):
'''
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
@ -285,15 +300,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def imap_unordered(self, func, iterable, chunksize=1):
@ -304,15 +327,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def apply_async(self, func, args=(), kwds={}, callback=None,
@ -323,7 +354,7 @@ class Pool(object):
if self._state != RUN:
raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
def map_async(self, func, iterable, chunksize=None, callback=None,
@ -354,8 +385,14 @@ class Pool(object):
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapper, (x,), {})
for i, x in enumerate(task_batches)), None))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapper,
task_batches),
None
)
)
return result
@staticmethod
@ -377,33 +414,27 @@ class Pool(object):
for taskseq, set_length in iter(taskqueue.get, None):
task = None
i = -1
try:
for i, task in enumerate(taskseq):
# iterating taskseq cannot fail
for task in taskseq:
if thread._state:
util.debug('task handler found thread._state != RUN')
break
try:
put(task)
except Exception as e:
job, ind = task[:2]
job, idx = task[:2]
try:
cache[job]._set(ind, (False, e))
cache[job]._set(idx, (False, e))
except KeyError:
pass
else:
if set_length:
util.debug('doing set_length()')
set_length(i+1)
idx = task[1] if task else -1
set_length(idx + 1)
continue
break
except Exception as ex:
job, ind = task[:2] if task else (0, 0)
if job in cache:
cache[job]._set(ind + 1, (False, ex))
if set_length:
util.debug('doing set_length()')
set_length(i+1)
finally:
task = taskseq = job = None
else: