Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2afcbf2

Browse files
author
Ask Solem
committed
Issue #9244: multiprocessing.pool: Worker crashes if result can't be encoded
1 parent fb04691 commit 2afcbf2

2 files changed

Lines changed: 88 additions & 10 deletions

File tree

Lib/multiprocessing/pool.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ def mapstar(args):
4242
# Code run by worker processes
4343
#
4444

45+
class MaybeEncodingError(Exception):
46+
"""Wraps possible unpickleable errors, so they can be
47+
safely sent through the socket."""
48+
49+
def __init__(self, exc, value):
50+
self.exc = repr(exc)
51+
self.value = repr(value)
52+
super(MaybeEncodingError, self).__init__(self.exc, self.value)
53+
54+
def __str__(self):
55+
return "Error sending result: '%s'. Reason: '%s'" % (self.value,
56+
self.exc)
57+
58+
def __repr__(self):
59+
return "<MaybeEncodingError: %s>" % str(self)
60+
61+
4562
def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
4663
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
4764
put = outqueue.put
@@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
7087
result = (True, func(*args, **kwds))
7188
except Exception as e:
7289
result = (False, e)
73-
put((job, i, result))
90+
try:
91+
put((job, i, result))
92+
except Exception as e:
93+
wrapped = MaybeEncodingError(e, result[1])
94+
debug("Possible encoding error while sending result: %s" % (
95+
wrapped))
96+
put((job, i, (False, wrapped)))
7497
completed += 1
7598
debug('worker exiting after %d tasks' % completed)
7699

@@ -235,16 +258,18 @@ def imap_unordered(self, func, iterable, chunksize=1):
235258
for i, x in enumerate(task_batches)), result._set_length))
236259
return (item for chunk in result for item in chunk)
237260

238-
def apply_async(self, func, args=(), kwds={}, callback=None):
261+
def apply_async(self, func, args=(), kwds={}, callback=None,
262+
error_callback=None):
239263
'''
240264
Asynchronous version of `apply()` method.
241265
'''
242266
assert self._state == RUN
243-
result = ApplyResult(self._cache, callback)
267+
result = ApplyResult(self._cache, callback, error_callback)
244268
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
245269
return result
246270

247-
def map_async(self, func, iterable, chunksize=None, callback=None):
271+
def map_async(self, func, iterable, chunksize=None, callback=None,
272+
error_callback=None):
248273
'''
249274
Asynchronous version of `map()` method.
250275
'''
@@ -260,7 +285,8 @@ def map_async(self, func, iterable, chunksize=None, callback=None):
260285
chunksize = 0
261286

262287
task_batches = Pool._get_tasks(func, iterable, chunksize)
263-
result = MapResult(self._cache, chunksize, len(iterable), callback)
288+
result = MapResult(self._cache, chunksize, len(iterable), callback,
289+
error_callback=error_callback)
264290
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
265291
for i, x in enumerate(task_batches)), None))
266292
return result
@@ -459,12 +485,13 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
459485

460486
class ApplyResult(object):
461487

462-
def __init__(self, cache, callback):
488+
def __init__(self, cache, callback, error_callback):
463489
self._cond = threading.Condition(threading.Lock())
464490
self._job = next(job_counter)
465491
self._cache = cache
466492
self._ready = False
467493
self._callback = callback
494+
self._error_callback = error_callback
468495
cache[self._job] = self
469496

470497
def ready(self):
@@ -495,6 +522,8 @@ def _set(self, i, obj):
495522
self._success, self._value = obj
496523
if self._callback and self._success:
497524
self._callback(self._value)
525+
if self._error_callback and not self._success:
526+
self._error_callback(self._value)
498527
self._cond.acquire()
499528
try:
500529
self._ready = True
@@ -509,8 +538,9 @@ def _set(self, i, obj):
509538

510539
class MapResult(ApplyResult):
511540

512-
def __init__(self, cache, chunksize, length, callback):
513-
ApplyResult.__init__(self, cache, callback)
541+
def __init__(self, cache, chunksize, length, callback, error_callback):
542+
ApplyResult.__init__(self, cache, callback,
543+
error_callback=error_callback)
514544
self._success = True
515545
self._value = [None] * length
516546
self._chunksize = chunksize
@@ -535,10 +565,11 @@ def _set(self, i, success_result):
535565
self._cond.notify()
536566
finally:
537567
self._cond.release()
538-
539568
else:
540569
self._success = False
541570
self._value = result
571+
if self._error_callback:
572+
self._error_callback(self._value)
542573
del self._cache[self._job]
543574
self._cond.acquire()
544575
try:

Lib/test/test_multiprocessing.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,7 @@ def test_namespace(self):
10111011
def sqr(x, wait=0.0):
10121012
time.sleep(wait)
10131013
return x*x
1014+
10141015
class _TestPool(BaseTestCase):
10151016

10161017
def test_apply(self):
@@ -1087,9 +1088,55 @@ def test_terminate(self):
10871088
join()
10881089
self.assertTrue(join.elapsed < 0.2)
10891090

1090-
class _TestPoolWorkerLifetime(BaseTestCase):
1091+
def raising():
1092+
raise KeyError("key")
1093+
1094+
def unpickleable_result():
1095+
return lambda: 42
1096+
1097+
class _TestPoolWorkerErrors(BaseTestCase):
1098+
ALLOWED_TYPES = ('processes', )
1099+
1100+
def test_async_error_callback(self):
1101+
p = multiprocessing.Pool(2)
1102+
1103+
scratchpad = [None]
1104+
def errback(exc):
1105+
scratchpad[0] = exc
1106+
1107+
res = p.apply_async(raising, error_callback=errback)
1108+
self.assertRaises(KeyError, res.get)
1109+
self.assertTrue(scratchpad[0])
1110+
self.assertIsInstance(scratchpad[0], KeyError)
1111+
1112+
p.close()
1113+
p.join()
1114+
1115+
def test_unpickleable_result(self):
1116+
from multiprocessing.pool import MaybeEncodingError
1117+
p = multiprocessing.Pool(2)
1118+
1119+
# Make sure we don't lose pool processes because of encoding errors.
1120+
for iteration in range(20):
1121+
1122+
scratchpad = [None]
1123+
def errback(exc):
1124+
scratchpad[0] = exc
1125+
1126+
res = p.apply_async(unpickleable_result, error_callback=errback)
1127+
self.assertRaises(MaybeEncodingError, res.get)
1128+
wrapped = scratchpad[0]
1129+
self.assertTrue(wrapped)
1130+
self.assertIsInstance(scratchpad[0], MaybeEncodingError)
1131+
self.assertIsNotNone(wrapped.exc)
1132+
self.assertIsNotNone(wrapped.value)
10911133

1134+
p.close()
1135+
p.join()
1136+
1137+
class _TestPoolWorkerLifetime(BaseTestCase):
10921138
ALLOWED_TYPES = ('processes', )
1139+
10931140
def test_pool_worker_lifetime(self):
10941141
p = multiprocessing.Pool(3, maxtasksperchild=10)
10951142
self.assertEqual(3, len(p._pool))

0 commit comments

Comments
 (0)