2020
2121"""
2222
23+ from __future__ import annotations
24+
2325import re
2426import sys
2527from datetime import datetime , timedelta
2628from json import loads as json_loads
27- from typing import Set
28- from warnings import catch_warnings
29+ from typing import Any , Callable , Set
30+ from warnings import WarningMessage , catch_warnings
2931
3032
3133def fail (msg = None ):
@@ -864,7 +866,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"):
864866 fail (msg_fmt .format (msg = msg , actual = actual , now = now ))
865867
866868
867- class AssertRaisesContext ( object ) :
869+ class AssertRaisesContext :
868870 """A context manager to test for exceptions with certain properties.
869871
870872 When the context is left and no exception has been raised, an
@@ -906,7 +908,7 @@ def __init__(self, exception, msg_fmt="{msg}"):
906908 self ._exc_type = exception
907909 self ._exc_val = None
908910 self ._exception_name = getattr (exception , "__name__" , str (exception ))
909- self ._tests = []
911+ self ._tests : list [ Callable [[ Any ], object ]] = []
910912
911913 def __enter__ (self ):
912914 return self
@@ -929,7 +931,7 @@ def format_message(self, default_msg):
929931 exc_name = self ._exception_name ,
930932 )
931933
932- def add_test (self , cb ) :
934+ def add_test (self , cb : Callable [[ Any ], object ]) -> None :
933935 """Add a test callback.
934936
935937 This callback is called after determining that the right exception
@@ -1188,16 +1190,19 @@ class AssertWarnsContext(object):
11881190 def __init__ (self , warning_class , msg_fmt = "{msg}" ):
11891191 self ._warning_class = warning_class
11901192 self ._msg_fmt = msg_fmt
1191- self ._warning_context = None
1193+ self ._warning_context : catch_warnings [list [WarningMessage ]] | None = (
1194+ None
1195+ )
11921196 self ._warnings = []
1193- self ._tests = []
1197+ self ._tests : list [ Callable [[ Warning ], bool ]] = []
11941198
11951199 def __enter__ (self ):
11961200 self ._warning_context = catch_warnings (record = True )
11971201 self ._warnings = self ._warning_context .__enter__ ()
11981202 return self
11991203
12001204 def __exit__ (self , exc_type , exc_val , exc_tb ):
1205+ assert self ._warning_context is not None
12011206 self ._warning_context .__exit__ (exc_type , exc_val , exc_tb )
12021207 if not any (self ._is_expected_warning (w ) for w in self ._warnings ):
12031208 fail (self .format_message ())
@@ -1210,12 +1215,12 @@ def format_message(self):
12101215 exc_name = self ._warning_class .__name__ ,
12111216 )
12121217
1213- def _is_expected_warning (self , warning ):
1218+ def _is_expected_warning (self , warning ) -> bool :
12141219 if not issubclass (warning .category , self ._warning_class ):
12151220 return False
12161221 return all (test (warning ) for test in self ._tests )
12171222
1218- def add_test (self , cb ) :
1223+ def add_test (self , cb : Callable [[ Warning ], bool ]) -> None :
12191224 """Add a test callback.
12201225
12211226 This callback is called after determining that the right warning
0 commit comments