Skip to content

Commit e972883

Browse files
authored
Merge pull request #320 from guan404ming/1-to-1-bool
🏷️ ufunc annotations for `is{finite,inf,nan.nat}`
2 parents 9f463f6 + 498c897 commit e972883

File tree

4 files changed

+72
-6
lines changed

4 files changed

+72
-6
lines changed

src/numpy-stubs/__init__.pyi

+5-5
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ from ._typing._char_codes import (
307307
_UnsignedIntegerCodes,
308308
_VoidCodes,
309309
)
310-
from ._typing._ufunc import _Call21Bool, _gufunc_2_1, _ufunc_1_1, _ufunc_1_2, _ufunc_2_1, _ufunc_2_2
310+
from ._typing._ufunc import _Call11Bool, _Call21Bool, _gufunc_2_1, _ufunc_1_1, _ufunc_1_2, _ufunc_2_1, _ufunc_2_2
311311
from .lib import scimath as emath
312312
from .lib._arraypad_impl import pad
313313
from .lib._arraysetops_impl import (
@@ -6989,10 +6989,10 @@ expm1: Final[_ufunc_1_1] = ...
69896989
fabs: Final[_ufunc_1_1] = ...
69906990
floor: Final[_ufunc_1_1] = ...
69916991
invert: Final[_ufunc_1_1] = ...
6992-
isfinite: Final[_ufunc_1_1] = ...
6993-
isinf: Final[_ufunc_1_1] = ...
6994-
isnan: Final[_ufunc_1_1] = ...
6995-
isnat: Final[_ufunc_1_1] = ...
6992+
isfinite: Final[_ufunc_1_1[_Call11Bool]] = ...
6993+
isinf: Final[_ufunc_1_1[_Call11Bool]] = ...
6994+
isnan: Final[_ufunc_1_1[_Call11Bool]] = ...
6995+
isnat: Final[_ufunc_1_1[_Call11Bool]] = ...
69966996
log: Final[_ufunc_1_1] = ...
69976997
log2: Final[_ufunc_1_1] = ...
69986998
log10: Final[_ufunc_1_1] = ...

src/numpy-stubs/_typing/_ufunc.pyi

+44-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from typing_extensions import Never, TypeAliasType, TypeVar, Unpack
1414

1515
import numpy as np
1616
from numpy import _CastingKind, _OrderKACF # noqa: ICN003
17-
from numpy._typing import _NestedSequence
17+
from numpy._typing import _DTypeLikeBool, _NestedSequence
1818

1919
from ._array_like import ArrayLike, NDArray, _ArrayLikeBool_co, _ArrayLikeInt_co
2020
from ._dtype_like import DTypeLike, _DTypeLike
@@ -178,6 +178,49 @@ class _Call11(Protocol):
178178
**kwds: Unpack[_Kwargs2],
179179
) -> Any: ...
180180

181+
@type_check_only
182+
class _Call11Bool(Protocol):
183+
@overload # (scalar) -> bool
184+
def __call__(
185+
self,
186+
x: _ScalarLike_co,
187+
/,
188+
out: None = None,
189+
*,
190+
dtype: _DTypeLikeBool | None = None,
191+
**kwds: Unpack[_Kwargs2],
192+
) -> np.bool: ...
193+
@overload # (array-like, out: T) -> T
194+
def __call__(
195+
self,
196+
x: ArrayLike,
197+
/,
198+
out: _Out1[_ArrayT],
199+
*,
200+
dtype: _DTypeLikeBool | None = None,
201+
**kwds: Unpack[_Kwargs2],
202+
) -> _ArrayT: ...
203+
@overload # (array) -> Array[bool] | bool
204+
def __call__(
205+
self,
206+
x: _AnyArray,
207+
/,
208+
out: _Out1[_AnyArray] | None = None,
209+
*,
210+
dtype: _DTypeLikeBool | None = None,
211+
**kwds: Unpack[_Kwargs2],
212+
) -> NDArray[np.bool]: ...
213+
@overload # (array-like) -> Array[bool] | bool
214+
def __call__(
215+
self,
216+
x: ArrayLike,
217+
/,
218+
out: _Out1[_AnyArray] | None = None,
219+
*,
220+
dtype: _DTypeLikeBool | None = None,
221+
**kwds: Unpack[_Kwargs2],
222+
) -> NDArray[np.bool] | np.bool: ...
223+
181224
@type_check_only
182225
class _Call12(Protocol):
183226
@overload

test/static/accept/ufuncs.pyi

+17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ i8: np.int64
88
f8: np.float64
99
AR_f8: npt.NDArray[np.float64]
1010
AR_i8: npt.NDArray[np.int64]
11+
AR_bool: npt.NDArray[np.bool_]
1112

1213
assert_type(np.absolute.types, list[str])
1314

@@ -36,3 +37,19 @@ assert_type(np.vecdot(AR_f8, AR_f8), Any)
3637

3738
assert_type(np.bitwise_count(i8), Any)
3839
assert_type(np.bitwise_count(AR_i8), npt.NDArray[Any])
40+
41+
assert_type(np.isnan(f8), np.bool_)
42+
assert_type(np.isnan(AR_f8), npt.NDArray[np.bool_])
43+
assert_type(np.isnan(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
44+
45+
assert_type(np.isnat(f8), np.bool_)
46+
assert_type(np.isnat(AR_f8), npt.NDArray[np.bool_])
47+
assert_type(np.isnat(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
48+
49+
assert_type(np.isinf(f8), np.bool_)
50+
assert_type(np.isinf(AR_f8), npt.NDArray[np.bool_])
51+
assert_type(np.isinf(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
52+
53+
assert_type(np.isfinite(f8), np.bool_)
54+
assert_type(np.isfinite(AR_f8), npt.NDArray[np.bool_])
55+
assert_type(np.isfinite(AR_f8, out=AR_bool), npt.NDArray[np.bool_])

test/static/reject/ufuncs.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import numpy.typing as npt
33

4+
i8: np.int64
45
AR_f8: npt.NDArray[np.float64]
56

67
np.sin.nin + "foo" # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
@@ -40,3 +41,8 @@ np.matmul.accumulate() # type: ignore[call-arg] # pyright: ignore[reportCallIs
4041
np.frexp.at() # type: ignore[call-arg, misc] # pyright: ignore[reportCallIssue]
4142
np.divmod.at() # type: ignore[call-arg, misc] # pyright: ignore[reportCallIssue]
4243
np.matmul.at() # type: ignore[call-arg, misc] # pyright: ignore[reportCallIssue]
44+
45+
np.isnan(i8, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
46+
np.isnat(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
47+
np.isinf(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
48+
np.isfinite(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)