Skip to content

Commit a3cf8a1

Browse files
committed
♻️ refactor ufunc annotations for logical operations
1 parent b412b6d commit a3cf8a1

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

src/numpy-stubs/_typing/_ufunc.pyi

+22-22
Original file line numberDiff line numberDiff line change
@@ -224,45 +224,45 @@ class _Call11Bool(Protocol):
224224
@type_check_only
225225
class _Call11Logical(Protocol):
226226
@overload
227-
def __call__( # (scalar, dtype: np.object_) -> np.bool
227+
def __call__( # (scalar, dtype: np.object_) -> bool
228228
self,
229-
x: _NumberLike_co,
229+
x: _ScalarLike_co,
230230
/,
231231
dtype: _DTypeLike[np.object_],
232232
out: None = None,
233233
**kwargs: Unpack[_Kwargs2],
234-
) -> np.bool: ...
234+
) -> bool: ...
235235
@overload
236-
def __call__( # (scalar) -> bool
236+
def __call__( # (scalar) -> np.bool
237237
self,
238238
x: _NumberLike_co,
239239
/,
240240
out: None = None,
241-
dtype: DTypeLike | None = None,
241+
dtype: _DTypeLikeBool | None = None,
242242
**kwargs: Unpack[_Kwargs2],
243243
) -> np.bool: ...
244244
@overload
245245
def __call__( # (array-like, dtype: np.object_) -> np.object_
246246
self,
247-
x: _ArrayLikeNumber_co,
247+
x: _ArrayLikeNumber_co | _ArrayLikeObject_co,
248248
/,
249249
dtype: _DTypeLike[np.object_],
250250
out: None = None,
251251
**kwargs: Unpack[_Kwargs2],
252-
) -> NDArray[np.object_] | np.bool: ...
252+
) -> NDArray[np.object_] | bool: ...
253253
@overload
254254
def __call__( # (array-like, out: T) -> T
255255
self,
256256
x: _ArrayLikeNumber_co,
257257
/,
258-
out: _ArrayT | tuple[_ArrayT],
258+
out: _Out1[_ArrayT],
259259
dtype: DTypeLike | None = None,
260260
**kwargs: Unpack[_Kwargs2],
261261
) -> _ArrayT: ...
262262
@overload # (array) -> Array[bool]
263263
def __call__(
264264
self,
265-
x: _AnyArray,
265+
x: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
266266
/,
267267
out: _Out1[_AnyArray] | None = None,
268268
*,
@@ -275,7 +275,7 @@ class _Call11Logical(Protocol):
275275
x: _ArrayLikeNumber_co,
276276
/,
277277
out: None = None,
278-
dtype: DTypeLike | None = None,
278+
dtype: _DTypeLikeBool | None = None,
279279
**kwargs: Unpack[_Kwargs2],
280280
) -> NDArray[np.bool] | np.bool: ...
281281
@overload
@@ -451,21 +451,21 @@ class _Call21Logical(Protocol):
451451
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
452452
def __call__(
453453
self,
454-
x1: _NumberLike_co,
455-
x2: _NumberLike_co,
454+
x1: _ScalarLike_co,
455+
x2: _ScalarLike_co,
456456
/,
457457
dtype: _DTypeLike[np.object_],
458458
out: None = None,
459459
**kwds: Unpack[_Kwargs3],
460-
) -> np.bool: ...
460+
) -> bool: ...
461461
@overload # (scalar, scalar) -> bool
462462
def __call__(
463463
self,
464464
x1: _NumberLike_co,
465465
x2: _NumberLike_co,
466466
/,
467467
out: None = None,
468-
dtype: DTypeLike | None = None,
468+
dtype: _DTypeLikeBool | None = None,
469469
**kwds: Unpack[_Kwargs3],
470470
) -> np.bool: ...
471471
@overload # (array-like, array, dtype: object_) -> Array[object_]
@@ -485,9 +485,9 @@ class _Call21Logical(Protocol):
485485
x2: NDArray[np.bool | np.number],
486486
/,
487487
out: None = None,
488-
dtype: DTypeLike | None = None,
488+
dtype: _DTypeLikeBool | None = None,
489489
**kwds: Unpack[_Kwargs3],
490-
) -> NDArray[np.bool | np.object_]: ...
490+
) -> NDArray[np.bool]: ...
491491
@overload # (array, array-like, dtype: object_) -> Array[object_]
492492
def __call__(
493493
self,
@@ -501,7 +501,7 @@ class _Call21Logical(Protocol):
501501
@overload # (array, array-like, dtype: dtype[T]) -> Array[T]
502502
def __call__(
503503
self,
504-
x1: NDArray[np.bool | np.number],
504+
x1: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
505505
x2: _ArrayLikeNumber_co,
506506
/,
507507
out: None = None,
@@ -525,27 +525,27 @@ class _Call21Logical(Protocol):
525525
x1: _ArrayLikeNumber_co,
526526
x2: _AnyArray,
527527
/,
528-
out: _Out1[_AnyArray] | None = None,
528+
out: _Out1[NDArray[np.bool]] | None = None,
529529
*,
530530
dtype: DTypeLike | None = None,
531531
**kwds: Unpack[_Kwargs3],
532-
) -> NDArray[np.bool | np.object_]: ...
532+
) -> NDArray[np.bool]: ...
533533
@overload # (array, array-like) -> Array[?]
534534
def __call__(
535535
self,
536536
x1: _AnyArray,
537537
x2: _ArrayLikeNumber_co,
538538
/,
539-
out: _Out1[_AnyArray] | None = None,
539+
out: _Out1[NDArray[np.bool]] | None = None,
540540
*,
541541
dtype: DTypeLike | None = None,
542542
**kwds: Unpack[_Kwargs3],
543543
) -> NDArray[np.bool]: ...
544544
@overload # (array-like, array-like) -> Array[?] | ?
545545
def __call__(
546546
self,
547-
x1: _ArrayLikeNumber_co,
548-
x2: _ArrayLikeNumber_co,
547+
x1: _ArrayLikeNumber_co | _ArrayLikeObject_co,
548+
x2: _ArrayLikeNumber_co | _ArrayLikeObject_co,
549549
/,
550550
out: _Out1[_AnyArray] | None = None,
551551
*,

test/static/accept/ufuncs.pyi

+7-7
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,25 @@ assert_type(np.isfinite(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
5757
assert_type(np.logical_not(True), np.bool_)
5858
assert_type(np.logical_not(AR_bool), npt.NDArray[np.bool_])
5959
assert_type(np.logical_not(AR_bool, out=AR_bool), npt.NDArray[np.bool_])
60-
assert_type(np.logical_not(AR_bool, dtype=np.object_), npt.NDArray[np.object_] | np.bool_)
60+
assert_type(np.logical_not(AR_bool, dtype=np.object_), npt.NDArray[np.object_] | bool)
6161

6262
assert_type(np.logical_and(True, True), np.bool_)
63-
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
63+
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_])
6464
assert_type(np.logical_and(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
6565
assert_type(np.logical_and(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
66-
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
66+
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_])
6767
assert_type(np.logical_and(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
6868

6969
assert_type(np.logical_or(True, True), np.bool_)
70-
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
70+
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_])
7171
assert_type(np.logical_or(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7272
assert_type(np.logical_or(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
73-
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
73+
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_])
7474
assert_type(np.logical_or(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
7575

7676
assert_type(np.logical_xor(True, True), np.bool_)
77-
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_ | np.object_])
77+
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_])
7878
assert_type(np.logical_xor(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
7979
assert_type(np.logical_xor(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
80-
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_ | np.object_])
80+
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_])
8181
assert_type(np.logical_xor(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])

test/static/reject/ufuncs.pyi

+8-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,14 @@ np.isnat(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[report
4848
np.isinf(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
4949
np.isfinite(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5050

51-
np.logical_not(dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
51+
np.logical_not(i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
52+
np.logical_not(dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
53+
54+
np.logical_and(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5255
np.logical_and(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
56+
57+
np.logical_or(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5358
np.logical_or(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
59+
60+
np.logical_xor(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
5461
np.logical_xor(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)