Skip to content

Commit 2feaad9

Browse files
committed
use general Array protocol for other in comparisons
Signed-off-by: nstarman <[email protected]>
1 parent 6e8c89d commit 2feaad9

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

Diff for: src/_array_api_conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
]
6767
nitpick_ignore_regex = [
6868
("py:class", ".*array"),
69+
("py:class", ".*Array"),
6970
("py:class", ".*device"),
7071
("py:class", ".*Device"),
7172
("py:class", ".*dtype"),

Diff for: src/array_api_stubs/_draft/array_object.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __abs__(self: array, /) -> array:
159159
"""
160160
...
161161

162-
def __add__(self: array, other: int | float | array, /) -> array:
162+
def __add__(self: array, other: int | float | Array, /) -> array:
163163
"""
164164
Calculates the sum for each element of an array instance with the respective element of the array ``other``.
165165
@@ -513,7 +513,7 @@ def __dlpack_device__(self, /) -> tuple[Enum, int]:
513513
# Note that __eq__ returns an array while `object.__eq__` returns a bool.
514514
# Hence Mypy will complain that this violates the Liskov substitution
515515
# principle - ignore that.
516-
def __eq__(self: array, other: int | float | bool | array, /) -> array: # type: ignore[override]
516+
def __eq__(self: array, other: int | float | bool | Array, /) -> array: # type: ignore[override]
517517
r"""
518518
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
519519
@@ -577,7 +577,7 @@ def __float__(self, /) -> float:
577577
"""
578578
...
579579

580-
def __floordiv__(self: array, other: int | float | array, /) -> array:
580+
def __floordiv__(self: array, other: int | float | Array, /) -> array:
581581
"""
582582
Evaluates ``self_i // other_i`` for each element of an array instance with the respective element of the array ``other``.
583583
@@ -602,7 +602,7 @@ def __floordiv__(self: array, other: int | float | array, /) -> array:
602602
"""
603603
...
604604

605-
def __ge__(self: array, other: int | float | array, /) -> array:
605+
def __ge__(self: array, other: int | float | Array, /) -> array:
606606
"""
607607
Computes the truth value of ``self_i >= other_i`` for each element of an array instance with the respective element of the array ``other``.
608608
@@ -658,7 +658,7 @@ def __getitem__(
658658
"""
659659
...
660660

661-
def __gt__(self: array, other: int | float | array, /) -> array:
661+
def __gt__(self: array, other: int | float | Array, /) -> array:
662662
"""
663663
Computes the truth value of ``self_i > other_i`` for each element of an array instance with the respective element of the array ``other``.
664664
@@ -786,7 +786,7 @@ def __invert__(self: array, /) -> array:
786786
"""
787787
...
788788

789-
def __le__(self: array, other: int | float | array, /) -> array:
789+
def __le__(self: array, other: int | float | Array, /) -> array:
790790
"""
791791
Computes the truth value of ``self_i <= other_i`` for each element of an array instance with the respective element of the array ``other``.
792792
@@ -836,7 +836,7 @@ def __lshift__(self: array, other: int | array, /) -> array:
836836
"""
837837
...
838838

839-
def __lt__(self: array, other: int | float | array, /) -> array:
839+
def __lt__(self: array, other: int | float | Array, /) -> array:
840840
"""
841841
Computes the truth value of ``self_i < other_i`` for each element of an array instance with the respective element of the array ``other``.
842842
@@ -913,7 +913,7 @@ def __matmul__(self: array, other: array, /) -> array:
913913
"""
914914
...
915915

916-
def __mod__(self: array, other: int | float | array, /) -> array:
916+
def __mod__(self: array, other: int | float | Array, /) -> array:
917917
"""
918918
Evaluates ``self_i % other_i`` for each element of an array instance with the respective element of the array ``other``.
919919
@@ -938,7 +938,7 @@ def __mod__(self: array, other: int | float | array, /) -> array:
938938
"""
939939
...
940940

941-
def __mul__(self: array, other: int | float | array, /) -> array:
941+
def __mul__(self: array, other: int | float | Array, /) -> array:
942942
r"""
943943
Calculates the product for each element of an array instance with the respective element of the array ``other``.
944944
@@ -969,7 +969,7 @@ def __mul__(self: array, other: int | float | array, /) -> array:
969969
...
970970

971971
# See note above __eq__ method for explanation of the `type: ignore`
972-
def __ne__(self: array, other: int | float | bool | array, /) -> array: # type: ignore[override]
972+
def __ne__(self: array, other: int | float | bool | Array, /) -> array: # type: ignore[override]
973973
"""
974974
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
975975
@@ -1078,7 +1078,7 @@ def __pos__(self: array, /) -> array:
10781078
"""
10791079
...
10801080

1081-
def __pow__(self: array, other: int | float | array, /) -> array:
1081+
def __pow__(self: array, other: int | float | Array, /) -> array:
10821082
r"""
10831083
Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``.
10841084
@@ -1163,7 +1163,7 @@ def __setitem__(
11631163
"""
11641164
...
11651165

1166-
def __sub__(self: array, other: int | float | array, /) -> array:
1166+
def __sub__(self: array, other: int | float | Array, /) -> array:
11671167
"""
11681168
Calculates the difference for each element of an array instance with the respective element of the array ``other``.
11691169
@@ -1192,7 +1192,7 @@ def __sub__(self: array, other: int | float | array, /) -> array:
11921192
"""
11931193
...
11941194

1195-
def __truediv__(self: array, other: int | float | array, /) -> array:
1195+
def __truediv__(self: array, other: int | float | Array, /) -> array:
11961196
r"""
11971197
Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
11981198

0 commit comments

Comments
 (0)