20
20
21
21
"""
22
22
23
+ from __future__ import annotations
24
+
23
25
import re
24
26
import sys
25
27
from datetime import datetime , timedelta
26
28
from json import loads as json_loads
27
- from typing import Set
28
- from warnings import catch_warnings
29
+ from typing import Any , Callable , Set
30
+ from warnings import WarningMessage , catch_warnings
29
31
30
32
31
33
def fail (msg = None ):
@@ -864,7 +866,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"):
864
866
fail (msg_fmt .format (msg = msg , actual = actual , now = now ))
865
867
866
868
867
- class AssertRaisesContext ( object ) :
869
+ class AssertRaisesContext :
868
870
"""A context manager to test for exceptions with certain properties.
869
871
870
872
When the context is left and no exception has been raised, an
@@ -906,7 +908,7 @@ def __init__(self, exception, msg_fmt="{msg}"):
906
908
self ._exc_type = exception
907
909
self ._exc_val = None
908
910
self ._exception_name = getattr (exception , "__name__" , str (exception ))
909
- self ._tests = []
911
+ self ._tests : list [ Callable [[ Any ], object ]] = []
910
912
911
913
def __enter__ (self ):
912
914
return self
@@ -929,7 +931,7 @@ def format_message(self, default_msg):
929
931
exc_name = self ._exception_name ,
930
932
)
931
933
932
- def add_test (self , cb ) :
934
+ def add_test (self , cb : Callable [[ Any ], object ]) -> None :
933
935
"""Add a test callback.
934
936
935
937
This callback is called after determining that the right exception
@@ -1188,16 +1190,19 @@ class AssertWarnsContext(object):
1188
1190
def __init__ (self , warning_class , msg_fmt = "{msg}" ):
1189
1191
self ._warning_class = warning_class
1190
1192
self ._msg_fmt = msg_fmt
1191
- self ._warning_context = None
1193
+ self ._warning_context : catch_warnings [list [WarningMessage ]] | None = (
1194
+ None
1195
+ )
1192
1196
self ._warnings = []
1193
- self ._tests = []
1197
+ self ._tests : list [ Callable [[ Warning ], bool ]] = []
1194
1198
1195
1199
def __enter__ (self ):
1196
1200
self ._warning_context = catch_warnings (record = True )
1197
1201
self ._warnings = self ._warning_context .__enter__ ()
1198
1202
return self
1199
1203
1200
1204
def __exit__ (self , exc_type , exc_val , exc_tb ):
1205
+ assert self ._warning_context is not None
1201
1206
self ._warning_context .__exit__ (exc_type , exc_val , exc_tb )
1202
1207
if not any (self ._is_expected_warning (w ) for w in self ._warnings ):
1203
1208
fail (self .format_message ())
@@ -1210,12 +1215,12 @@ def format_message(self):
1210
1215
exc_name = self ._warning_class .__name__ ,
1211
1216
)
1212
1217
1213
- def _is_expected_warning (self , warning ):
1218
+ def _is_expected_warning (self , warning ) -> bool :
1214
1219
if not issubclass (warning .category , self ._warning_class ):
1215
1220
return False
1216
1221
return all (test (warning ) for test in self ._tests )
1217
1222
1218
- def add_test (self , cb ) :
1223
+ def add_test (self , cb : Callable [[ Warning ], bool ]) -> None :
1219
1224
"""Add a test callback.
1220
1225
1221
1226
This callback is called after determining that the right warning
0 commit comments