bpo-26467: Adds AsyncMock for asyncio Mock library support (GH-9296)

This commit is contained in:
Lisa Roach 2019-05-20 09:19:53 -07:00 committed by GitHub
parent 0f72147ce2
commit 77b3b7701a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1161 additions and 20 deletions

View file

@ -13,6 +13,7 @@ __all__ = (
'ANY',
'call',
'create_autospec',
'AsyncMock',
'FILTER_DIR',
'NonCallableMock',
'NonCallableMagicMock',
@ -24,13 +25,13 @@ __all__ = (
__version__ = '1.0'
import asyncio
import io
import inspect
import pprint
import sys
import builtins
from types import ModuleType, MethodType
from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
from functools import wraps, partial
@ -43,6 +44,13 @@ FILTER_DIR = True
# Without this, the __class__ properties wouldn't be set correctly
_safe_super = super
def _is_async_obj(obj):
if getattr(obj, '__code__', None):
return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj)
else:
return False
def _is_instance_mock(obj):
# can't use isinstance on Mock objects because they override __class__
# The base class for all mocks is NonCallableMock
@ -355,7 +363,20 @@ class NonCallableMock(Base):
# every instance has its own class
# so we can create magic methods on the
# class without stomping on other mocks
new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__})
bases = (cls,)
if not issubclass(cls, AsyncMock):
# Check if spec is an async object or function
sig = inspect.signature(NonCallableMock.__init__)
bound_args = sig.bind_partial(cls, *args, **kw).arguments
spec_arg = [
arg for arg in bound_args.keys()
if arg.startswith('spec')
]
if spec_arg:
# what if spec_set is different than spec?
if _is_async_obj(bound_args[spec_arg[0]]):
bases = (AsyncMockMixin, cls,)
new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
instance = object.__new__(new)
return instance
@ -431,6 +452,11 @@ class NonCallableMock(Base):
_eat_self=False):
_spec_class = None
_spec_signature = None
_spec_asyncs = []
for attr in dir(spec):
if asyncio.iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr)
if spec is not None and not _is_list(spec):
if isinstance(spec, type):
@ -448,7 +474,7 @@ class NonCallableMock(Base):
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs
def __get_return_value(self):
ret = self._mock_return_value
@ -886,7 +912,15 @@ class NonCallableMock(Base):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
_new_name = kw.get("_new_name")
if _new_name in self.__dict__['_spec_asyncs']:
return AsyncMock(**kw)
_type = type(self)
if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
klass = AsyncMock
if issubclass(_type, AsyncMockMixin):
klass = MagicMock
if not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock):
klass = MagicMock
@ -932,14 +966,12 @@ def _try_iter(obj):
return obj
class CallableMixin(Base):
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
wraps=None, name=None, spec_set=None, parent=None,
_spec_state=None, _new_name='', _new_parent=None, **kwargs):
self.__dict__['_mock_return_value'] = return_value
_safe_super(CallableMixin, self).__init__(
spec, wraps, name, spec_set, parent,
_spec_state, _new_name, _new_parent, **kwargs
@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock):
"""
def _dot_lookup(thing, comp, import_path):
try:
return getattr(thing, comp)
@ -1279,8 +1310,10 @@ class _patch(object):
if isinstance(original, type):
# If we're patching out a class and there is a spec
inherit = True
Klass = MagicMock
if spec is None and _is_async_obj(original):
Klass = AsyncMock
else:
Klass = MagicMock
_kwargs = {}
if new_callable is not None:
Klass = new_callable
@ -1292,7 +1325,9 @@ class _patch(object):
not_callable = '__call__' not in this_spec
else:
not_callable = not callable(this_spec)
if not_callable:
if _is_async_obj(this_spec):
Klass = AsyncMock
elif not_callable:
Klass = NonCallableMagicMock
if spec is not None:
@ -1733,7 +1768,7 @@ _non_defaults = {
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
'__repr__', '__dir__', '__subclasses__', '__format__',
'__getnewargs_ex__',
'__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__',
}
@ -1750,6 +1785,11 @@ _magics = {
' '.join([magic_methods, numerics, inplace, right]).split()
}
# Magic methods used for async `with` statements
_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
# `__aiter__` is a plain function but used with async calls
_async_magics = _async_method_magics | {"__aiter__"}
_all_magics = _magics | _non_defaults
_unsupported_magics = {
@ -1779,6 +1819,7 @@ _return_values = {
'__float__': 1.0,
'__bool__': True,
'__index__': 1,
'__aexit__': False,
}
@ -1811,10 +1852,19 @@ def _get_iter(self):
return iter(ret_val)
return __iter__
def _get_async_iter(self):
def __aiter__():
ret_val = self.__aiter__._mock_return_value
if ret_val is DEFAULT:
return _AsyncIterator(iter([]))
return _AsyncIterator(iter(ret_val))
return __aiter__
_side_effect_methods = {
'__eq__': _get_eq,
'__ne__': _get_ne,
'__iter__': _get_iter,
'__aiter__': _get_async_iter
}
@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock):
self._mock_set_magics()
class AsyncMagicMixin:
def __init__(self, *args, **kw):
self._mock_set_async_magics() # make magic work for kwargs in init
_safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
self._mock_set_async_magics() # fix magic broken by upper level init
class MagicMock(MagicMixin, Mock):
def _mock_set_async_magics(self):
these_magics = _async_magics
if getattr(self, "_mock_methods", None) is not None:
these_magics = _async_magics.intersection(self._mock_methods)
remove_magics = _async_magics - these_magics
for entry in remove_magics:
if entry in type(self).__dict__:
# remove unneeded magic methods
delattr(self, entry)
# don't overwrite existing attributes if called a second time
these_magics = these_magics - set(type(self).__dict__)
_type = type(self)
for entry in these_magics:
setattr(_type, entry, MagicProxy(entry, self))
class MagicMock(MagicMixin, AsyncMagicMixin, Mock):
"""
MagicMock is a subclass of Mock with default implementations
of most of the magic methods. You can use MagicMock without having to
@ -1920,6 +1995,218 @@ class MagicProxy(object):
return self.create_mock()
class AsyncMockMixin(Base):
awaited = _delegating_property('awaited')
await_count = _delegating_property('await_count')
await_args = _delegating_property('await_args')
await_args_list = _delegating_property('await_args_list')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# asyncio.iscoroutinefunction() checks _is_coroutine property to say if an
# object is a coroutine. Without this check it looks to see if it is a
# function/method, which in this case it is not (since it is an
# AsyncMock).
# It is set through __dict__ because when spec_set is True, this
# attribute is likely undefined.
self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine
self.__dict__['_mock_awaited'] = _AwaitEvent(self)
self.__dict__['_mock_await_count'] = 0
self.__dict__['_mock_await_args'] = None
self.__dict__['_mock_await_args_list'] = _CallList()
code_mock = NonCallableMock(spec_set=CodeType)
code_mock.co_flags = inspect.CO_COROUTINE
self.__dict__['__code__'] = code_mock
async def _mock_call(_mock_self, *args, **kwargs):
self = _mock_self
try:
result = super()._mock_call(*args, **kwargs)
except (BaseException, StopIteration) as e:
side_effect = self.side_effect
if side_effect is not None and not callable(side_effect):
raise
return await _raise(e)
_call = self.call_args
async def proxy():
try:
if inspect.isawaitable(result):
return await result
else:
return result
finally:
self.await_count += 1
self.await_args = _call
self.await_args_list.append(_call)
await self.awaited._notify()
return await proxy()
def assert_awaited(_mock_self):
"""
Assert that the mock was awaited at least once.
"""
self = _mock_self
if self.await_count == 0:
msg = f"Expected {self._mock_name or 'mock'} to have been awaited."
raise AssertionError(msg)
def assert_awaited_once(_mock_self):
"""
Assert that the mock was awaited exactly once.
"""
self = _mock_self
if not self.await_count == 1:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
def assert_awaited_with(_mock_self, *args, **kwargs):
"""
Assert that the last await was with the specified arguments.
"""
self = _mock_self
if self.await_args is None:
expected = self._format_mock_call_signature(args, kwargs)
raise AssertionError(f'Expected await: {expected}\nNot awaited')
def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
return msg
expected = self._call_matcher((args, kwargs))
actual = self._call_matcher(self.await_args)
if expected != actual:
cause = expected if isinstance(expected, Exception) else None
raise AssertionError(_error_message()) from cause
def assert_awaited_once_with(_mock_self, *args, **kwargs):
"""
Assert that the mock was awaited exactly once and with the specified
arguments.
"""
self = _mock_self
if not self.await_count == 1:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
return self.assert_awaited_with(*args, **kwargs)
def assert_any_await(_mock_self, *args, **kwargs):
"""
Assert the mock has ever been awaited with the specified arguments.
"""
self = _mock_self
expected = self._call_matcher((args, kwargs))
actual = [self._call_matcher(c) for c in self.await_args_list]
if expected not in actual:
cause = expected if isinstance(expected, Exception) else None
expected_string = self._format_mock_call_signature(args, kwargs)
raise AssertionError(
'%s await not found' % expected_string
) from cause
def assert_has_awaits(_mock_self, calls, any_order=False):
"""
Assert the mock has been awaited with the specified calls.
The :attr:`await_args_list` list is checked for the awaits.
If `any_order` is False (the default) then the awaits must be
sequential. There can be extra calls before or after the
specified awaits.
If `any_order` is True then the awaits can be in any order, but
they must all appear in :attr:`await_args_list`.
"""
self = _mock_self
expected = [self._call_matcher(c) for c in calls]
cause = expected if isinstance(expected, Exception) else None
all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list)
if not any_order:
if expected not in all_awaits:
raise AssertionError(
f'Awaits not found.\nExpected: {_CallList(calls)}\n',
f'Actual: {self.await_args_list}'
) from cause
return
all_awaits = list(all_awaits)
not_found = []
for kall in expected:
try:
all_awaits.remove(kall)
except ValueError:
not_found.append(kall)
if not_found:
raise AssertionError(
'%r not all found in await list' % (tuple(not_found),)
) from cause
def assert_not_awaited(_mock_self):
"""
Assert that the mock was never awaited.
"""
self = _mock_self
if self.await_count != 0:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
def reset_mock(self, *args, **kwargs):
"""
See :func:`.Mock.reset_mock()`
"""
super().reset_mock(*args, **kwargs)
self.await_count = 0
self.await_args = None
self.await_args_list = _CallList()
class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
"""
Enhance :class:`Mock` with features allowing to mock
an async function.
The :class:`AsyncMock` object will behave so the object is
recognized as an async function, and the result of a call is an awaitable:
>>> mock = AsyncMock()
>>> asyncio.iscoroutinefunction(mock)
True
>>> inspect.isawaitable(mock())
True
The result of ``mock()`` is an async function which will have the outcome
of ``side_effect`` or ``return_value``:
- if ``side_effect`` is a function, the async function will return the
result of that function,
- if ``side_effect`` is an exception, the async function will raise the
exception,
- if ``side_effect`` is an iterable, the async function will return the
next value of the iterable, however, if the sequence of result is
exhausted, ``StopIteration`` is raised immediately,
- if ``side_effect`` is not defined, the async function will return the
value defined by ``return_value``, hence, by default, the async function
returns a new :class:`AsyncMock` object.
If the outcome of ``side_effect`` or ``return_value`` is an async function,
the mock async function obtained when the mock object is called will be this
async function itself (and not an async function returning an async
function).
The test author can also specify a wrapped object with ``wraps``. In this
case, the :class:`Mock` object behavior is the same as with an
:class:`.Mock` object: the wrapped object may have methods
defined as async function functions.
Based on Martin Richard's asyntest project.
"""
class _ANY(object):
"A helper object that compares equal to everything."
@ -2145,7 +2432,6 @@ class _Call(tuple):
call = _Call(from_kall=False)
def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name=None, **kwargs):
"""Create a mock object using another object as a spec. Attributes on the
@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
spec = type(spec)
is_type = isinstance(spec, type)
if getattr(spec, '__code__', None):
is_async_func = asyncio.iscoroutinefunction(spec)
else:
is_async_func = False
_kwargs = {'spec': spec}
if spec_set:
_kwargs = {'spec_set': spec}
@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# descriptors don't have a spec
# because we don't know what type they return
_kwargs = {}
elif is_async_func:
if instance:
raise RuntimeError("Instance can not be True when create_autospec "
"is mocking an async function")
Klass = AsyncMock
elif not _callable(spec):
Klass = NonCallableMagicMock
elif is_type and instance and not _instance_callable(spec):
@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
name=_name, **_kwargs)
if isinstance(spec, FunctionTypes):
wrapped_mock = mock
# should only happen at the top level because we don't
# recurse for functions
mock = _set_signature(mock, spec)
if is_async_func:
mock._is_coroutine = asyncio.coroutines._is_coroutine
mock.await_count = 0
mock.await_args = None
mock.await_args_list = _CallList()
for a in ('assert_awaited',
'assert_awaited_once',
'assert_awaited_with',
'assert_awaited_once_with',
'assert_any_await',
'assert_has_awaits',
'assert_not_awaited'):
def f(*args, **kwargs):
return getattr(wrapped_mock, a)(*args, **kwargs)
setattr(mock, a, f)
else:
_check_signature(spec, mock, is_type, instance)
@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
skipfirst = _must_skip(spec, entry, is_type)
kwargs['_eat_self'] = skipfirst
new = MagicMock(parent=parent, name=entry, _new_name=entry,
_new_parent=parent,
**kwargs)
if asyncio.iscoroutinefunction(original):
child_klass = AsyncMock
else:
child_klass = MagicMock
new = child_klass(parent=parent, name=entry, _new_name=entry,
_new_parent=parent,
**kwargs)
mock._mock_children[entry] = new
_check_signature(original, new, skipfirst=skipfirst)
@ -2438,3 +2753,60 @@ def seal(mock):
continue
if m._mock_new_parent is mock:
seal(m)
async def _raise(exception):
raise exception
class _AsyncIterator:
"""
Wraps an iterator in an asynchronous iterator.
"""
def __init__(self, iterator):
self.iterator = iterator
code_mock = NonCallableMock(spec_set=CodeType)
code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
self.__dict__['__code__'] = code_mock
def __aiter__(self):
return self
async def __anext__(self):
try:
return next(self.iterator)
except StopIteration:
pass
raise StopAsyncIteration
class _AwaitEvent:
def __init__(self, mock):
self._mock = mock
self._condition = None
async def _notify(self):
condition = self._get_condition()
try:
await condition.acquire()
condition.notify_all()
finally:
condition.release()
def _get_condition(self):
"""
Creation of condition is delayed, to minimize the chance of using the
wrong loop.
A user may create a mock with _AwaitEvent before selecting the
execution loop. Requiring a user to delay creation is error-prone and
inflexible. Instead, condition is created when user actually starts to
use the mock.
"""
# No synchronization is needed:
# - asyncio is thread unsafe
# - there are no awaits here, method will be executed without
# switching asyncio context.
if self._condition is None:
self._condition = asyncio.Condition()
return self._condition