bpo-38093: Correctly returns AsyncMock for async subclasses. (GH-15947)

This commit is contained in:
Lisa Roach 2019-09-19 21:04:18 -07:00 committed by GitHub
parent 2702638eab
commit 8b03f943c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 178 additions and 67 deletions

View file

@ -382,35 +382,88 @@ class AsyncArguments(unittest.TestCase):
class AsyncContextManagerTest(unittest.TestCase):
class WithAsyncContextManager:
async def __aenter__(self, *args, **kwargs):
return self
async def __aexit__(self, *args, **kwargs):
pass
def test_magic_methods_are_async_mocks(self):
mock = MagicMock(self.WithAsyncContextManager())
self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock)
class WithSyncContextManager:
def __enter__(self, *args, **kwargs):
return self
def __exit__(self, *args, **kwargs):
pass
class ProductionCode:
# Example real-world(ish) code
def __init__(self):
self.session = None
async def main(self):
async with self.session.post('https://python.org') as response:
val = await response.json()
return val
def test_async_magic_methods_are_async_mocks_with_magicmock(self):
cm_mock = MagicMock(self.WithAsyncContextManager())
self.assertIsInstance(cm_mock.__aenter__, AsyncMock)
self.assertIsInstance(cm_mock.__aexit__, AsyncMock)
def test_magicmock_has_async_magic_methods(self):
cm = MagicMock(name='magic_cm')
self.assertTrue(hasattr(cm, "__aenter__"))
self.assertTrue(hasattr(cm, "__aexit__"))
def test_magic_methods_are_async_functions(self):
cm = MagicMock(name='magic_cm')
self.assertIsInstance(cm.__aenter__, AsyncMock)
self.assertIsInstance(cm.__aexit__, AsyncMock)
# AsyncMocks are also coroutine functions
self.assertTrue(asyncio.iscoroutinefunction(cm.__aenter__))
self.assertTrue(asyncio.iscoroutinefunction(cm.__aexit__))
def test_set_return_value_of_aenter(self):
def inner_test(mock_type):
pc = self.ProductionCode()
pc.session = MagicMock(name='sessionmock')
cm = mock_type(name='magic_cm')
response = AsyncMock(name='response')
response.json = AsyncMock(return_value={'json': 123})
cm.__aenter__.return_value = response
pc.session.post.return_value = cm
result = asyncio.run(pc.main())
self.assertEqual(result, {'json': 123})
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test set return value of aenter with {mock_type}"):
inner_test(mock_type)
def test_mock_supports_async_context_manager(self):
called = False
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
def inner_test(mock_type):
called = False
cm = self.WithAsyncContextManager()
cm_mock = mock_type(cm)
async def use_context_manager():
nonlocal called
async with mock_instance as result:
called = True
return result
async def use_context_manager():
nonlocal called
async with cm_mock as result:
called = True
return result
cm_result = asyncio.run(use_context_manager())
self.assertTrue(called)
self.assertTrue(cm_mock.__aenter__.called)
self.assertTrue(cm_mock.__aexit__.called)
cm_mock.__aenter__.assert_awaited()
cm_mock.__aexit__.assert_awaited()
# We mock __aenter__ so it does not return self
self.assertIsNot(cm_mock, cm_result)
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test context manager magics with {mock_type}"):
inner_test(mock_type)
result = asyncio.run(use_context_manager())
self.assertTrue(called)
self.assertTrue(mock_instance.__aenter__.called)
self.assertTrue(mock_instance.__aexit__.called)
self.assertIsNot(mock_instance, result)
self.assertIsInstance(result, AsyncMock)
def test_mock_customize_async_context_manager(self):
instance = self.WithAsyncContextManager()
@ -478,27 +531,30 @@ class AsyncIteratorTest(unittest.TestCase):
raise StopAsyncIteration
def test_mock_aiter_and_anext(self):
instance = self.WithAsyncIterator()
mock_instance = MagicMock(instance)
def test_aiter_set_return_value(self):
mock_iter = AsyncMock(name="tester")
mock_iter.__aiter__.return_value = [1, 2, 3]
async def main():
return [i async for i in mock_iter]
result = asyncio.run(main())
self.assertEqual(result, [1, 2, 3])
self.assertEqual(asyncio.iscoroutine(instance.__aiter__),
asyncio.iscoroutine(mock_instance.__aiter__))
self.assertEqual(asyncio.iscoroutine(instance.__anext__),
asyncio.iscoroutine(mock_instance.__anext__))
def test_mock_aiter_and_anext_asyncmock(self):
def inner_test(mock_type):
instance = self.WithAsyncIterator()
mock_instance = mock_type(instance)
# Check that the mock and the real thing bahave the same
# __aiter__ is not actually async, so not a coroutinefunction
self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__))
self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__))
# __anext__ is async
self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__))
self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__))
iterator = instance.__aiter__()
if asyncio.iscoroutine(iterator):
iterator = asyncio.run(iterator)
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
inner_test(mock_type)
mock_iterator = mock_instance.__aiter__()
if asyncio.iscoroutine(mock_iterator):
mock_iterator = asyncio.run(mock_iterator)
self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
asyncio.iscoroutine(mock_iterator.__aiter__))
self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
asyncio.iscoroutine(mock_iterator.__anext__))
def test_mock_async_for(self):
async def iterate(iterator):
@ -509,19 +565,30 @@ class AsyncIteratorTest(unittest.TestCase):
return accumulator
expected = ["FOO", "BAR", "BAZ"]
with self.subTest("iterate through default value"):
mock_instance = MagicMock(self.WithAsyncIterator())
self.assertEqual([], asyncio.run(iterate(mock_instance)))
def test_default(mock_type):
mock_instance = mock_type(self.WithAsyncIterator())
self.assertEqual(asyncio.run(iterate(mock_instance)), [])
with self.subTest("iterate through set return_value"):
mock_instance = MagicMock(self.WithAsyncIterator())
def test_set_return_value(mock_type):
mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = expected[:]
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
with self.subTest("iterate through set return_value iterator"):
mock_instance = MagicMock(self.WithAsyncIterator())
def test_set_return_value_iter(mock_type):
mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = iter(expected[:])
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"default value with {mock_type}"):
test_default(mock_type)
with self.subTest(f"set return_value with {mock_type}"):
test_set_return_value(mock_type)
with self.subTest(f"set return_value iterator with {mock_type}"):
test_set_return_value_iter(mock_type)
class AsyncMockAssert(unittest.TestCase):