@@ -47,6 +47,8 @@ def __init__(self, config):
47
47
self .MagicMock = mock_module .MagicMock
48
48
self .NonCallableMock = mock_module .NonCallableMock
49
49
self .PropertyMock = mock_module .PropertyMock
50
+ if hasattr (mock_module , "AsyncMock" ):
51
+ self .AsyncMock = mock_module .AsyncMock
50
52
self .call = mock_module .call
51
53
self .ANY = mock_module .ANY
52
54
self .DEFAULT = mock_module .DEFAULT
@@ -275,6 +277,41 @@ def wrap_assert_called(*args, **kwargs):
275
277
assert_wrapper (_mock_module_originals ["assert_called" ], * args , ** kwargs )
276
278
277
279
280
+ def wrap_assert_not_awaited (* args , ** kwargs ):
281
+ __tracebackhide__ = True
282
+ assert_wrapper (_mock_module_originals ["assert_not_awaited" ], * args , ** kwargs )
283
+
284
+
285
+ def wrap_assert_awaited_with (* args , ** kwargs ):
286
+ __tracebackhide__ = True
287
+ assert_wrapper (_mock_module_originals ["assert_awaited_with" ], * args , ** kwargs )
288
+
289
+
290
+ def wrap_assert_awaited_once (* args , ** kwargs ):
291
+ __tracebackhide__ = True
292
+ assert_wrapper (_mock_module_originals ["assert_awaited_once" ], * args , ** kwargs )
293
+
294
+
295
+ def wrap_assert_awaited_once_with (* args , ** kwargs ):
296
+ __tracebackhide__ = True
297
+ assert_wrapper (_mock_module_originals ["assert_awaited_once_with" ], * args , ** kwargs )
298
+
299
+
300
+ def wrap_assert_has_awaits (* args , ** kwargs ):
301
+ __tracebackhide__ = True
302
+ assert_wrapper (_mock_module_originals ["assert_has_awaits" ], * args , ** kwargs )
303
+
304
+
305
+ def wrap_assert_any_await (* args , ** kwargs ):
306
+ __tracebackhide__ = True
307
+ assert_wrapper (_mock_module_originals ["assert_any_await" ], * args , ** kwargs )
308
+
309
+
310
+ def wrap_assert_awaited (* args , ** kwargs ):
311
+ __tracebackhide__ = True
312
+ assert_wrapper (_mock_module_originals ["assert_awaited" ], * args , ** kwargs )
313
+
314
+
278
315
def wrap_assert_methods (config ):
279
316
"""
280
317
Wrap assert methods of mock module so we can hide their traceback and
@@ -305,6 +342,26 @@ def wrap_assert_methods(config):
305
342
patcher .start ()
306
343
_mock_module_patches .append (patcher )
307
344
345
+ if hasattr (mock_module , "AsyncMock" ):
346
+ async_wrappers = {
347
+ "assert_awaited" : wrap_assert_awaited ,
348
+ "assert_awaited_once" : wrap_assert_awaited_once ,
349
+ "assert_awaited_with" : wrap_assert_awaited_with ,
350
+ "assert_awaited_once_with" : wrap_assert_awaited_once_with ,
351
+ "assert_any_await" : wrap_assert_any_await ,
352
+ "assert_has_awaits" : wrap_assert_has_awaits ,
353
+ "assert_not_awaited" : wrap_assert_not_awaited ,
354
+ }
355
+ for method , wrapper in async_wrappers .items ():
356
+ try :
357
+ original = getattr (mock_module .AsyncMock , method )
358
+ except AttributeError : # pragma: no cover
359
+ continue
360
+ _mock_module_originals [method ] = original
361
+ patcher = mock_module .patch .object (mock_module .AsyncMock , method , wrapper )
362
+ patcher .start ()
363
+ _mock_module_patches .append (patcher )
364
+
308
365
if hasattr (config , "add_cleanup" ):
309
366
add_cleanup = config .add_cleanup
310
367
else :
0 commit comments