mirror of
https://github.com/python/cpython.git
synced 2025-08-31 05:58:33 +00:00
PEP 0492 -- Coroutines with async and await syntax. Issue #24017.
This commit is contained in:
parent
4e6bf4b3da
commit
7544508f02
72 changed files with 9261 additions and 5739 deletions
862
Lib/test/test_coroutines.py
Normal file
862
Lib/test/test_coroutines.py
Normal file
|
@ -0,0 +1,862 @@
|
|||
import contextlib
|
||||
import gc
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from test import support
|
||||
|
||||
|
||||
class AsyncYieldFrom:
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
|
||||
def __await__(self):
|
||||
yield from self.obj
|
||||
|
||||
|
||||
class AsyncYield:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __await__(self):
|
||||
yield self.value
|
||||
|
||||
|
||||
def run_async(coro):
|
||||
assert coro.__class__ is types.GeneratorType
|
||||
|
||||
buffer = []
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
buffer.append(coro.send(None))
|
||||
except StopIteration as ex:
|
||||
result = ex.args[0] if ex.args else None
|
||||
break
|
||||
return buffer, result
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def silence_coro_gc():
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
yield
|
||||
support.gc_collect()
|
||||
|
||||
|
||||
class AsyncBadSyntaxTest(unittest.TestCase):
|
||||
|
||||
def test_badsyntax_1(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async1
|
||||
|
||||
def test_badsyntax_2(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async2
|
||||
|
||||
def test_badsyntax_3(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async3
|
||||
|
||||
def test_badsyntax_4(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async4
|
||||
|
||||
def test_badsyntax_5(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async5
|
||||
|
||||
def test_badsyntax_6(self):
|
||||
with self.assertRaisesRegex(
|
||||
SyntaxError, "'yield' inside async function"):
|
||||
|
||||
import test.badsyntax_async6
|
||||
|
||||
def test_badsyntax_7(self):
|
||||
with self.assertRaisesRegex(
|
||||
SyntaxError, "'yield from' inside async function"):
|
||||
|
||||
import test.badsyntax_async7
|
||||
|
||||
def test_badsyntax_8(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async8
|
||||
|
||||
def test_badsyntax_9(self):
|
||||
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
|
||||
import test.badsyntax_async9
|
||||
|
||||
|
||||
class CoroutineTest(unittest.TestCase):
|
||||
|
||||
def test_gen_1(self):
|
||||
def gen(): yield
|
||||
self.assertFalse(hasattr(gen, '__await__'))
|
||||
|
||||
def test_func_1(self):
|
||||
async def foo():
|
||||
return 10
|
||||
|
||||
f = foo()
|
||||
self.assertIsInstance(f, types.GeneratorType)
|
||||
self.assertTrue(bool(foo.__code__.co_flags & 0x80))
|
||||
self.assertTrue(bool(foo.__code__.co_flags & 0x20))
|
||||
self.assertTrue(bool(f.gi_code.co_flags & 0x80))
|
||||
self.assertTrue(bool(f.gi_code.co_flags & 0x20))
|
||||
self.assertEqual(run_async(f), ([], 10))
|
||||
|
||||
def bar(): pass
|
||||
self.assertFalse(bool(bar.__code__.co_flags & 0x80))
|
||||
|
||||
def test_func_2(self):
|
||||
async def foo():
|
||||
raise StopIteration
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "generator raised StopIteration"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_func_3(self):
|
||||
async def foo():
|
||||
raise StopIteration
|
||||
|
||||
with silence_coro_gc():
|
||||
self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
|
||||
|
||||
def test_func_4(self):
|
||||
async def foo():
|
||||
raise StopIteration
|
||||
|
||||
check = lambda: self.assertRaisesRegex(
|
||||
TypeError, "coroutine-objects do not support iteration")
|
||||
|
||||
with check():
|
||||
list(foo())
|
||||
|
||||
with check():
|
||||
tuple(foo())
|
||||
|
||||
with check():
|
||||
sum(foo())
|
||||
|
||||
with check():
|
||||
iter(foo())
|
||||
|
||||
with check():
|
||||
next(foo())
|
||||
|
||||
with silence_coro_gc(), check():
|
||||
for i in foo():
|
||||
pass
|
||||
|
||||
with silence_coro_gc(), check():
|
||||
[i for i in foo()]
|
||||
|
||||
def test_func_5(self):
|
||||
@types.coroutine
|
||||
def bar():
|
||||
yield 1
|
||||
|
||||
async def foo():
|
||||
await bar()
|
||||
|
||||
check = lambda: self.assertRaisesRegex(
|
||||
TypeError, "coroutine-objects do not support iteration")
|
||||
|
||||
with check():
|
||||
for el in foo(): pass
|
||||
|
||||
# the following should pass without an error
|
||||
for el in bar():
|
||||
self.assertEqual(el, 1)
|
||||
self.assertEqual([el for el in bar()], [1])
|
||||
self.assertEqual(tuple(bar()), (1,))
|
||||
self.assertEqual(next(iter(bar())), 1)
|
||||
|
||||
def test_func_6(self):
|
||||
@types.coroutine
|
||||
def bar():
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
async def foo():
|
||||
await bar()
|
||||
|
||||
f = foo()
|
||||
self.assertEquals(f.send(None), 1)
|
||||
self.assertEquals(f.send(None), 2)
|
||||
with self.assertRaises(StopIteration):
|
||||
f.send(None)
|
||||
|
||||
def test_func_7(self):
|
||||
async def bar():
|
||||
return 10
|
||||
|
||||
def foo():
|
||||
yield from bar()
|
||||
|
||||
with silence_coro_gc(), self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"cannot 'yield from' a coroutine object from a generator"):
|
||||
|
||||
list(foo())
|
||||
|
||||
def test_func_8(self):
|
||||
@types.coroutine
|
||||
def bar():
|
||||
return (yield from foo())
|
||||
|
||||
async def foo():
|
||||
return 'spam'
|
||||
|
||||
self.assertEqual(run_async(bar()), ([], 'spam') )
|
||||
|
||||
def test_func_9(self):
|
||||
async def foo(): pass
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
|
||||
|
||||
foo()
|
||||
support.gc_collect()
|
||||
|
||||
def test_await_1(self):
|
||||
|
||||
async def foo():
|
||||
await 1
|
||||
with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
|
||||
run_async(foo())
|
||||
|
||||
def test_await_2(self):
|
||||
async def foo():
|
||||
await []
|
||||
with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
|
||||
run_async(foo())
|
||||
|
||||
def test_await_3(self):
|
||||
async def foo():
|
||||
await AsyncYieldFrom([1, 2, 3])
|
||||
|
||||
self.assertEqual(run_async(foo()), ([1, 2, 3], None))
|
||||
|
||||
def test_await_4(self):
|
||||
async def bar():
|
||||
return 42
|
||||
|
||||
async def foo():
|
||||
return await bar()
|
||||
|
||||
self.assertEqual(run_async(foo()), ([], 42))
|
||||
|
||||
def test_await_5(self):
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
return
|
||||
|
||||
async def foo():
|
||||
return (await Awaitable())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "__await__.*returned non-iterator of type"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_await_6(self):
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
return iter([52])
|
||||
|
||||
async def foo():
|
||||
return (await Awaitable())
|
||||
|
||||
self.assertEqual(run_async(foo()), ([52], None))
|
||||
|
||||
def test_await_7(self):
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
yield 42
|
||||
return 100
|
||||
|
||||
async def foo():
|
||||
return (await Awaitable())
|
||||
|
||||
self.assertEqual(run_async(foo()), ([42], 100))
|
||||
|
||||
def test_await_8(self):
|
||||
class Awaitable:
|
||||
pass
|
||||
|
||||
async def foo():
|
||||
return (await Awaitable())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "object Awaitable can't be used in 'await' expression"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_await_9(self):
|
||||
def wrap():
|
||||
return bar
|
||||
|
||||
async def bar():
|
||||
return 42
|
||||
|
||||
async def foo():
|
||||
b = bar()
|
||||
|
||||
db = {'b': lambda: wrap}
|
||||
|
||||
class DB:
|
||||
b = wrap
|
||||
|
||||
return (await bar() + await wrap()() + await db['b']()()() +
|
||||
await bar() * 1000 + await DB.b()())
|
||||
|
||||
async def foo2():
|
||||
return -await bar()
|
||||
|
||||
self.assertEqual(run_async(foo()), ([], 42168))
|
||||
self.assertEqual(run_async(foo2()), ([], -42))
|
||||
|
||||
def test_await_10(self):
|
||||
async def baz():
|
||||
return 42
|
||||
|
||||
async def bar():
|
||||
return baz()
|
||||
|
||||
async def foo():
|
||||
return await (await bar())
|
||||
|
||||
self.assertEqual(run_async(foo()), ([], 42))
|
||||
|
||||
def test_await_11(self):
|
||||
def ident(val):
|
||||
return val
|
||||
|
||||
async def bar():
|
||||
return 'spam'
|
||||
|
||||
async def foo():
|
||||
return ident(val=await bar())
|
||||
|
||||
async def foo2():
|
||||
return await bar(), 'ham'
|
||||
|
||||
self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
|
||||
|
||||
def test_await_12(self):
|
||||
async def coro():
|
||||
return 'spam'
|
||||
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
return coro()
|
||||
|
||||
async def foo():
|
||||
return await Awaitable()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "__await__\(\) returned a coroutine"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_await_13(self):
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
return self
|
||||
|
||||
async def foo():
|
||||
return await Awaitable()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "__await__.*returned non-iterator of type"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_with_1(self):
|
||||
class Manager:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
async def __aenter__(self):
|
||||
await AsyncYieldFrom(['enter-1-' + self.name,
|
||||
'enter-2-' + self.name])
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await AsyncYieldFrom(['exit-1-' + self.name,
|
||||
'exit-2-' + self.name])
|
||||
|
||||
if self.name == 'B':
|
||||
return True
|
||||
|
||||
|
||||
async def foo():
|
||||
async with Manager("A") as a, Manager("B") as b:
|
||||
await AsyncYieldFrom([('managers', a.name, b.name)])
|
||||
1/0
|
||||
|
||||
f = foo()
|
||||
result, _ = run_async(f)
|
||||
|
||||
self.assertEqual(
|
||||
result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
|
||||
('managers', 'A', 'B'),
|
||||
'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
|
||||
)
|
||||
|
||||
async def foo():
|
||||
async with Manager("A") as a, Manager("C") as c:
|
||||
await AsyncYieldFrom([('managers', a.name, c.name)])
|
||||
1/0
|
||||
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
run_async(foo())
|
||||
|
||||
def test_with_2(self):
|
||||
class CM:
|
||||
def __aenter__(self):
|
||||
pass
|
||||
|
||||
async def foo():
|
||||
async with CM():
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(AttributeError, '__aexit__'):
|
||||
run_async(foo())
|
||||
|
||||
def test_with_3(self):
|
||||
class CM:
|
||||
def __aexit__(self):
|
||||
pass
|
||||
|
||||
async def foo():
|
||||
async with CM():
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(AttributeError, '__aenter__'):
|
||||
run_async(foo())
|
||||
|
||||
def test_with_4(self):
|
||||
class CM:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self):
|
||||
pass
|
||||
|
||||
async def foo():
|
||||
async with CM():
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(AttributeError, '__aexit__'):
|
||||
run_async(foo())
|
||||
|
||||
def test_with_5(self):
|
||||
# While this test doesn't make a lot of sense,
|
||||
# it's a regression test for an early bug with opcodes
|
||||
# generation
|
||||
|
||||
class CM:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
pass
|
||||
|
||||
async def func():
|
||||
async with CM():
|
||||
assert (1, ) == 1
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
run_async(func())
|
||||
|
||||
def test_with_6(self):
|
||||
class CM:
|
||||
def __aenter__(self):
|
||||
return 123
|
||||
|
||||
def __aexit__(self, *e):
|
||||
return 456
|
||||
|
||||
async def foo():
|
||||
async with CM():
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "object int can't be used in 'await' expression"):
|
||||
# it's important that __aexit__ wasn't called
|
||||
run_async(foo())
|
||||
|
||||
def test_with_7(self):
|
||||
class CM:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
def __aexit__(self, *e):
|
||||
return 456
|
||||
|
||||
async def foo():
|
||||
async with CM():
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "object int can't be used in 'await' expression"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_for_1(self):
|
||||
aiter_calls = 0
|
||||
|
||||
class AsyncIter:
|
||||
def __init__(self):
|
||||
self.i = 0
|
||||
|
||||
async def __aiter__(self):
|
||||
nonlocal aiter_calls
|
||||
aiter_calls += 1
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
self.i += 1
|
||||
|
||||
if not (self.i % 10):
|
||||
await AsyncYield(self.i * 10)
|
||||
|
||||
if self.i > 100:
|
||||
raise StopAsyncIteration
|
||||
|
||||
return self.i, self.i
|
||||
|
||||
|
||||
buffer = []
|
||||
async def test1():
|
||||
async for i1, i2 in AsyncIter():
|
||||
buffer.append(i1 + i2)
|
||||
|
||||
yielded, _ = run_async(test1())
|
||||
# Make sure that __aiter__ was called only once
|
||||
self.assertEqual(aiter_calls, 1)
|
||||
self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
|
||||
self.assertEqual(buffer, [i*2 for i in range(1, 101)])
|
||||
|
||||
|
||||
buffer = []
|
||||
async def test2():
|
||||
nonlocal buffer
|
||||
async for i in AsyncIter():
|
||||
buffer.append(i[0])
|
||||
if i[0] == 20:
|
||||
break
|
||||
else:
|
||||
buffer.append('what?')
|
||||
buffer.append('end')
|
||||
|
||||
yielded, _ = run_async(test2())
|
||||
# Make sure that __aiter__ was called only once
|
||||
self.assertEqual(aiter_calls, 2)
|
||||
self.assertEqual(yielded, [100, 200])
|
||||
self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
|
||||
|
||||
|
||||
buffer = []
|
||||
async def test3():
|
||||
nonlocal buffer
|
||||
async for i in AsyncIter():
|
||||
if i[0] > 20:
|
||||
continue
|
||||
buffer.append(i[0])
|
||||
else:
|
||||
buffer.append('what?')
|
||||
buffer.append('end')
|
||||
|
||||
yielded, _ = run_async(test3())
|
||||
# Make sure that __aiter__ was called only once
|
||||
self.assertEqual(aiter_calls, 3)
|
||||
self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
|
||||
self.assertEqual(buffer, [i for i in range(1, 21)] +
|
||||
['what?', 'end'])
|
||||
|
||||
def test_for_2(self):
|
||||
tup = (1, 2, 3)
|
||||
refs_before = sys.getrefcount(tup)
|
||||
|
||||
async def foo():
|
||||
async for i in tup:
|
||||
print('never going to happen')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "async for' requires an object.*__aiter__.*tuple"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
self.assertEqual(sys.getrefcount(tup), refs_before)
|
||||
|
||||
def test_for_3(self):
|
||||
class I:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
aiter = I()
|
||||
refs_before = sys.getrefcount(aiter)
|
||||
|
||||
async def foo():
|
||||
async for i in aiter:
|
||||
print('never going to happen')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"async for' received an invalid object.*__aiter.*\: I"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
self.assertEqual(sys.getrefcount(aiter), refs_before)
|
||||
|
||||
def test_for_4(self):
|
||||
class I:
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
def __anext__(self):
|
||||
return ()
|
||||
|
||||
aiter = I()
|
||||
refs_before = sys.getrefcount(aiter)
|
||||
|
||||
async def foo():
|
||||
async for i in aiter:
|
||||
print('never going to happen')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"async for' received an invalid object.*__anext__.*tuple"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
self.assertEqual(sys.getrefcount(aiter), refs_before)
|
||||
|
||||
def test_for_5(self):
|
||||
class I:
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
def __anext__(self):
|
||||
return 123
|
||||
|
||||
async def foo():
|
||||
async for i in I():
|
||||
print('never going to happen')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"async for' received an invalid object.*__anext.*int"):
|
||||
|
||||
run_async(foo())
|
||||
|
||||
def test_for_6(self):
|
||||
I = 0
|
||||
|
||||
class Manager:
|
||||
async def __aenter__(self):
|
||||
nonlocal I
|
||||
I += 10000
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
nonlocal I
|
||||
I += 100000
|
||||
|
||||
class Iterable:
|
||||
def __init__(self):
|
||||
self.i = 0
|
||||
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.i > 10:
|
||||
raise StopAsyncIteration
|
||||
self.i += 1
|
||||
return self.i
|
||||
|
||||
##############
|
||||
|
||||
manager = Manager()
|
||||
iterable = Iterable()
|
||||
mrefs_before = sys.getrefcount(manager)
|
||||
irefs_before = sys.getrefcount(iterable)
|
||||
|
||||
async def main():
|
||||
nonlocal I
|
||||
|
||||
async with manager:
|
||||
async for i in iterable:
|
||||
I += 1
|
||||
I += 1000
|
||||
|
||||
run_async(main())
|
||||
self.assertEqual(I, 111011)
|
||||
|
||||
self.assertEqual(sys.getrefcount(manager), mrefs_before)
|
||||
self.assertEqual(sys.getrefcount(iterable), irefs_before)
|
||||
|
||||
##############
|
||||
|
||||
async def main():
|
||||
nonlocal I
|
||||
|
||||
async with Manager():
|
||||
async for i in Iterable():
|
||||
I += 1
|
||||
I += 1000
|
||||
|
||||
async with Manager():
|
||||
async for i in Iterable():
|
||||
I += 1
|
||||
I += 1000
|
||||
|
||||
run_async(main())
|
||||
self.assertEqual(I, 333033)
|
||||
|
||||
##############
|
||||
|
||||
async def main():
|
||||
nonlocal I
|
||||
|
||||
async with Manager():
|
||||
I += 100
|
||||
async for i in Iterable():
|
||||
I += 1
|
||||
else:
|
||||
I += 10000000
|
||||
I += 1000
|
||||
|
||||
async with Manager():
|
||||
I += 100
|
||||
async for i in Iterable():
|
||||
I += 1
|
||||
else:
|
||||
I += 10000000
|
||||
I += 1000
|
||||
|
||||
run_async(main())
|
||||
self.assertEqual(I, 20555255)
|
||||
|
||||
|
||||
class CoroAsyncIOCompatTest(unittest.TestCase):
|
||||
|
||||
def test_asyncio_1(self):
|
||||
import asyncio
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
buffer = []
|
||||
|
||||
class CM:
|
||||
async def __aenter__(self):
|
||||
buffer.append(1)
|
||||
await asyncio.sleep(0.01)
|
||||
buffer.append(2)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await asyncio.sleep(0.01)
|
||||
buffer.append(exc_type.__name__)
|
||||
|
||||
async def f():
|
||||
async with CM() as c:
|
||||
await asyncio.sleep(0.01)
|
||||
raise MyException
|
||||
buffer.append('unreachable')
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(f())
|
||||
except MyException:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
self.assertEqual(buffer, [1, 2, 'MyException'])
|
||||
|
||||
|
||||
class SysSetCoroWrapperTest(unittest.TestCase):
|
||||
|
||||
def test_set_wrapper_1(self):
|
||||
async def foo():
|
||||
return 'spam'
|
||||
|
||||
wrapped = None
|
||||
def wrap(gen):
|
||||
nonlocal wrapped
|
||||
wrapped = gen
|
||||
return gen
|
||||
|
||||
self.assertIsNone(sys.get_coroutine_wrapper())
|
||||
|
||||
sys.set_coroutine_wrapper(wrap)
|
||||
self.assertIs(sys.get_coroutine_wrapper(), wrap)
|
||||
try:
|
||||
f = foo()
|
||||
self.assertTrue(wrapped)
|
||||
|
||||
self.assertEqual(run_async(f), ([], 'spam'))
|
||||
finally:
|
||||
sys.set_coroutine_wrapper(None)
|
||||
|
||||
self.assertIsNone(sys.get_coroutine_wrapper())
|
||||
|
||||
wrapped = None
|
||||
with silence_coro_gc():
|
||||
foo()
|
||||
self.assertFalse(wrapped)
|
||||
|
||||
def test_set_wrapper_2(self):
|
||||
self.assertIsNone(sys.get_coroutine_wrapper())
|
||||
with self.assertRaisesRegex(TypeError, "callable expected, got int"):
|
||||
sys.set_coroutine_wrapper(1)
|
||||
self.assertIsNone(sys.get_coroutine_wrapper())
|
||||
|
||||
|
||||
class CAPITest(unittest.TestCase):
|
||||
|
||||
def test_tp_await_1(self):
|
||||
from _testcapi import awaitType as at
|
||||
|
||||
async def foo():
|
||||
future = at(iter([1]))
|
||||
return (await future)
|
||||
|
||||
self.assertEqual(foo().send(None), 1)
|
||||
|
||||
def test_tp_await_2(self):
|
||||
# Test tp_await to __await__ mapping
|
||||
from _testcapi import awaitType as at
|
||||
future = at(iter([1]))
|
||||
self.assertEqual(next(future.__await__()), 1)
|
||||
|
||||
def test_tp_await_3(self):
|
||||
from _testcapi import awaitType as at
|
||||
|
||||
async def foo():
|
||||
future = at(1)
|
||||
return (await future)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "__await__.*returned non-iterator of type 'int'"):
|
||||
self.assertEqual(foo().send(None), 1)
|
||||
|
||||
|
||||
def test_main():
|
||||
support.run_unittest(AsyncBadSyntaxTest,
|
||||
CoroutineTest,
|
||||
CoroAsyncIOCompatTest,
|
||||
SysSetCoroWrapperTest,
|
||||
CAPITest)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
test_main()
|
Loading…
Add table
Add a link
Reference in a new issue