Skip to content

Commit 17c7c40

Browse files
authored
Merge pull request #145 from crusaderky/np_generics
ENH: disallow numpy generics
2 parents 50a155a + e7fcd34 commit 17c7c40

File tree

2 files changed

+118
-69
lines changed

2 files changed

+118
-69
lines changed

array_api_strict/_array_object.py

+44-43
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,16 @@ def _check_allowed_dtypes(
233233

234234
return other
235235

236-
def _check_device(self, other: Array | bool | int | float | complex) -> None:
237-
"""Check that other is on a device compatible with the current array"""
238-
if isinstance(other, (bool, int, float, complex)):
239-
return
240-
elif isinstance(other, Array):
236+
def _check_type_device(self, other: Array | bool | int | float | complex) -> None:
237+
"""Check that other is either a Python scalar or an array on a device
238+
compatible with the current array.
239+
"""
240+
if isinstance(other, Array):
241241
if self.device != other.device:
242242
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
243-
else:
244-
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
243+
# Disallow subclasses of Python scalars, such as np.float64 and np.complex128
244+
elif type(other) not in (bool, int, float, complex):
245+
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")
245246

246247
# Helper function to match the type promotion rules in the spec
247248
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
@@ -542,7 +543,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
542543
"""
543544
Performs the operation __add__.
544545
"""
545-
self._check_device(other)
546+
self._check_type_device(other)
546547
other = self._check_allowed_dtypes(other, "numeric", "__add__")
547548
if other is NotImplemented:
548549
return other
@@ -554,7 +555,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
554555
"""
555556
Performs the operation __and__.
556557
"""
557-
self._check_device(other)
558+
self._check_type_device(other)
558559
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
559560
if other is NotImplemented:
560561
return other
@@ -651,7 +652,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
651652
"""
652653
Performs the operation __eq__.
653654
"""
654-
self._check_device(other)
655+
self._check_type_device(other)
655656
# Even though "all" dtypes are allowed, we still require them to be
656657
# promotable with each other.
657658
other = self._check_allowed_dtypes(other, "all", "__eq__")
@@ -677,7 +678,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
677678
"""
678679
Performs the operation __floordiv__.
679680
"""
680-
self._check_device(other)
681+
self._check_type_device(other)
681682
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
682683
if other is NotImplemented:
683684
return other
@@ -689,7 +690,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
689690
"""
690691
Performs the operation __ge__.
691692
"""
692-
self._check_device(other)
693+
self._check_type_device(other)
693694
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
694695
if other is NotImplemented:
695696
return other
@@ -741,7 +742,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
741742
"""
742743
Performs the operation __gt__.
743744
"""
744-
self._check_device(other)
745+
self._check_type_device(other)
745746
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
746747
if other is NotImplemented:
747748
return other
@@ -796,7 +797,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
796797
"""
797798
Performs the operation __le__.
798799
"""
799-
self._check_device(other)
800+
self._check_type_device(other)
800801
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
801802
if other is NotImplemented:
802803
return other
@@ -808,7 +809,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808809
"""
809810
Performs the operation __lshift__.
810811
"""
811-
self._check_device(other)
812+
self._check_type_device(other)
812813
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
813814
if other is NotImplemented:
814815
return other
@@ -820,7 +821,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
820821
"""
821822
Performs the operation __lt__.
822823
"""
823-
self._check_device(other)
824+
self._check_type_device(other)
824825
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
825826
if other is NotImplemented:
826827
return other
@@ -832,7 +833,7 @@ def __matmul__(self, other: Array, /) -> Array:
832833
"""
833834
Performs the operation __matmul__.
834835
"""
835-
self._check_device(other)
836+
self._check_type_device(other)
836837
# matmul is not defined for scalars, but without this, we may get
837838
# the wrong error message from asarray.
838839
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
@@ -845,7 +846,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
845846
"""
846847
Performs the operation __mod__.
847848
"""
848-
self._check_device(other)
849+
self._check_type_device(other)
849850
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
850851
if other is NotImplemented:
851852
return other
@@ -857,7 +858,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
857858
"""
858859
Performs the operation __mul__.
859860
"""
860-
self._check_device(other)
861+
self._check_type_device(other)
861862
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
862863
if other is NotImplemented:
863864
return other
@@ -869,7 +870,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
869870
"""
870871
Performs the operation __ne__.
871872
"""
872-
self._check_device(other)
873+
self._check_type_device(other)
873874
other = self._check_allowed_dtypes(other, "all", "__ne__")
874875
if other is NotImplemented:
875876
return other
@@ -890,7 +891,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
890891
"""
891892
Performs the operation __or__.
892893
"""
893-
self._check_device(other)
894+
self._check_type_device(other)
894895
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
895896
if other is NotImplemented:
896897
return other
@@ -913,7 +914,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
913914
"""
914915
from ._elementwise_functions import pow # type: ignore[attr-defined]
915916

916-
self._check_device(other)
917+
self._check_type_device(other)
917918
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
918919
if other is NotImplemented:
919920
return other
@@ -925,7 +926,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925926
"""
926927
Performs the operation __rshift__.
927928
"""
928-
self._check_device(other)
929+
self._check_type_device(other)
929930
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
930931
if other is NotImplemented:
931932
return other
@@ -961,7 +962,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
961962
"""
962963
Performs the operation __sub__.
963964
"""
964-
self._check_device(other)
965+
self._check_type_device(other)
965966
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
966967
if other is NotImplemented:
967968
return other
@@ -975,7 +976,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975976
"""
976977
Performs the operation __truediv__.
977978
"""
978-
self._check_device(other)
979+
self._check_type_device(other)
979980
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
980981
if other is NotImplemented:
981982
return other
@@ -987,7 +988,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
987988
"""
988989
Performs the operation __xor__.
989990
"""
990-
self._check_device(other)
991+
self._check_type_device(other)
991992
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
992993
if other is NotImplemented:
993994
return other
@@ -999,7 +1000,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
9991000
"""
10001001
Performs the operation __iadd__.
10011002
"""
1002-
self._check_device(other)
1003+
self._check_type_device(other)
10031004
other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
10041005
if other is NotImplemented:
10051006
return other
@@ -1010,7 +1011,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
10101011
"""
10111012
Performs the operation __radd__.
10121013
"""
1013-
self._check_device(other)
1014+
self._check_type_device(other)
10141015
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
10151016
if other is NotImplemented:
10161017
return other
@@ -1022,7 +1023,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
10221023
"""
10231024
Performs the operation __iand__.
10241025
"""
1025-
self._check_device(other)
1026+
self._check_type_device(other)
10261027
other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
10271028
if other is NotImplemented:
10281029
return other
@@ -1033,7 +1034,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
10331034
"""
10341035
Performs the operation __rand__.
10351036
"""
1036-
self._check_device(other)
1037+
self._check_type_device(other)
10371038
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
10381039
if other is NotImplemented:
10391040
return other
@@ -1045,7 +1046,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10451046
"""
10461047
Performs the operation __ifloordiv__.
10471048
"""
1048-
self._check_device(other)
1049+
self._check_type_device(other)
10491050
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
10501051
if other is NotImplemented:
10511052
return other
@@ -1056,7 +1057,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
10561057
"""
10571058
Performs the operation __rfloordiv__.
10581059
"""
1059-
self._check_device(other)
1060+
self._check_type_device(other)
10601061
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
10611062
if other is NotImplemented:
10621063
return other
@@ -1068,7 +1069,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
10681069
"""
10691070
Performs the operation __ilshift__.
10701071
"""
1071-
self._check_device(other)
1072+
self._check_type_device(other)
10721073
other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
10731074
if other is NotImplemented:
10741075
return other
@@ -1079,7 +1080,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
10791080
"""
10801081
Performs the operation __rlshift__.
10811082
"""
1082-
self._check_device(other)
1083+
self._check_type_device(other)
10831084
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
10841085
if other is NotImplemented:
10851086
return other
@@ -1096,7 +1097,7 @@ def __imatmul__(self, other: Array, /) -> Array:
10961097
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
10971098
if other is NotImplemented:
10981099
return other
1099-
self._check_device(other)
1100+
self._check_type_device(other)
11001101
res = self._array.__imatmul__(other._array)
11011102
return self.__class__._new(res, device=self.device)
11021103

@@ -1109,7 +1110,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11091110
other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
11101111
if other is NotImplemented:
11111112
return other
1112-
self._check_device(other)
1113+
self._check_type_device(other)
11131114
res = self._array.__rmatmul__(other._array)
11141115
return self.__class__._new(res, device=self.device)
11151116

@@ -1130,7 +1131,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
11301131
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
11311132
if other is NotImplemented:
11321133
return other
1133-
self._check_device(other)
1134+
self._check_type_device(other)
11341135
self, other = self._normalize_two_args(self, other)
11351136
res = self._array.__rmod__(other._array)
11361137
return self.__class__._new(res, device=self.device)
@@ -1152,7 +1153,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11521153
other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
11531154
if other is NotImplemented:
11541155
return other
1155-
self._check_device(other)
1156+
self._check_type_device(other)
11561157
self, other = self._normalize_two_args(self, other)
11571158
res = self._array.__rmul__(other._array)
11581159
return self.__class__._new(res, device=self.device)
@@ -1171,7 +1172,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
11711172
"""
11721173
Performs the operation __ror__.
11731174
"""
1174-
self._check_device(other)
1175+
self._check_type_device(other)
11751176
other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
11761177
if other is NotImplemented:
11771178
return other
@@ -1219,7 +1220,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12191220
other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
12201221
if other is NotImplemented:
12211222
return other
1222-
self._check_device(other)
1223+
self._check_type_device(other)
12231224
self, other = self._normalize_two_args(self, other)
12241225
res = self._array.__rrshift__(other._array)
12251226
return self.__class__._new(res, device=self.device)
@@ -1241,7 +1242,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12411242
other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
12421243
if other is NotImplemented:
12431244
return other
1244-
self._check_device(other)
1245+
self._check_type_device(other)
12451246
self, other = self._normalize_two_args(self, other)
12461247
res = self._array.__rsub__(other._array)
12471248
return self.__class__._new(res, device=self.device)
@@ -1263,7 +1264,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12631264
other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
12641265
if other is NotImplemented:
12651266
return other
1266-
self._check_device(other)
1267+
self._check_type_device(other)
12671268
self, other = self._normalize_two_args(self, other)
12681269
res = self._array.__rtruediv__(other._array)
12691270
return self.__class__._new(res, device=self.device)
@@ -1285,7 +1286,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
12851286
other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
12861287
if other is NotImplemented:
12871288
return other
1288-
self._check_device(other)
1289+
self._check_type_device(other)
12891290
self, other = self._normalize_two_args(self, other)
12901291
res = self._array.__rxor__(other._array)
12911292
return self.__class__._new(res, device=self.device)

0 commit comments

Comments
 (0)