Skip to content

Commit 3c92e9b

Browse files
authored
♻️ tighten _Call21Bool dtype (#322)
* ♻️ tighten `_Call21Bool` dtype * 🔥 remove unused dtype in `_Call21Bool`
1 parent 8373f8c commit 3c92e9b

File tree

5 files changed

+30
-0
lines changed

5 files changed

+30
-0
lines changed

src/numpy-stubs/_typing/_ufunc.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ class _Call21Bool(Protocol):
341341
x2: _ScalarLike_co,
342342
/,
343343
out: None = None,
344+
dtype: _DTypeLikeBool | None = None,
344345
**kwds: Unpack[_Kwargs3],
345346
) -> np.bool: ...
346347
@overload # (array-like, array) -> Array[bool]
@@ -350,6 +351,7 @@ class _Call21Bool(Protocol):
350351
x2: _AnyArray | _NestedSequence[_ScalarLike_co],
351352
/,
352353
out: _Out1[_AnyArray] | None = None,
354+
dtype: _DTypeLikeBool | None = None,
353355
**kwds: Unpack[_Kwargs3],
354356
) -> NDArray[np.bool]: ...
355357
@overload # (array, array-like) -> Array[bool]
@@ -359,6 +361,7 @@ class _Call21Bool(Protocol):
359361
x2: ArrayLike,
360362
/,
361363
out: _Out1[_AnyArray] | None = None,
364+
dtype: _DTypeLikeBool | None = None,
362365
**kwds: Unpack[_Kwargs3],
363366
) -> NDArray[np.bool]: ...
364367
@overload # (array-like, array-like, out: T) -> T
@@ -377,6 +380,7 @@ class _Call21Bool(Protocol):
377380
x2: ArrayLike,
378381
/,
379382
out: _Out1[_AnyArray] | None = None,
383+
dtype: _DTypeLikeBool | None = None,
380384
**kwds: Unpack[_Kwargs3],
381385
) -> np.bool | NDArray[np.bool]: ...
382386

test/static/accept/char.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,32 @@ AR_SUT: BytesArray | StrArray | StringArray
1616
assert_type(np.char.equal(AR_U, AR_U), npt.NDArray[np.bool])
1717
assert_type(np.char.equal(AR_S, AR_S), npt.NDArray[np.bool])
1818
assert_type(np.char.equal(AR_T, AR_T), npt.NDArray[np.bool])
19+
assert_type(np.char.equal(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
1920

2021
assert_type(np.char.not_equal(AR_U, AR_U), npt.NDArray[np.bool])
2122
assert_type(np.char.not_equal(AR_S, AR_S), npt.NDArray[np.bool])
2223
assert_type(np.char.not_equal(AR_T, AR_T), npt.NDArray[np.bool])
24+
assert_type(np.char.not_equal(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
2325

2426
assert_type(np.char.greater_equal(AR_U, AR_U), npt.NDArray[np.bool])
2527
assert_type(np.char.greater_equal(AR_S, AR_S), npt.NDArray[np.bool])
2628
assert_type(np.char.greater_equal(AR_T, AR_T), npt.NDArray[np.bool])
29+
assert_type(np.char.greater_equal(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
2730

2831
assert_type(np.char.less_equal(AR_U, AR_U), npt.NDArray[np.bool])
2932
assert_type(np.char.less_equal(AR_S, AR_S), npt.NDArray[np.bool])
3033
assert_type(np.char.less_equal(AR_T, AR_T), npt.NDArray[np.bool])
34+
assert_type(np.char.less_equal(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
3135

3236
assert_type(np.char.greater(AR_U, AR_U), npt.NDArray[np.bool])
3337
assert_type(np.char.greater(AR_S, AR_S), npt.NDArray[np.bool])
3438
assert_type(np.char.greater(AR_T, AR_T), npt.NDArray[np.bool])
39+
assert_type(np.char.greater(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
3540

3641
assert_type(np.char.less(AR_U, AR_U), npt.NDArray[np.bool])
3742
assert_type(np.char.less(AR_S, AR_S), npt.NDArray[np.bool])
3843
assert_type(np.char.less(AR_T, AR_T), npt.NDArray[np.bool])
44+
assert_type(np.char.less(AR_S, AR_U, dtype=np.bool_), npt.NDArray[np.bool])
3945

4046
assert_type(np.char.multiply(AR_U, 5), StrArray)
4147
assert_type(np.char.multiply(AR_U, 5), StrArray)

test/static/accept/strings.pyi

+6
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,32 @@ AR_T: _StringArray
1717
assert_type(np.strings.equal(AR_U, AR_U), _BoolArray)
1818
assert_type(np.strings.equal(AR_S, AR_S), _BoolArray)
1919
assert_type(np.strings.equal(AR_T, AR_T), _BoolArray)
20+
assert_type(np.strings.equal(AR_S, AR_U, dtype=np.bool_), _BoolArray)
2021

2122
assert_type(np.strings.not_equal(AR_U, AR_U), _BoolArray)
2223
assert_type(np.strings.not_equal(AR_S, AR_S), _BoolArray)
2324
assert_type(np.strings.not_equal(AR_T, AR_T), _BoolArray)
25+
assert_type(np.strings.not_equal(AR_S, AR_U, dtype=np.bool_), _BoolArray)
2426

2527
assert_type(np.strings.greater_equal(AR_U, AR_U), _BoolArray)
2628
assert_type(np.strings.greater_equal(AR_S, AR_S), _BoolArray)
2729
assert_type(np.strings.greater_equal(AR_T, AR_T), _BoolArray)
30+
assert_type(np.strings.greater_equal(AR_S, AR_U, dtype=np.bool_), _BoolArray)
2831

2932
assert_type(np.strings.less_equal(AR_U, AR_U), _BoolArray)
3033
assert_type(np.strings.less_equal(AR_S, AR_S), _BoolArray)
3134
assert_type(np.strings.less_equal(AR_T, AR_T), _BoolArray)
35+
assert_type(np.strings.less_equal(AR_S, AR_U, dtype=np.bool_), _BoolArray)
3236

3337
assert_type(np.strings.greater(AR_U, AR_U), _BoolArray)
3438
assert_type(np.strings.greater(AR_S, AR_S), _BoolArray)
3539
assert_type(np.strings.greater(AR_T, AR_T), _BoolArray)
40+
assert_type(np.strings.greater(AR_S, AR_U, dtype=np.bool_), _BoolArray)
3641

3742
assert_type(np.strings.less(AR_U, AR_U), _BoolArray)
3843
assert_type(np.strings.less(AR_S, AR_S), _BoolArray)
3944
assert_type(np.strings.less(AR_T, AR_T), _BoolArray)
45+
assert_type(np.strings.less(AR_S, AR_U, dtype=np.bool_), _BoolArray)
4046

4147
# TODO(jorenham)
4248
# https://github.com/numpy/numtype/issues/230

test/static/reject/char.pyi

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ AR_S: npt.NDArray[np.bytes_]
1515
# np.char.isdecimal(AR_S)
1616
# np.char.isnumeric(AR_S)
1717

18+
np.char.equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
19+
np.char.not_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
20+
np.char.greater_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
21+
np.char.less_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
22+
np.char.greater(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
23+
np.char.less(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
24+
1825
np.char.encode(AR_S) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1926
np.char.decode(AR_U) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2027

test/static/reject/strings.pyi

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ AR_S: npt.NDArray[np.bytes_]
1616
# np.strings.isdecimal(AR_S)
1717
# np.strings.isnumeric(AR_S)
1818

19+
np.strings.equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
20+
np.strings.not_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
21+
np.strings.greater_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
22+
np.strings.less_equal(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
23+
np.strings.greater(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
24+
np.strings.less(AR_S, AR_U, dtype=int) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
25+
1926
np.strings.encode(AR_S) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2027
np.strings.decode(AR_U) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2128

0 commit comments

Comments
 (0)