asyncio.tasks: Fix as_completed, gather & wait to work with duplicate coroutines

This commit is contained in:
Yury Selivanov 2014-02-06 22:06:16 -05:00
parent 2ddb39a695
commit 622be340fd
2 changed files with 51 additions and 11 deletions

View file

@ -364,7 +364,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
fs = set(async(f, loop=loop) for f in fs) fs = {async(f, loop=loop) for f in set(fs)}
if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
raise ValueError('Invalid return_when value: {}'.format(return_when)) raise ValueError('Invalid return_when value: {}'.format(return_when))
@ -476,7 +476,7 @@ def as_completed(fs, *, loop=None, timeout=None):
""" """
loop = loop if loop is not None else events.get_event_loop() loop = loop if loop is not None else events.get_event_loop()
deadline = None if timeout is None else loop.time() + timeout deadline = None if timeout is None else loop.time() + timeout
todo = set(async(f, loop=loop) for f in fs) todo = {async(f, loop=loop) for f in set(fs)}
completed = collections.deque() completed = collections.deque()
@coroutine @coroutine
@ -568,7 +568,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
prevent the cancellation of one child to cause other children to prevent the cancellation of one child to cause other children to
be cancelled.) be cancelled.)
""" """
children = [async(fut, loop=loop) for fut in coros_or_futures] arg_to_fut = {arg: async(arg, loop=loop) for arg in set(coros_or_futures)}
children = [arg_to_fut[arg] for arg in coros_or_futures]
n = len(children) n = len(children)
if n == 0: if n == 0:
outer = futures.Future(loop=loop) outer = futures.Future(loop=loop)

View file

@ -483,6 +483,21 @@ class TaskTests(unittest.TestCase):
self.assertEqual(res, 42) self.assertEqual(res, 42)
def test_wait_duplicate_coroutines(self):
@asyncio.coroutine
def coro(s):
return s
c = coro('test')
task = asyncio.Task(
asyncio.wait([c, c, coro('spam')], loop=self.loop),
loop=self.loop)
done, pending = self.loop.run_until_complete(task)
self.assertFalse(pending)
self.assertEqual(set(f.result() for f in done), {'test', 'spam'})
def test_wait_errors(self): def test_wait_errors(self):
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, ValueError, self.loop.run_until_complete,
@ -757,14 +772,10 @@ class TaskTests(unittest.TestCase):
def test_as_completed_with_timeout(self): def test_as_completed_with_timeout(self):
def gen(): def gen():
when = yield yield
self.assertAlmostEqual(0.12, when) yield 0
when = yield 0 yield 0
self.assertAlmostEqual(0.1, when) yield 0.1
when = yield 0
self.assertAlmostEqual(0.15, when)
when = yield 0.1
self.assertAlmostEqual(0.12, when)
yield 0.02 yield 0.02
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
@ -840,6 +851,25 @@ class TaskTests(unittest.TestCase):
done, pending = loop.run_until_complete(waiter) done, pending = loop.run_until_complete(waiter)
self.assertEqual(set(f.result() for f in done), {'a', 'b'}) self.assertEqual(set(f.result() for f in done), {'a', 'b'})
def test_as_completed_duplicate_coroutines(self):
@asyncio.coroutine
def coro(s):
return s
@asyncio.coroutine
def runner():
result = []
c = coro('ham')
for f in asyncio.as_completed({c, c, coro('spam')}, loop=self.loop):
result.append((yield from f))
return result
fut = asyncio.Task(runner(), loop=self.loop)
self.loop.run_until_complete(fut)
result = fut.result()
self.assertEqual(set(result), {'ham', 'spam'})
self.assertEqual(len(result), 2)
def test_sleep(self): def test_sleep(self):
def gen(): def gen():
@ -1505,6 +1535,15 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
gen3.close() gen3.close()
gen4.close() gen4.close()
def test_duplicate_coroutines(self):
@asyncio.coroutine
def coro(s):
return s
c = coro('abc')
fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop)
self._run_loop(self.one_loop)
self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc'])
def test_cancellation_broadcast(self): def test_cancellation_broadcast(self):
# Cancelling outer() cancels all children. # Cancelling outer() cancels all children.
proof = 0 proof = 0