Use relative imports in mock and its tests to help backporting (GH-18197)

* asyncio.run only available in 3.8+

* iscoroutinefunction has important bungfixes in 3.8

* IsolatedAsyncioTestCase only available in 3.8+
This commit is contained in:
Chris Withers 2020-01-27 14:11:19 +00:00 committed by GitHub
parent 997443c14c
commit c7dd3c7d87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 79 deletions

View file

@ -30,6 +30,7 @@ import inspect
import pprint import pprint
import sys import sys
import builtins import builtins
from asyncio import iscoroutinefunction
from types import CodeType, ModuleType, MethodType from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr from unittest.util import safe_repr
from functools import wraps, partial from functools import wraps, partial
@ -48,12 +49,12 @@ def _is_async_obj(obj):
return False return False
if hasattr(obj, '__func__'): if hasattr(obj, '__func__'):
obj = getattr(obj, '__func__') obj = getattr(obj, '__func__')
return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj) return iscoroutinefunction(obj) or inspect.isawaitable(obj)
def _is_async_func(func): def _is_async_func(func):
if getattr(func, '__code__', None): if getattr(func, '__code__', None):
return asyncio.iscoroutinefunction(func) return iscoroutinefunction(func)
else: else:
return False return False
@ -488,7 +489,7 @@ class NonCallableMock(Base):
_spec_asyncs = [] _spec_asyncs = []
for attr in dir(spec): for attr in dir(spec):
if asyncio.iscoroutinefunction(getattr(spec, attr, None)): if iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr) _spec_asyncs.append(attr)
if spec is not None and not _is_list(spec): if spec is not None and not _is_list(spec):
@ -2152,7 +2153,7 @@ class AsyncMockMixin(Base):
def __init__(self, /, *args, **kwargs): def __init__(self, /, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# asyncio.iscoroutinefunction() checks _is_coroutine property to say if an # iscoroutinefunction() checks _is_coroutine property to say if an
# object is a coroutine. Without this check it looks to see if it is a # object is a coroutine. Without this check it looks to see if it is a
# function/method, which in this case it is not (since it is an # function/method, which in this case it is not (since it is an
# AsyncMock). # AsyncMock).
@ -2188,7 +2189,7 @@ class AsyncMockMixin(Base):
raise StopAsyncIteration raise StopAsyncIteration
if _is_exception(result): if _is_exception(result):
raise result raise result
elif asyncio.iscoroutinefunction(effect): elif iscoroutinefunction(effect):
result = await effect(*args, **kwargs) result = await effect(*args, **kwargs)
else: else:
result = effect(*args, **kwargs) result = effect(*args, **kwargs)
@ -2200,7 +2201,7 @@ class AsyncMockMixin(Base):
return self.return_value return self.return_value
if self._mock_wraps is not None: if self._mock_wraps is not None:
if asyncio.iscoroutinefunction(self._mock_wraps): if iscoroutinefunction(self._mock_wraps):
return await self._mock_wraps(*args, **kwargs) return await self._mock_wraps(*args, **kwargs)
return self._mock_wraps(*args, **kwargs) return self._mock_wraps(*args, **kwargs)
@ -2337,7 +2338,7 @@ class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
recognized as an async function, and the result of a call is an awaitable: recognized as an async function, and the result of a call is an awaitable:
>>> mock = AsyncMock() >>> mock = AsyncMock()
>>> asyncio.iscoroutinefunction(mock) >>> iscoroutinefunction(mock)
True True
>>> inspect.isawaitable(mock()) >>> inspect.isawaitable(mock())
True True
@ -2710,7 +2711,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
skipfirst = _must_skip(spec, entry, is_type) skipfirst = _must_skip(spec, entry, is_type)
kwargs['_eat_self'] = skipfirst kwargs['_eat_self'] = skipfirst
if asyncio.iscoroutinefunction(original): if iscoroutinefunction(original):
child_klass = AsyncMock child_klass = AsyncMock
else: else:
child_klass = MagicMock child_klass = MagicMock

View file

@ -3,6 +3,8 @@ import inspect
import re import re
import unittest import unittest
from asyncio import run, iscoroutinefunction
from unittest import IsolatedAsyncioTestCase
from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
create_autospec, sentinel, _CallList) create_autospec, sentinel, _CallList)
@ -54,7 +56,7 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
def test_is_coroutine_function_patch(self): def test_is_coroutine_function_patch(self):
@patch.object(AsyncClass, 'async_method') @patch.object(AsyncClass, 'async_method')
def test_async(mock_method): def test_async(mock_method):
self.assertTrue(asyncio.iscoroutinefunction(mock_method)) self.assertTrue(iscoroutinefunction(mock_method))
test_async() test_async()
def test_is_async_patch(self): def test_is_async_patch(self):
@ -62,13 +64,13 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
def test_async(mock_method): def test_async(mock_method):
m = mock_method() m = mock_method()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
@patch(f'{async_foo_name}.async_method') @patch(f'{async_foo_name}.async_method')
def test_no_parent_attribute(mock_method): def test_no_parent_attribute(mock_method):
m = mock_method() m = mock_method()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
test_async() test_async()
test_no_parent_attribute() test_no_parent_attribute()
@ -107,7 +109,7 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
self.assertEqual(await async_func(), 1) self.assertEqual(await async_func(), 1)
self.assertEqual(await async_func_args(1, 2, c=3), 2) self.assertEqual(await async_func_args(1, 2, c=3), 2)
asyncio.run(test_async()) run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func)) self.assertTrue(inspect.iscoroutinefunction(async_func))
@ -115,7 +117,7 @@ class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self): def test_is_async_function_cm(self):
def test_async(): def test_async():
with patch.object(AsyncClass, 'async_method') as mock_method: with patch.object(AsyncClass, 'async_method') as mock_method:
self.assertTrue(asyncio.iscoroutinefunction(mock_method)) self.assertTrue(iscoroutinefunction(mock_method))
test_async() test_async()
@ -124,7 +126,7 @@ class AsyncPatchCMTest(unittest.TestCase):
with patch.object(AsyncClass, 'async_method') as mock_method: with patch.object(AsyncClass, 'async_method') as mock_method:
m = mock_method() m = mock_method()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
test_async() test_async()
@ -141,31 +143,31 @@ class AsyncPatchCMTest(unittest.TestCase):
self.assertIsInstance(async_func, AsyncMock) self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func)) self.assertTrue(inspect.iscoroutinefunction(async_func))
asyncio.run(test_async()) run(test_async())
class AsyncMockTest(unittest.TestCase): class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self): def test_iscoroutinefunction_default(self):
mock = AsyncMock() mock = AsyncMock()
self.assertTrue(asyncio.iscoroutinefunction(mock)) self.assertTrue(iscoroutinefunction(mock))
def test_iscoroutinefunction_function(self): def test_iscoroutinefunction_function(self):
async def foo(): pass async def foo(): pass
mock = AsyncMock(foo) mock = AsyncMock(foo)
self.assertTrue(asyncio.iscoroutinefunction(mock)) self.assertTrue(iscoroutinefunction(mock))
self.assertTrue(inspect.iscoroutinefunction(mock)) self.assertTrue(inspect.iscoroutinefunction(mock))
def test_isawaitable(self): def test_isawaitable(self):
mock = AsyncMock() mock = AsyncMock()
m = mock() m = mock()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
self.assertIn('assert_awaited', dir(mock)) self.assertIn('assert_awaited', dir(mock))
def test_iscoroutinefunction_normal_function(self): def test_iscoroutinefunction_normal_function(self):
def foo(): pass def foo(): pass
mock = AsyncMock(foo) mock = AsyncMock(foo)
self.assertTrue(asyncio.iscoroutinefunction(mock)) self.assertTrue(iscoroutinefunction(mock))
self.assertTrue(inspect.iscoroutinefunction(mock)) self.assertTrue(inspect.iscoroutinefunction(mock))
def test_future_isfuture(self): def test_future_isfuture(self):
@ -211,9 +213,9 @@ class AsyncAutospecTest(unittest.TestCase):
self.assertEqual(spec.await_args_list, []) self.assertEqual(spec.await_args_list, [])
spec.assert_not_awaited() spec.assert_not_awaited()
asyncio.run(main()) run(main())
self.assertTrue(asyncio.iscoroutinefunction(spec)) self.assertTrue(iscoroutinefunction(spec))
self.assertTrue(asyncio.iscoroutine(awaitable)) self.assertTrue(asyncio.iscoroutine(awaitable))
self.assertEqual(spec.await_count, 1) self.assertEqual(spec.await_count, 1)
self.assertEqual(spec.await_args, call(1, 2, c=3)) self.assertEqual(spec.await_args, call(1, 2, c=3))
@ -234,7 +236,7 @@ class AsyncAutospecTest(unittest.TestCase):
awaitable = mock_method(1, 2, c=3) awaitable = mock_method(1, 2, c=3)
self.assertIsInstance(mock_method.mock, AsyncMock) self.assertIsInstance(mock_method.mock, AsyncMock)
self.assertTrue(asyncio.iscoroutinefunction(mock_method)) self.assertTrue(iscoroutinefunction(mock_method))
self.assertTrue(asyncio.iscoroutine(awaitable)) self.assertTrue(asyncio.iscoroutine(awaitable))
self.assertTrue(inspect.isawaitable(awaitable)) self.assertTrue(inspect.isawaitable(awaitable))
@ -259,7 +261,7 @@ class AsyncAutospecTest(unittest.TestCase):
self.assertIsNone(mock_method.await_args) self.assertIsNone(mock_method.await_args)
self.assertEqual(mock_method.await_args_list, []) self.assertEqual(mock_method.await_args_list, [])
asyncio.run(test_async()) run(test_async())
class AsyncSpecTest(unittest.TestCase): class AsyncSpecTest(unittest.TestCase):
@ -313,14 +315,14 @@ class AsyncSpecTest(unittest.TestCase):
self.assertIsInstance(mock, AsyncMock) self.assertIsInstance(mock, AsyncMock)
m = mock() m = mock()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
def test_spec_as_normal_positional_AsyncMock(self): def test_spec_as_normal_positional_AsyncMock(self):
mock = AsyncMock(normal_func) mock = AsyncMock(normal_func)
self.assertIsInstance(mock, AsyncMock) self.assertIsInstance(mock, AsyncMock)
m = mock() m = mock()
self.assertTrue(inspect.isawaitable(m)) self.assertTrue(inspect.isawaitable(m))
asyncio.run(m) run(m)
def test_spec_async_mock(self): def test_spec_async_mock(self):
@patch.object(AsyncClass, 'async_method', spec=True) @patch.object(AsyncClass, 'async_method', spec=True)
@ -370,13 +372,13 @@ class AsyncSpecSetTest(unittest.TestCase):
def test_is_async_AsyncMock(self): def test_is_async_AsyncMock(self):
mock = AsyncMock(spec_set=AsyncClass.async_method) mock = AsyncMock(spec_set=AsyncClass.async_method)
self.assertTrue(asyncio.iscoroutinefunction(mock)) self.assertTrue(iscoroutinefunction(mock))
self.assertIsInstance(mock, AsyncMock) self.assertIsInstance(mock, AsyncMock)
def test_is_child_AsyncMock(self): def test_is_child_AsyncMock(self):
mock = MagicMock(spec_set=AsyncClass) mock = MagicMock(spec_set=AsyncClass)
self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) self.assertTrue(iscoroutinefunction(mock.async_method))
self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method)) self.assertFalse(iscoroutinefunction(mock.normal_method))
self.assertIsInstance(mock.async_method, AsyncMock) self.assertIsInstance(mock.async_method, AsyncMock)
self.assertIsInstance(mock.normal_method, MagicMock) self.assertIsInstance(mock.normal_method, MagicMock)
self.assertIsInstance(mock, MagicMock) self.assertIsInstance(mock, MagicMock)
@ -389,7 +391,7 @@ class AsyncSpecSetTest(unittest.TestCase):
self.assertIsInstance(cm, MagicMock) self.assertIsInstance(cm, MagicMock)
class AsyncArguments(unittest.IsolatedAsyncioTestCase): class AsyncArguments(IsolatedAsyncioTestCase):
async def test_add_return_value(self): async def test_add_return_value(self):
async def addition(self, var): async def addition(self, var):
return var + 1 return var + 1
@ -536,8 +538,8 @@ class AsyncMagicMethods(unittest.TestCase):
self.assertIsInstance(m_mock.__aenter__, AsyncMock) self.assertIsInstance(m_mock.__aenter__, AsyncMock)
self.assertIsInstance(m_mock.__aexit__, AsyncMock) self.assertIsInstance(m_mock.__aexit__, AsyncMock)
# AsyncMocks are also coroutine functions # AsyncMocks are also coroutine functions
self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aenter__)) self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aexit__)) self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
class AsyncContextManagerTest(unittest.TestCase): class AsyncContextManagerTest(unittest.TestCase):
@ -574,7 +576,7 @@ class AsyncContextManagerTest(unittest.TestCase):
response.json = AsyncMock(return_value={'json': 123}) response.json = AsyncMock(return_value={'json': 123})
cm.__aenter__.return_value = response cm.__aenter__.return_value = response
pc.session.post.return_value = cm pc.session.post.return_value = cm
result = asyncio.run(pc.main()) result = run(pc.main())
self.assertEqual(result, {'json': 123}) self.assertEqual(result, {'json': 123})
for mock_type in [AsyncMock, MagicMock]: for mock_type in [AsyncMock, MagicMock]:
@ -593,7 +595,7 @@ class AsyncContextManagerTest(unittest.TestCase):
called = True called = True
return result return result
cm_result = asyncio.run(use_context_manager()) cm_result = run(use_context_manager())
self.assertTrue(called) self.assertTrue(called)
self.assertTrue(cm_mock.__aenter__.called) self.assertTrue(cm_mock.__aenter__.called)
self.assertTrue(cm_mock.__aexit__.called) self.assertTrue(cm_mock.__aexit__.called)
@ -618,7 +620,7 @@ class AsyncContextManagerTest(unittest.TestCase):
async with mock_instance as result: async with mock_instance as result:
return result return result
self.assertIs(asyncio.run(use_context_manager()), expected_result) self.assertIs(run(use_context_manager()), expected_result)
def test_mock_customize_async_context_manager_with_coroutine(self): def test_mock_customize_async_context_manager_with_coroutine(self):
enter_called = False enter_called = False
@ -642,7 +644,7 @@ class AsyncContextManagerTest(unittest.TestCase):
async with mock_instance: async with mock_instance:
pass pass
asyncio.run(use_context_manager()) run(use_context_manager())
self.assertTrue(enter_called) self.assertTrue(enter_called)
self.assertTrue(exit_called) self.assertTrue(exit_called)
@ -654,7 +656,7 @@ class AsyncContextManagerTest(unittest.TestCase):
instance = self.WithAsyncContextManager() instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance) mock_instance = MagicMock(instance)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
asyncio.run(raise_in(mock_instance)) run(raise_in(mock_instance))
class AsyncIteratorTest(unittest.TestCase): class AsyncIteratorTest(unittest.TestCase):
@ -678,7 +680,7 @@ class AsyncIteratorTest(unittest.TestCase):
mock_iter.__aiter__.return_value = [1, 2, 3] mock_iter.__aiter__.return_value = [1, 2, 3]
async def main(): async def main():
return [i async for i in mock_iter] return [i async for i in mock_iter]
result = asyncio.run(main()) result = run(main())
self.assertEqual(result, [1, 2, 3]) self.assertEqual(result, [1, 2, 3])
def test_mock_aiter_and_anext_asyncmock(self): def test_mock_aiter_and_anext_asyncmock(self):
@ -687,11 +689,11 @@ class AsyncIteratorTest(unittest.TestCase):
mock_instance = mock_type(instance) mock_instance = mock_type(instance)
# Check that the mock and the real thing bahave the same # Check that the mock and the real thing bahave the same
# __aiter__ is not actually async, so not a coroutinefunction # __aiter__ is not actually async, so not a coroutinefunction
self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__)) self.assertFalse(iscoroutinefunction(instance.__aiter__))
self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__)) self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
# __anext__ is async # __anext__ is async
self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__)) self.assertTrue(iscoroutinefunction(instance.__anext__))
self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__)) self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
for mock_type in [AsyncMock, MagicMock]: for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test aiter and anext corourtine with {mock_type}"): with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
@ -709,18 +711,18 @@ class AsyncIteratorTest(unittest.TestCase):
expected = ["FOO", "BAR", "BAZ"] expected = ["FOO", "BAR", "BAZ"]
def test_default(mock_type): def test_default(mock_type):
mock_instance = mock_type(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
self.assertEqual(asyncio.run(iterate(mock_instance)), []) self.assertEqual(run(iterate(mock_instance)), [])
def test_set_return_value(mock_type): def test_set_return_value(mock_type):
mock_instance = mock_type(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = expected[:] mock_instance.__aiter__.return_value = expected[:]
self.assertEqual(asyncio.run(iterate(mock_instance)), expected) self.assertEqual(run(iterate(mock_instance)), expected)
def test_set_return_value_iter(mock_type): def test_set_return_value_iter(mock_type):
mock_instance = mock_type(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = iter(expected[:]) mock_instance.__aiter__.return_value = iter(expected[:])
self.assertEqual(asyncio.run(iterate(mock_instance)), expected) self.assertEqual(run(iterate(mock_instance)), expected)
for mock_type in [AsyncMock, MagicMock]: for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"default value with {mock_type}"): with self.subTest(f"default value with {mock_type}"):
@ -748,7 +750,7 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertWarns(RuntimeWarning): with self.assertWarns(RuntimeWarning):
# Will raise a warning because never awaited # Will raise a warning because never awaited
mock.async_method() mock.async_method()
self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) self.assertTrue(iscoroutinefunction(mock.async_method))
mock.async_method.assert_called() mock.async_method.assert_called()
mock.async_method.assert_called_once() mock.async_method.assert_called_once()
mock.async_method.assert_called_once_with() mock.async_method.assert_called_once_with()
@ -766,7 +768,7 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
mock.async_method.assert_awaited() mock.async_method.assert_awaited()
asyncio.run(self._await_coroutine(mock_coroutine)) run(self._await_coroutine(mock_coroutine))
# Assert we haven't re-called the function # Assert we haven't re-called the function
mock.async_method.assert_called_once() mock.async_method.assert_called_once()
mock.async_method.assert_awaited() mock.async_method.assert_awaited()
@ -780,7 +782,7 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_called() self.mock.assert_called()
asyncio.run(self._runnable_test()) run(self._runnable_test())
self.mock.assert_called_once() self.mock.assert_called_once()
self.mock.assert_awaited_once() self.mock.assert_awaited_once()
@ -794,7 +796,7 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
mock.async_method.assert_awaited() mock.async_method.assert_awaited()
mock.async_method.assert_called() mock.async_method.assert_called()
asyncio.run(self._await_coroutine(coroutine)) run(self._await_coroutine(coroutine))
mock.async_method.assert_awaited() mock.async_method.assert_awaited()
mock.async_method.assert_awaited_once() mock.async_method.assert_awaited_once()
@ -802,10 +804,10 @@ class AsyncMockAssert(unittest.TestCase):
mock = AsyncMock(AsyncClass) mock = AsyncMock(AsyncClass)
coroutine = mock.async_method() coroutine = mock.async_method()
mock.async_method.assert_called_once() mock.async_method.assert_called_once()
asyncio.run(self._await_coroutine(coroutine)) run(self._await_coroutine(coroutine))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
# Cannot reuse already awaited coroutine # Cannot reuse already awaited coroutine
asyncio.run(self._await_coroutine(coroutine)) run(self._await_coroutine(coroutine))
mock.async_method.assert_awaited() mock.async_method.assert_awaited()
def test_assert_awaited_but_not_called(self): def test_assert_awaited_but_not_called(self):
@ -815,7 +817,7 @@ class AsyncMockAssert(unittest.TestCase):
self.mock.assert_called() self.mock.assert_called()
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
# You cannot await an AsyncMock, it must be a coroutine # You cannot await an AsyncMock, it must be a coroutine
asyncio.run(self._await_coroutine(self.mock)) run(self._await_coroutine(self.mock))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited() self.mock.assert_awaited()
@ -909,17 +911,17 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited() self.mock.assert_awaited()
asyncio.run(self._runnable_test()) run(self._runnable_test())
self.mock.assert_awaited() self.mock.assert_awaited()
def test_assert_awaited_once(self): def test_assert_awaited_once(self):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited_once() self.mock.assert_awaited_once()
asyncio.run(self._runnable_test()) run(self._runnable_test())
self.mock.assert_awaited_once() self.mock.assert_awaited_once()
asyncio.run(self._runnable_test()) run(self._runnable_test())
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited_once() self.mock.assert_awaited_once()
@ -928,15 +930,15 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaisesRegex(AssertionError, msg): with self.assertRaisesRegex(AssertionError, msg):
self.mock.assert_awaited_with('foo') self.mock.assert_awaited_with('foo')
asyncio.run(self._runnable_test()) run(self._runnable_test())
msg = 'expected await not found' msg = 'expected await not found'
with self.assertRaisesRegex(AssertionError, msg): with self.assertRaisesRegex(AssertionError, msg):
self.mock.assert_awaited_with('foo') self.mock.assert_awaited_with('foo')
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
self.mock.assert_awaited_with('foo') self.mock.assert_awaited_with('foo')
asyncio.run(self._runnable_test('SomethingElse')) run(self._runnable_test('SomethingElse'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited_with('foo') self.mock.assert_awaited_with('foo')
@ -944,10 +946,10 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo') self.mock.assert_awaited_once_with('foo')
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
self.mock.assert_awaited_once_with('foo') self.mock.assert_awaited_once_with('foo')
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo') self.mock.assert_awaited_once_with('foo')
@ -955,14 +957,14 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_any_await('foo') self.mock.assert_any_await('foo')
asyncio.run(self._runnable_test('baz')) run(self._runnable_test('baz'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_any_await('foo') self.mock.assert_any_await('foo')
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
self.mock.assert_any_await('foo') self.mock.assert_any_await('foo')
asyncio.run(self._runnable_test('SomethingElse')) run(self._runnable_test('SomethingElse'))
self.mock.assert_any_await('foo') self.mock.assert_any_await('foo')
def test_assert_has_awaits_no_order(self): def test_assert_has_awaits_no_order(self):
@ -972,25 +974,25 @@ class AsyncMockAssert(unittest.TestCase):
self.mock.assert_has_awaits(calls) self.mock.assert_has_awaits(calls)
self.assertEqual(len(cm.exception.args), 1) self.assertEqual(len(cm.exception.args), 1)
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls) self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls) self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('baz')) run(self._runnable_test('baz'))
self.mock.assert_has_awaits(calls) self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('SomethingElse')) run(self._runnable_test('SomethingElse'))
self.mock.assert_has_awaits(calls) self.mock.assert_has_awaits(calls)
def test_awaits_asserts_with_any(self): def test_awaits_asserts_with_any(self):
class Foo: class Foo:
def __eq__(self, other): pass def __eq__(self, other): pass
asyncio.run(self._runnable_test(Foo(), 1)) run(self._runnable_test(Foo(), 1))
self.mock.assert_has_awaits([call(ANY, 1)]) self.mock.assert_has_awaits([call(ANY, 1)])
self.mock.assert_awaited_with(ANY, 1) self.mock.assert_awaited_with(ANY, 1)
@ -1005,7 +1007,7 @@ class AsyncMockAssert(unittest.TestCase):
async def _custom_mock_runnable_test(*args): async def _custom_mock_runnable_test(*args):
await mock_with_spec(*args) await mock_with_spec(*args)
asyncio.run(_custom_mock_runnable_test(Foo(), 1)) run(_custom_mock_runnable_test(Foo(), 1))
mock_with_spec.assert_has_awaits([call(ANY, 1)]) mock_with_spec.assert_has_awaits([call(ANY, 1)])
mock_with_spec.assert_awaited_with(ANY, 1) mock_with_spec.assert_awaited_with(ANY, 1)
mock_with_spec.assert_any_await(ANY, 1) mock_with_spec.assert_any_await(ANY, 1)
@ -1015,24 +1017,24 @@ class AsyncMockAssert(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True) self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('baz')) run(self._runnable_test('baz'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True) self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('bamf')) run(self._runnable_test('bamf'))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True) self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('foo')) run(self._runnable_test('foo'))
self.mock.assert_has_awaits(calls, any_order=True) self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('qux')) run(self._runnable_test('qux'))
self.mock.assert_has_awaits(calls, any_order=True) self.mock.assert_has_awaits(calls, any_order=True)
def test_assert_not_awaited(self): def test_assert_not_awaited(self):
self.mock.assert_not_awaited() self.mock.assert_not_awaited()
asyncio.run(self._runnable_test()) run(self._runnable_test())
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.mock.assert_not_awaited() self.mock.assert_not_awaited()
@ -1040,7 +1042,7 @@ class AsyncMockAssert(unittest.TestCase):
async def f(x=None): pass async def f(x=None): pass
self.mock = AsyncMock(spec=f) self.mock = AsyncMock(spec=f)
asyncio.run(self._runnable_test(1)) run(self._runnable_test(1))
with self.assertRaisesRegex( with self.assertRaisesRegex(
AssertionError, AssertionError,

View file

@ -1,8 +1,8 @@
import asyncio
import math import math
import unittest import unittest
import os import os
import sys import sys
from asyncio import iscoroutinefunction
from unittest.mock import AsyncMock, Mock, MagicMock, _magics from unittest.mock import AsyncMock, Mock, MagicMock, _magics
@ -286,8 +286,8 @@ class TestMockingMagicMethods(unittest.TestCase):
self.assertEqual(math.trunc(mock), mock.__trunc__()) self.assertEqual(math.trunc(mock), mock.__trunc__())
self.assertEqual(math.floor(mock), mock.__floor__()) self.assertEqual(math.floor(mock), mock.__floor__())
self.assertEqual(math.ceil(mock), mock.__ceil__()) self.assertEqual(math.ceil(mock), mock.__ceil__())
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__)) self.assertTrue(iscoroutinefunction(mock.__aexit__))
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__)) self.assertTrue(iscoroutinefunction(mock.__aenter__))
self.assertIsInstance(mock.__aenter__, AsyncMock) self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock) self.assertIsInstance(mock.__aexit__, AsyncMock)
@ -312,8 +312,8 @@ class TestMockingMagicMethods(unittest.TestCase):
self.assertEqual(math.trunc(mock), mock.__trunc__()) self.assertEqual(math.trunc(mock), mock.__trunc__())
self.assertEqual(math.floor(mock), mock.__floor__()) self.assertEqual(math.floor(mock), mock.__floor__())
self.assertEqual(math.ceil(mock), mock.__ceil__()) self.assertEqual(math.ceil(mock), mock.__ceil__())
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__)) self.assertTrue(iscoroutinefunction(mock.__aexit__))
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__)) self.assertTrue(iscoroutinefunction(mock.__aenter__))
self.assertIsInstance(mock.__aenter__, AsyncMock) self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock) self.assertIsInstance(mock.__aexit__, AsyncMock)