bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)

This commit is contained in:
Dennis Sweeney 2021-04-11 00:51:35 -04:00 committed by GitHub
parent 9045919bfa
commit dfb45323ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 182 additions and 6 deletions

View file

@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop = None
asyncio.set_event_loop_policy(None)
def test_async_gen_anext(self):
async def gen():
yield 1
yield 2
g = gen()
def check_async_iterator_anext(self, ait_class):
g = ait_class()
async def consume():
results = []
results.append(await anext(g))
@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase):
with self.assertRaises(StopAsyncIteration):
self.loop.run_until_complete(consume())
async def test_2():
g1 = ait_class()
self.assertEqual(await anext(g1), 1)
self.assertEqual(await anext(g1), 2)
with self.assertRaises(StopAsyncIteration):
await anext(g1)
with self.assertRaises(StopAsyncIteration):
await anext(g1)
g2 = ait_class()
self.assertEqual(await anext(g2, "default"), 1)
self.assertEqual(await anext(g2, "default"), 2)
self.assertEqual(await anext(g2, "default"), "default")
self.assertEqual(await anext(g2, "default"), "default")
return "completed"
result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed")
def test_async_generator_anext(self):
async def agen():
yield 1
yield 2
self.check_async_iterator_anext(agen)
def test_python_async_iterator_anext(self):
class MyAsyncIter:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIter)
def test_python_async_iterator_types_coroutine_anext(self):
import types
class MyAsyncIterWithTypesCoro:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
@types.coroutine
def __anext__(self):
if False:
yield "this is a generator-based coroutine"
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
def test_async_gen_aiter(self):
async def gen():
yield 1
@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase):
await anext(gen(), 1, 3)
async def call_with_wrong_type_args():
await anext(1, gen())
async def call_with_kwarg():
await anext(aiterator=gen())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_few_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_many_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_wrong_type_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_kwarg())
def test_anext_bad_await(self):
async def bad_awaitable():
class BadAwaitable:
def __await__(self):
return 42
class MyAsyncIter:
def __aiter__(self):
return self
def __anext__(self):
return BadAwaitable()
regex = r"__await__.*iterator"
awaitable = anext(MyAsyncIter(), "default")
with self.assertRaisesRegex(TypeError, regex):
await awaitable
awaitable = anext(MyAsyncIter())
with self.assertRaisesRegex(TypeError, regex):
await awaitable
return "completed"
result = self.loop.run_until_complete(bad_awaitable())
self.assertEqual(result, "completed")
async def check_anext_returning_iterator(self, aiter_class):
awaitable = anext(aiter_class(), "default")
with self.assertRaises(TypeError):
await awaitable
awaitable = anext(aiter_class())
with self.assertRaises(TypeError):
await awaitable
return "completed"
def test_anext_return_iterator(self):
class WithIterAnext:
def __aiter__(self):
return self
def __anext__(self):
return iter("abc")
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
self.assertEqual(result, "completed")
def test_anext_return_generator(self):
class WithGenAnext:
def __aiter__(self):
return self
def __anext__(self):
yield
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
self.assertEqual(result, "completed")
def test_anext_await_raises(self):
class RaisingAwaitable:
def __await__(self):
raise ZeroDivisionError()
yield
class WithRaisingAwaitableAnext:
def __aiter__(self):
return self
def __anext__(self):
return RaisingAwaitable()
async def do_test():
awaitable = anext(WithRaisingAwaitableAnext())
with self.assertRaises(ZeroDivisionError):
await awaitable
awaitable = anext(WithRaisingAwaitableAnext(), "default")
with self.assertRaises(ZeroDivisionError):
await awaitable
return "completed"
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")
def test_aiter_bad_args(self):
async def gen():