@@ -1139,8 +1139,8 @@ def _increment_mock_call(self, /, *args, **kwargs):
11391139 _new_parent = _new_parent ._mock_new_parent
11401140
11411141 def _execute_mock_call (self , / , * args , ** kwargs ):
1142- # seperate from _increment_mock_call so that awaited functions are
1143- # executed seperately from their call
1142+ # separate from _increment_mock_call so that awaited functions are
1143+ # executed separately from their call, also AsyncMock overrides this method
11441144
11451145 effect = self .side_effect
11461146 if effect is not None :
@@ -2136,29 +2136,45 @@ def __init__(self, /, *args, **kwargs):
21362136 code_mock .co_flags = inspect .CO_COROUTINE
21372137 self .__dict__ ['__code__' ] = code_mock
21382138
2139- async def _mock_call (self , / , * args , ** kwargs ):
2140- try :
2141- result = super ()._mock_call (* args , ** kwargs )
2142- except (BaseException , StopIteration ) as e :
2143- side_effect = self .side_effect
2144- if side_effect is not None and not callable (side_effect ):
2145- raise
2146- return await _raise (e )
2139+ async def _execute_mock_call (self , / , * args , ** kwargs ):
2140+ # This is nearly just like super(), except for sepcial handling
2141+ # of coroutines
21472142
21482143 _call = self .call_args
2144+ self .await_count += 1
2145+ self .await_args = _call
2146+ self .await_args_list .append (_call )
21492147
2150- async def proxy ():
2151- try :
2152- if inspect .isawaitable (result ):
2153- return await result
2154- else :
2155- return result
2156- finally :
2157- self .await_count += 1
2158- self .await_args = _call
2159- self .await_args_list .append (_call )
2148+ effect = self .side_effect
2149+ if effect is not None :
2150+ if _is_exception (effect ):
2151+ raise effect
2152+ elif not _callable (effect ):
2153+ try :
2154+ result = next (effect )
2155+ except StopIteration :
2156+ # It is impossible to propogate a StopIteration
2157+ # through coroutines because of PEP 479
2158+ raise StopAsyncIteration
2159+ if _is_exception (result ):
2160+ raise result
2161+ elif asyncio .iscoroutinefunction (effect ):
2162+ result = await effect (* args , ** kwargs )
2163+ else :
2164+ result = effect (* args , ** kwargs )
21602165
2161- return await proxy ()
2166+ if result is not DEFAULT :
2167+ return result
2168+
2169+ if self ._mock_return_value is not DEFAULT :
2170+ return self .return_value
2171+
2172+ if self ._mock_wraps is not None :
2173+ if asyncio .iscoroutinefunction (self ._mock_wraps ):
2174+ return await self ._mock_wraps (* args , ** kwargs )
2175+ return self ._mock_wraps (* args , ** kwargs )
2176+
2177+ return self .return_value
21622178
21632179 def assert_awaited (self ):
21642180 """
@@ -2864,10 +2880,6 @@ def seal(mock):
28642880 seal (m )
28652881
28662882
2867- async def _raise (exception ):
2868- raise exception
2869-
2870-
28712883class _AsyncIterator :
28722884 """
28732885 Wraps an iterator in an asynchronous iterator.
0 commit comments