Skip to content

Commit 222a9a2

Browse files
committed
Add a few type annotations to implementation
1 parent 2740a24 commit 222a9a2

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

Diff for: asserts/__init__.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
2121
"""
2222

23+
from __future__ import annotations
24+
2325
import re
2426
import sys
2527
from datetime import datetime, timedelta
2628
from json import loads as json_loads
27-
from typing import Set
28-
from warnings import catch_warnings
29+
from typing import Any, Callable, Set
30+
from warnings import WarningMessage, catch_warnings
2931

3032

3133
def fail(msg=None):
@@ -864,7 +866,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"):
864866
fail(msg_fmt.format(msg=msg, actual=actual, now=now))
865867

866868

867-
class AssertRaisesContext(object):
869+
class AssertRaisesContext:
868870
"""A context manager to test for exceptions with certain properties.
869871
870872
When the context is left and no exception has been raised, an
@@ -906,7 +908,7 @@ def __init__(self, exception, msg_fmt="{msg}"):
906908
self._exc_type = exception
907909
self._exc_val = None
908910
self._exception_name = getattr(exception, "__name__", str(exception))
909-
self._tests = []
911+
self._tests: list[Callable[[Any], object]] = []
910912

911913
def __enter__(self):
912914
return self
@@ -929,7 +931,7 @@ def format_message(self, default_msg):
929931
exc_name=self._exception_name,
930932
)
931933

932-
def add_test(self, cb):
934+
def add_test(self, cb: Callable[[Any], object]) -> None:
933935
"""Add a test callback.
934936
935937
This callback is called after determining that the right exception
@@ -1188,16 +1190,19 @@ class AssertWarnsContext(object):
11881190
def __init__(self, warning_class, msg_fmt="{msg}"):
11891191
self._warning_class = warning_class
11901192
self._msg_fmt = msg_fmt
1191-
self._warning_context = None
1193+
self._warning_context: catch_warnings[list[WarningMessage]] | None = (
1194+
None
1195+
)
11921196
self._warnings = []
1193-
self._tests = []
1197+
self._tests: list[Callable[[Warning], bool]] = []
11941198

11951199
def __enter__(self):
11961200
self._warning_context = catch_warnings(record=True)
11971201
self._warnings = self._warning_context.__enter__()
11981202
return self
11991203

12001204
def __exit__(self, exc_type, exc_val, exc_tb):
1205+
assert self._warning_context is not None
12011206
self._warning_context.__exit__(exc_type, exc_val, exc_tb)
12021207
if not any(self._is_expected_warning(w) for w in self._warnings):
12031208
fail(self.format_message())
@@ -1210,12 +1215,12 @@ def format_message(self):
12101215
exc_name=self._warning_class.__name__,
12111216
)
12121217

1213-
def _is_expected_warning(self, warning):
1218+
def _is_expected_warning(self, warning) -> bool:
12141219
if not issubclass(warning.category, self._warning_class):
12151220
return False
12161221
return all(test(warning) for test in self._tests)
12171222

1218-
def add_test(self, cb):
1223+
def add_test(self, cb: Callable[[Warning], bool]) -> None:
12191224
"""Add a test callback.
12201225
12211226
This callback is called after determining that the right warning

Diff for: test_asserts.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1269,8 +1269,9 @@ def extra_test(warning):
12691269
def test_assert_warns__add_test_not_called(self):
12701270
called = Box(False)
12711271

1272-
def extra_test(_):
1272+
def extra_test(_: Warning) -> bool:
12731273
called.value = True
1274+
return False
12741275

12751276
with assert_raises(AssertionError):
12761277
with assert_warns(UserWarning) as context:
@@ -1342,10 +1343,7 @@ def test_assert_warns_regex__not_issued__default_message(self):
13421343
pass
13431344

13441345
def test_assert_warns_regex__not_issued__custom_message(self):
1345-
expected = (
1346-
"no ImportWarning matching 'abc' issued;ImportWarning;"
1347-
"ImportWarning;abc"
1348-
)
1346+
expected = "no ImportWarning matching 'abc' issued;ImportWarning;ImportWarning;abc"
13491347
with _assert_raises_assertion(expected):
13501348
msg_fmt = "{msg};{exc_type.__name__};{exc_name};{pattern}"
13511349
with assert_warns_regex(ImportWarning, r"abc", msg_fmt=msg_fmt):

0 commit comments

Comments
 (0)