bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)

This commit is contained in:
Jason Fried 2019-11-20 16:27:51 -08:00 committed by Lisa Roach
parent e5d1f734db
commit 046442d02b
5 changed files with 102 additions and 41 deletions

View file

@ -1139,8 +1139,8 @@ class CallableMixin(Base):
_new_parent = _new_parent._mock_new_parent
def _execute_mock_call(self, /, *args, **kwargs):
# seperate from _increment_mock_call so that awaited functions are
# executed seperately from their call
# separate from _increment_mock_call so that awaited functions are
# executed separately from their call, also AsyncMock overrides this method
effect = self.side_effect
if effect is not None:
@ -2136,29 +2136,45 @@ class AsyncMockMixin(Base):
code_mock.co_flags = inspect.CO_COROUTINE
self.__dict__['__code__'] = code_mock
async def _mock_call(self, /, *args, **kwargs):
try:
result = super()._mock_call(*args, **kwargs)
except (BaseException, StopIteration) as e:
side_effect = self.side_effect
if side_effect is not None and not callable(side_effect):
raise
return await _raise(e)
async def _execute_mock_call(self, /, *args, **kwargs):
# This is nearly just like super(), except for sepcial handling
# of coroutines
_call = self.call_args
self.await_count += 1
self.await_args = _call
self.await_args_list.append(_call)
async def proxy():
try:
if inspect.isawaitable(result):
return await result
else:
return result
finally:
self.await_count += 1
self.await_args = _call
self.await_args_list.append(_call)
effect = self.side_effect
if effect is not None:
if _is_exception(effect):
raise effect
elif not _callable(effect):
try:
result = next(effect)
except StopIteration:
# It is impossible to propogate a StopIteration
# through coroutines because of PEP 479
raise StopAsyncIteration
if _is_exception(result):
raise result
elif asyncio.iscoroutinefunction(effect):
result = await effect(*args, **kwargs)
else:
result = effect(*args, **kwargs)
return await proxy()
if result is not DEFAULT:
return result
if self._mock_return_value is not DEFAULT:
return self.return_value
if self._mock_wraps is not None:
if asyncio.iscoroutinefunction(self._mock_wraps):
return await self._mock_wraps(*args, **kwargs)
return self._mock_wraps(*args, **kwargs)
return self.return_value
def assert_awaited(self):
"""
@ -2864,10 +2880,6 @@ def seal(mock):
seal(m)
async def _raise(exception):
raise exception
class _AsyncIterator:
"""
Wraps an iterator in an asynchronous iterator.