Skip to content

Commit e034026

Browse files
committed
♻️ refactor ufunc annotations for logical operations
1 parent a7d5f6e commit e034026

File tree

4 files changed

+187
-26
lines changed

4 files changed

+187
-26
lines changed

src/numpy-stubs/__init__.pyi

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,18 @@ from ._typing._char_codes import (
307307
_UnsignedIntegerCodes,
308308
_VoidCodes,
309309
)
310-
from ._typing._ufunc import _Call11Bool, _Call11Isnat, _Call21Bool, _gufunc_2_1, _ufunc_1_1, _ufunc_1_2, _ufunc_2_1, _ufunc_2_2
310+
from ._typing._ufunc import (
311+
_Call11Bool,
312+
_Call11Isnat
313+
_Call11LogicalNot,
314+
_Call21Bool,
315+
_Call21Logical,
316+
_gufunc_2_1,
317+
_ufunc_1_1,
318+
_ufunc_1_2,
319+
_ufunc_2_1,
320+
_ufunc_2_2,
321+
)
311322
from .lib import scimath as emath
312323
from .lib._arraypad_impl import pad
313324
from .lib._arraysetops_impl import (
@@ -7102,7 +7113,7 @@ log: Final[_ufunc_1_1] = ...
71027113
log2: Final[_ufunc_1_1] = ...
71037114
log10: Final[_ufunc_1_1] = ...
71047115
log1p: Final[_ufunc_1_1] = ...
7105-
logical_not: Final[_ufunc_1_1[_Call11Bool]] = ...
7116+
logical_not: Final[_ufunc_1_1[_Call11LogicalNot]] = ...
71067117
negative: Final[_ufunc_1_1] = ...
71077118
positive: Final[_ufunc_1_1] = ...
71087119
rad2deg: Final[_ufunc_1_1] = ...
@@ -7159,9 +7170,9 @@ ldexp: Final[_ufunc_2_1] = ...
71597170
left_shift: Final[_ufunc_2_1] = ...
71607171
logaddexp: Final[_ufunc_2_1] = ...
71617172
logaddexp2: Final[_ufunc_2_1] = ...
7162-
logical_and: Final[_ufunc_2_1[_Call21Bool]] = ...
7163-
logical_or: Final[_ufunc_2_1[_Call21Bool]] = ...
7164-
logical_xor: Final[_ufunc_2_1[_Call21Bool]] = ...
7173+
logical_and: Final[_ufunc_2_1[_Call21Logical]] = ...
7174+
logical_or: Final[_ufunc_2_1[_Call21Logical]] = ...
7175+
logical_xor: Final[_ufunc_2_1[_Call21Logical]] = ...
71657176
maximum: Final[_ufunc_2_1] = ...
71667177
minimum: Final[_ufunc_2_1] = ...
71677178
mod: Final[_ufunc_2_1] = ...

src/numpy-stubs/_typing/_ufunc.pyi

Lines changed: 161 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ import numpy as np
1616
from numpy import _CastingKind, _OrderKACF # noqa: ICN003
1717
from numpy._typing import _DTypeLikeBool, _NestedSequence
1818

19-
from ._array_like import ArrayLike, NDArray, _ArrayLike, _ArrayLikeBool_co, _ArrayLikeInt_co
20-
from ._dtype_like import DTypeLike, _DTypeLike
21-
from ._scalars import _ScalarLike_co
19+
from ._array_like import ArrayLike, NDArray, _ArrayLike, _ArrayLikeBool_co, _ArrayLikeInt_co, _ArrayLikeNumber_co
20+
from ._dtype_like import DTypeLike, _DTypeLike, _DTypeLikeObject
21+
from ._scalars import _NumberLike_co, _ScalarLike_co
2222
from ._shape import _ShapeLike
2323

2424
###
@@ -201,7 +201,7 @@ class _Call11Bool(Protocol):
201201
dtype: _DTypeLikeBool | None = None,
202202
**kwds: Unpack[_Kwargs2],
203203
) -> _ArrayT: ...
204-
@overload # (array) -> Array[bool] | bool
204+
@overload # (array) -> Array[bool]
205205
def __call__(
206206
self,
207207
x: _AnyArray,
@@ -223,39 +223,72 @@ class _Call11Bool(Protocol):
223223
) -> NDArray[np.bool] | np.bool: ...
224224

225225
@type_check_only
226-
class _Call11Isnat(Protocol):
227-
@overload # (scalar) -> bool
228-
def __call__(
226+
class _Call11LogicalNot(Protocol):
227+
@overload
228+
def __call__( # (scalar, dtype: np.object_) -> np.object_
229229
self,
230-
x: _TimeLike,
230+
x: _NumberLike_co,
231231
/,
232+
dtype: _DTypeLikeObject,
232233
out: None = None,
233-
*,
234-
dtype: _DTypeLikeBool | None = None,
235-
**kwds: Unpack[_Kwargs2],
234+
**kwargs: Unpack[_Kwargs2],
235+
) -> np.object_: ...
236+
@overload
237+
def __call__( # (scalar) -> bool
238+
self,
239+
x: _NumberLike_co,
240+
/,
241+
out: None = None,
242+
dtype: DTypeLike | None = None,
243+
**kwargs: Unpack[_Kwargs2],
236244
) -> np.bool: ...
237-
@overload # (array-like, out: T) -> T
238-
def __call__(
245+
@overload
246+
def __call__( # (array-like, dtype: np.object_) -> np.object_
239247
self,
240-
x: _ArrayLike[_TimeLike],
248+
x: _ArrayLikeNumber_co,
241249
/,
242-
out: _Out1[_ArrayT],
243-
*,
244-
dtype: _DTypeLikeBool | None = None,
245-
**kwds: Unpack[_Kwargs2],
250+
dtype: _DTypeLikeObject,
251+
out: None = None,
252+
**kwargs: Unpack[_Kwargs2],
253+
) -> np.object_: ...
254+
@overload
255+
def __call__( # (array-like, out: T) -> T
256+
self,
257+
x: _ArrayLikeNumber_co,
258+
/,
259+
out: _ArrayT | tuple[_ArrayT],
260+
dtype: DTypeLike | None = None,
261+
**kwargs: Unpack[_Kwargs2],
246262
) -> _ArrayT: ...
247263
@overload # (array) -> Array[bool]
248264
def __call__(
249265
self,
250-
x: NDArray[_TimeLike] | _NestedSequence[_TimeLike],
266+
x: _AnyArray,
251267
/,
252268
out: _Out1[_AnyArray] | None = None,
253269
*,
254-
dtype: _DTypeLikeBool | None = None,
270+
dtype: DTypeLike | None = None,
255271
**kwds: Unpack[_Kwargs2],
256272
) -> NDArray[np.bool]: ...
273+
@overload
274+
def __call__( # (array-like) -> Array[bool] | bool
275+
self,
276+
x: _ArrayLikeNumber_co,
277+
/,
278+
out: None = None,
279+
dtype: DTypeLike | None = None,
280+
**kwargs: Unpack[_Kwargs2],
281+
) -> NDArray[np.bool] | np.bool: ...
282+
@overload
283+
def __call__( # (?) -> ?
284+
self,
285+
x: _CanArrayUFunc,
286+
/,
287+
out: _Out1[_AnyArray] | None = None,
288+
dtype: DTypeLike | None = None,
289+
**kwargs: Unpack[_Kwargs2],
290+
) -> NDArray[np.bool] | np.bool: ...
257291

258-
@type_check_only
259292
class _Call12(Protocol):
260293
@overload
261294
def __call__(
@@ -418,6 +451,113 @@ class _Call21Bool(Protocol):
418451
**kwds: Unpack[_Kwargs3],
419452
) -> np.bool | NDArray[np.bool]: ...
420453

454+
@type_check_only
455+
class _Call21Logical(Protocol):
456+
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
457+
def __call__(
458+
self,
459+
x1: _NumberLike_co,
460+
x2: _NumberLike_co,
461+
/,
462+
dtype: _DTypeLikeObject,
463+
out: None = None,
464+
**kwds: Unpack[_Kwargs3],
465+
) -> np.object_: ...
466+
@overload # (scalar, scalar) -> bool
467+
def __call__(
468+
self,
469+
x1: _NumberLike_co,
470+
x2: _NumberLike_co,
471+
/,
472+
out: None = None,
473+
dtype: DTypeLike | None = None,
474+
**kwds: Unpack[_Kwargs3],
475+
) -> np.bool: ...
476+
@overload # (array-like, array, dtype: np.object_) -> np.object_
477+
def __call__(
478+
self,
479+
x1: _ArrayLikeNumber_co,
480+
x2: _AnyArray,
481+
/,
482+
dtype: _DTypeLikeObject,
483+
out: None = None,
484+
**kwds: Unpack[_Kwargs3],
485+
) -> np.object_: ...
486+
@overload # (array-like, array, dtype: dtype[T]) -> Array[T]
487+
def __call__(
488+
self,
489+
x1: _ArrayLikeNumber_co,
490+
x2: _AnyArray,
491+
/,
492+
out: None = None,
493+
dtype: DTypeLike | None = None,
494+
**kwds: Unpack[_Kwargs3],
495+
) -> NDArray[np.bool]: ...
496+
@overload # (array, array-like, dtype: np.object_) -> np.object_
497+
def __call__(
498+
self,
499+
x1: _AnyArray,
500+
x2: _ArrayLikeNumber_co,
501+
/,
502+
dtype: _DTypeLikeObject,
503+
out: None = None,
504+
**kwds: Unpack[_Kwargs3],
505+
) -> np.object_: ...
506+
@overload # (array, array-like, dtype: dtype[T]) -> Array[T]
507+
def __call__(
508+
self,
509+
x1: _AnyArray,
510+
x2: _ArrayLikeNumber_co,
511+
/,
512+
out: None = None,
513+
dtype: DTypeLike | None = None,
514+
**kwds: Unpack[_Kwargs3],
515+
) -> NDArray[np.bool]: ...
516+
@overload # (array-like, array) -> Array[?]
517+
def __call__(
518+
self,
519+
x1: _ArrayLikeNumber_co,
520+
x2: _AnyArray,
521+
/,
522+
out: _Out1[_AnyArray] | None = None,
523+
*,
524+
dtype: DTypeLike | None = None,
525+
**kwds: Unpack[_Kwargs3],
526+
) -> NDArray[np.bool]: ...
527+
@overload # (array, array-like) -> Array[?]
528+
def __call__(
529+
self,
530+
x1: _AnyArray,
531+
x2: _ArrayLikeNumber_co,
532+
/,
533+
out: _Out1[_AnyArray] | None = None,
534+
*,
535+
dtype: DTypeLike | None = None,
536+
**kwds: Unpack[_Kwargs3],
537+
) -> NDArray[np.bool]: ...
538+
@overload # (array-like, array-like, out: T) -> T
539+
def __call__(
540+
self,
541+
x1: _ArrayLikeNumber_co,
542+
x2: ArrayLike,
543+
/,
544+
out: _ArrayT | tuple[_ArrayT],
545+
*,
546+
dtype: None = None,
547+
**kwds: Unpack[_Kwargs3],
548+
) -> _ArrayT: ...
549+
@overload # (array-like, array-like) -> Array[?] | ?
550+
def __call__(
551+
self,
552+
x1: _ArrayLikeNumber_co,
553+
x2: _ArrayLikeNumber_co,
554+
/,
555+
out: _Out1[_AnyArray] | None = None,
556+
*,
557+
dtype: DTypeLike | None = None,
558+
**kwds: Unpack[_Kwargs3],
559+
) -> NDArray[np.bool] | np.bool: ...
560+
421561
@type_check_only
422562
class _Call21(Protocol):
423563
@overload # (scalar, scalar, dtype: type[T]) -> T

test/static/accept/ufuncs.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,25 @@ assert_type(np.isfinite(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
6666
assert_type(np.logical_not(True), np.bool_)
6767
assert_type(np.logical_not(AR_bool), npt.NDArray[np.bool_])
6868
assert_type(np.logical_not(AR_bool, out=AR_bool), npt.NDArray[np.bool_])
69+
assert_type(np.logical_not(AR_bool, dtype=np.object_), np.object_)
6970

7071
assert_type(np.logical_and(True, True), np.bool_)
7172
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_])
7273
assert_type(np.logical_and(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7374
assert_type(np.logical_and(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7475
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_])
76+
assert_type(np.logical_and(AR_bool, AR_bool, dtype=np.object_), np.object_)
7577

7678
assert_type(np.logical_or(True, True), np.bool_)
7779
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_])
7880
assert_type(np.logical_or(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7981
assert_type(np.logical_or(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8082
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_])
83+
assert_type(np.logical_or(AR_bool, AR_bool, dtype=np.object_), np.object_)
8184

8285
assert_type(np.logical_xor(True, True), np.bool_)
8386
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_])
8487
assert_type(np.logical_xor(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8588
assert_type(np.logical_xor(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8689
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_])
90+
assert_type(np.logical_xor(AR_bool, AR_bool, dtype=np.object_), np.object_)

test/static/reject/ufuncs.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import numpy.typing as npt
33

44
i8: np.int64
55
AR_f8: npt.NDArray[np.float64]
6+
dt64: np.datetime64
67

78
np.sin.nin + "foo" # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
89

@@ -47,3 +48,8 @@ np.isnat(i8, dtype=np.int64) # type: ignore[call-overload] # pyright: ignore[r
4748
np.isnat(i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
4849
np.isinf(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
4950
np.isfinite(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
51+
52+
np.logical_not(dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
53+
np.logical_and(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
54+
np.logical_or(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
55+
np.logical_xor(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)