Skip to content

Commit f7df097

Browse files
committed
♻️ refactor ufunc annotations for logical operations
1 parent f0e2e91 commit f7df097

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

src/numpy-stubs/_typing/_ufunc.pyi

+22-22
Original file line numberDiff line numberDiff line change
@@ -225,45 +225,45 @@ class _Call11Bool(Protocol):
225225
@type_check_only
226226
class _Call11Logical(Protocol):
227227
@overload
228-
def __call__( # (scalar, dtype: np.object_) -> np.bool
228+
def __call__( # (scalar, dtype: np.object_) -> bool
229229
self,
230-
x: _NumberLike_co,
230+
x: _ScalarLike_co,
231231
/,
232232
dtype: _DTypeLike[np.object_],
233233
out: None = None,
234234
**kwargs: Unpack[_Kwargs2],
235-
) -> np.bool: ...
235+
) -> bool: ...
236236
@overload
237-
def __call__( # (scalar) -> bool
237+
def __call__( # (scalar) -> np.bool
238238
self,
239239
x: _NumberLike_co,
240240
/,
241241
out: None = None,
242-
dtype: DTypeLike | None = None,
242+
dtype: _DTypeLikeBool | None = None,
243243
**kwargs: Unpack[_Kwargs2],
244244
) -> np.bool: ...
245245
@overload
246246
def __call__( # (array-like, dtype: np.object_) -> np.object_
247247
self,
248-
x: _ArrayLikeNumber_co,
248+
x: _ArrayLikeNumber_co | _ArrayLikeObject_co,
249249
/,
250250
dtype: _DTypeLike[np.object_],
251251
out: None = None,
252252
**kwargs: Unpack[_Kwargs2],
253-
) -> NDArray[np.object_] | np.bool: ...
253+
) -> NDArray[np.object_] | bool: ...
254254
@overload
255255
def __call__( # (array-like, out: T) -> T
256256
self,
257257
x: _ArrayLikeNumber_co,
258258
/,
259-
out: _ArrayT | tuple[_ArrayT],
259+
out: _Out1[_ArrayT],
260260
dtype: DTypeLike | None = None,
261261
**kwargs: Unpack[_Kwargs2],
262262
) -> _ArrayT: ...
263263
@overload # (array) -> Array[bool]
264264
def __call__(
265265
self,
266-
x: _AnyArray,
266+
x: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
267267
/,
268268
out: _Out1[_AnyArray] | None = None,
269269
*,
@@ -276,7 +276,7 @@ class _Call11Logical(Protocol):
276276
x: _ArrayLikeNumber_co,
277277
/,
278278
out: None = None,
279-
dtype: DTypeLike | None = None,
279+
dtype: _DTypeLikeBool | None = None,
280280
**kwargs: Unpack[_Kwargs2],
281281
) -> NDArray[np.bool] | np.bool: ...
282282
@overload
@@ -456,21 +456,21 @@ class _Call21Logical(Protocol):
456456
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
457457
def __call__(
458458
self,
459-
x1: _NumberLike_co,
460-
x2: _NumberLike_co,
459+
x1: _ScalarLike_co,
460+
x2: _ScalarLike_co,
461461
/,
462462
dtype: _DTypeLike[np.object_],
463463
out: None = None,
464464
**kwds: Unpack[_Kwargs3],
465-
) -> np.bool: ...
465+
) -> bool: ...
466466
@overload # (scalar, scalar) -> bool
467467
def __call__(
468468
self,
469469
x1: _NumberLike_co,
470470
x2: _NumberLike_co,
471471
/,
472472
out: None = None,
473-
dtype: DTypeLike | None = None,
473+
dtype: _DTypeLikeBool | None = None,
474474
**kwds: Unpack[_Kwargs3],
475475
) -> np.bool: ...
476476
@overload # (array-like, array, dtype: object_) -> Array[object_]
@@ -490,9 +490,9 @@ class _Call21Logical(Protocol):
490490
x2: NDArray[np.bool | np.number],
491491
/,
492492
out: None = None,
493-
dtype: DTypeLike | None = None,
493+
dtype: _DTypeLikeBool | None = None,
494494
**kwds: Unpack[_Kwargs3],
495-
) -> NDArray[np.bool | np.object_]: ...
495+
) -> NDArray[np.bool]: ...
496496
@overload # (array, array-like, dtype: object_) -> Array[object_]
497497
def __call__(
498498
self,
@@ -506,7 +506,7 @@ class _Call21Logical(Protocol):
506506
@overload # (array, array-like, dtype: dtype[T]) -> Array[T]
507507
def __call__(
508508
self,
509-
x1: NDArray[np.bool | np.number],
509+
x1: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
510510
x2: _ArrayLikeNumber_co,
511511
/,
512512
out: None = None,
@@ -530,27 +530,27 @@ class _Call21Logical(Protocol):
530530
x1: _ArrayLikeNumber_co,
531531
x2: _AnyArray,
532532
/,
533-
out: _Out1[_AnyArray] | None = None,
533+
out: _Out1[NDArray[np.bool]] | None = None,
534534
*,
535535
dtype: DTypeLike | None = None,
536536
**kwds: Unpack[_Kwargs3],
537-
) -> NDArray[np.bool | np.object_]: ...
537+
) -> NDArray[np.bool]: ...
538538
@overload # (array, array-like) -> Array[?]
539539
def __call__(
540540
self,
541541
x1: _AnyArray,
542542
x2: _ArrayLikeNumber_co,
543543
/,
544-
out: _Out1[_AnyArray] | None = None,
544+
out: _Out1[NDArray[np.bool]] | None = None,
545545
*,
546546
dtype: DTypeLike | None = None,
547547
**kwds: Unpack[_Kwargs3],
548548
) -> NDArray[np.bool]: ...
549549
@overload # (array-like, array-like) -> Array[?] | ?
550550
def __call__(
551551
self,
552-
x1: _ArrayLikeNumber_co,
553-
x2: _ArrayLikeNumber_co,
552+
x1: _ArrayLikeNumber_co | _ArrayLikeObject_co,
553+
x2: _ArrayLikeNumber_co | _ArrayLikeObject_co,
554554
/,
555555
out: _Out1[_AnyArray] | None = None,
556556
*,

test/static/accept/ufuncs.pyi

+7-7
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,25 @@ assert_type(np.isfinite(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
6666
assert_type(np.logical_not(True), np.bool_)
6767
assert_type(np.logical_not(AR_bool), npt.NDArray[np.bool_])
6868
assert_type(np.logical_not(AR_bool, out=AR_bool), npt.NDArray[np.bool_])
69-
assert_type(np.logical_not(AR_bool, dtype=np.object_), npt.NDArray[np.object_] | np.bool_)
69+
assert_type(np.logical_not(AR_bool, dtype=np.object_), npt.NDArray[np.object_] | bool)
7070

7171
assert_type(np.logical_and(True, True), np.bool_)
72-
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
72+
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_])
7373
assert_type(np.logical_and(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7474
assert_type(np.logical_and(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
75-
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
75+
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_])
7676
assert_type(np.logical_and(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
7777

7878
assert_type(np.logical_or(True, True), np.bool_)
79-
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
79+
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_])
8080
assert_type(np.logical_or(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8181
assert_type(np.logical_or(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
82-
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
82+
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_])
8383
assert_type(np.logical_or(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
8484

8585
assert_type(np.logical_xor(True, True), np.bool_)
86-
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
86+
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_])
8787
assert_type(np.logical_xor(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
8888
assert_type(np.logical_xor(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
89-
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
89+
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_])
9090
assert_type(np.logical_xor(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])

test/static/reject/ufuncs.pyi

+8-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ np.isnat(i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,r
4949
np.isinf(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5050
np.isfinite(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5151

52-
np.logical_not(dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
52+
np.logical_not(i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
53+
np.logical_not(dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
54+
55+
np.logical_and(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5356
np.logical_and(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
57+
58+
np.logical_or(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5459
np.logical_or(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
60+
61+
np.logical_xor(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5562
np.logical_xor(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)