@@ -233,15 +233,16 @@ def _check_allowed_dtypes(
233
233
234
234
return other
235
235
236
- def _check_device (self , other : Array | bool | int | float | complex ) -> None :
237
- """Check that other is on a device compatible with the current array"""
238
- if isinstance ( other , ( bool , int , float , complex )):
239
- return
240
- elif isinstance (other , Array ):
236
+ def _check_type_device (self , other : Array | bool | int | float | complex ) -> None :
237
+ """Check that other is either a Python scalar or an array on a device
238
+ compatible with the current array.
239
+ """
240
+ if isinstance (other , Array ):
241
241
if self .device != other .device :
242
242
raise ValueError (f"Arrays from two different devices ({ self .device } and { other .device } ) can not be combined." )
243
- else :
244
- raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
243
+ # Disallow subclasses of Python scalars, such as np.float64 and np.complex128
244
+ elif type (other ) not in (bool , int , float , complex ):
245
+ raise TypeError (f"Expected Array or Python scalar; got { type (other )} " )
245
246
246
247
# Helper function to match the type promotion rules in the spec
247
248
def _promote_scalar (self , scalar : bool | int | float | complex ) -> Array :
@@ -542,7 +543,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
542
543
"""
543
544
Performs the operation __add__.
544
545
"""
545
- self ._check_device (other )
546
+ self ._check_type_device (other )
546
547
other = self ._check_allowed_dtypes (other , "numeric" , "__add__" )
547
548
if other is NotImplemented :
548
549
return other
@@ -554,7 +555,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
554
555
"""
555
556
Performs the operation __and__.
556
557
"""
557
- self ._check_device (other )
558
+ self ._check_type_device (other )
558
559
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__and__" )
559
560
if other is NotImplemented :
560
561
return other
@@ -651,7 +652,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
651
652
"""
652
653
Performs the operation __eq__.
653
654
"""
654
- self ._check_device (other )
655
+ self ._check_type_device (other )
655
656
# Even though "all" dtypes are allowed, we still require them to be
656
657
# promotable with each other.
657
658
other = self ._check_allowed_dtypes (other , "all" , "__eq__" )
@@ -677,7 +678,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
677
678
"""
678
679
Performs the operation __floordiv__.
679
680
"""
680
- self ._check_device (other )
681
+ self ._check_type_device (other )
681
682
other = self ._check_allowed_dtypes (other , "real numeric" , "__floordiv__" )
682
683
if other is NotImplemented :
683
684
return other
@@ -689,7 +690,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
689
690
"""
690
691
Performs the operation __ge__.
691
692
"""
692
- self ._check_device (other )
693
+ self ._check_type_device (other )
693
694
other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
694
695
if other is NotImplemented :
695
696
return other
@@ -741,7 +742,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
741
742
"""
742
743
Performs the operation __gt__.
743
744
"""
744
- self ._check_device (other )
745
+ self ._check_type_device (other )
745
746
other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
746
747
if other is NotImplemented :
747
748
return other
@@ -796,7 +797,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
796
797
"""
797
798
Performs the operation __le__.
798
799
"""
799
- self ._check_device (other )
800
+ self ._check_type_device (other )
800
801
other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
801
802
if other is NotImplemented :
802
803
return other
@@ -808,7 +809,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808
809
"""
809
810
Performs the operation __lshift__.
810
811
"""
811
- self ._check_device (other )
812
+ self ._check_type_device (other )
812
813
other = self ._check_allowed_dtypes (other , "integer" , "__lshift__" )
813
814
if other is NotImplemented :
814
815
return other
@@ -820,7 +821,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
820
821
"""
821
822
Performs the operation __lt__.
822
823
"""
823
- self ._check_device (other )
824
+ self ._check_type_device (other )
824
825
other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
825
826
if other is NotImplemented :
826
827
return other
@@ -832,7 +833,7 @@ def __matmul__(self, other: Array, /) -> Array:
832
833
"""
833
834
Performs the operation __matmul__.
834
835
"""
835
- self ._check_device (other )
836
+ self ._check_type_device (other )
836
837
# matmul is not defined for scalars, but without this, we may get
837
838
# the wrong error message from asarray.
838
839
other = self ._check_allowed_dtypes (other , "numeric" , "__matmul__" )
@@ -845,7 +846,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
845
846
"""
846
847
Performs the operation __mod__.
847
848
"""
848
- self ._check_device (other )
849
+ self ._check_type_device (other )
849
850
other = self ._check_allowed_dtypes (other , "real numeric" , "__mod__" )
850
851
if other is NotImplemented :
851
852
return other
@@ -857,7 +858,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
857
858
"""
858
859
Performs the operation __mul__.
859
860
"""
860
- self ._check_device (other )
861
+ self ._check_type_device (other )
861
862
other = self ._check_allowed_dtypes (other , "numeric" , "__mul__" )
862
863
if other is NotImplemented :
863
864
return other
@@ -869,7 +870,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
869
870
"""
870
871
Performs the operation __ne__.
871
872
"""
872
- self ._check_device (other )
873
+ self ._check_type_device (other )
873
874
other = self ._check_allowed_dtypes (other , "all" , "__ne__" )
874
875
if other is NotImplemented :
875
876
return other
@@ -890,7 +891,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
890
891
"""
891
892
Performs the operation __or__.
892
893
"""
893
- self ._check_device (other )
894
+ self ._check_type_device (other )
894
895
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__or__" )
895
896
if other is NotImplemented :
896
897
return other
@@ -913,7 +914,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
913
914
"""
914
915
from ._elementwise_functions import pow # type: ignore[attr-defined]
915
916
916
- self ._check_device (other )
917
+ self ._check_type_device (other )
917
918
other = self ._check_allowed_dtypes (other , "numeric" , "__pow__" )
918
919
if other is NotImplemented :
919
920
return other
@@ -925,7 +926,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925
926
"""
926
927
Performs the operation __rshift__.
927
928
"""
928
- self ._check_device (other )
929
+ self ._check_type_device (other )
929
930
other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
930
931
if other is NotImplemented :
931
932
return other
@@ -961,7 +962,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
961
962
"""
962
963
Performs the operation __sub__.
963
964
"""
964
- self ._check_device (other )
965
+ self ._check_type_device (other )
965
966
other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
966
967
if other is NotImplemented :
967
968
return other
@@ -975,7 +976,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975
976
"""
976
977
Performs the operation __truediv__.
977
978
"""
978
- self ._check_device (other )
979
+ self ._check_type_device (other )
979
980
other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
980
981
if other is NotImplemented :
981
982
return other
@@ -987,7 +988,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
987
988
"""
988
989
Performs the operation __xor__.
989
990
"""
990
- self ._check_device (other )
991
+ self ._check_type_device (other )
991
992
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
992
993
if other is NotImplemented :
993
994
return other
@@ -999,7 +1000,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
999
1000
"""
1000
1001
Performs the operation __iadd__.
1001
1002
"""
1002
- self ._check_device (other )
1003
+ self ._check_type_device (other )
1003
1004
other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
1004
1005
if other is NotImplemented :
1005
1006
return other
@@ -1010,7 +1011,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
1010
1011
"""
1011
1012
Performs the operation __radd__.
1012
1013
"""
1013
- self ._check_device (other )
1014
+ self ._check_type_device (other )
1014
1015
other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
1015
1016
if other is NotImplemented :
1016
1017
return other
@@ -1022,7 +1023,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
1022
1023
"""
1023
1024
Performs the operation __iand__.
1024
1025
"""
1025
- self ._check_device (other )
1026
+ self ._check_type_device (other )
1026
1027
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
1027
1028
if other is NotImplemented :
1028
1029
return other
@@ -1033,7 +1034,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
1033
1034
"""
1034
1035
Performs the operation __rand__.
1035
1036
"""
1036
- self ._check_device (other )
1037
+ self ._check_type_device (other )
1037
1038
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
1038
1039
if other is NotImplemented :
1039
1040
return other
@@ -1045,7 +1046,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
1045
1046
"""
1046
1047
Performs the operation __ifloordiv__.
1047
1048
"""
1048
- self ._check_device (other )
1049
+ self ._check_type_device (other )
1049
1050
other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
1050
1051
if other is NotImplemented :
1051
1052
return other
@@ -1056,7 +1057,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
1056
1057
"""
1057
1058
Performs the operation __rfloordiv__.
1058
1059
"""
1059
- self ._check_device (other )
1060
+ self ._check_type_device (other )
1060
1061
other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
1061
1062
if other is NotImplemented :
1062
1063
return other
@@ -1068,7 +1069,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
1068
1069
"""
1069
1070
Performs the operation __ilshift__.
1070
1071
"""
1071
- self ._check_device (other )
1072
+ self ._check_type_device (other )
1072
1073
other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
1073
1074
if other is NotImplemented :
1074
1075
return other
@@ -1079,7 +1080,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
1079
1080
"""
1080
1081
Performs the operation __rlshift__.
1081
1082
"""
1082
- self ._check_device (other )
1083
+ self ._check_type_device (other )
1083
1084
other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
1084
1085
if other is NotImplemented :
1085
1086
return other
@@ -1096,7 +1097,7 @@ def __imatmul__(self, other: Array, /) -> Array:
1096
1097
other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
1097
1098
if other is NotImplemented :
1098
1099
return other
1099
- self ._check_device (other )
1100
+ self ._check_type_device (other )
1100
1101
res = self ._array .__imatmul__ (other ._array )
1101
1102
return self .__class__ ._new (res , device = self .device )
1102
1103
@@ -1109,7 +1110,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
1109
1110
other = self ._check_allowed_dtypes (other , "numeric" , "__rmatmul__" )
1110
1111
if other is NotImplemented :
1111
1112
return other
1112
- self ._check_device (other )
1113
+ self ._check_type_device (other )
1113
1114
res = self ._array .__rmatmul__ (other ._array )
1114
1115
return self .__class__ ._new (res , device = self .device )
1115
1116
@@ -1130,7 +1131,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
1130
1131
other = self ._check_allowed_dtypes (other , "real numeric" , "__rmod__" )
1131
1132
if other is NotImplemented :
1132
1133
return other
1133
- self ._check_device (other )
1134
+ self ._check_type_device (other )
1134
1135
self , other = self ._normalize_two_args (self , other )
1135
1136
res = self ._array .__rmod__ (other ._array )
1136
1137
return self .__class__ ._new (res , device = self .device )
@@ -1152,7 +1153,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
1152
1153
other = self ._check_allowed_dtypes (other , "numeric" , "__rmul__" )
1153
1154
if other is NotImplemented :
1154
1155
return other
1155
- self ._check_device (other )
1156
+ self ._check_type_device (other )
1156
1157
self , other = self ._normalize_two_args (self , other )
1157
1158
res = self ._array .__rmul__ (other ._array )
1158
1159
return self .__class__ ._new (res , device = self .device )
@@ -1171,7 +1172,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
1171
1172
"""
1172
1173
Performs the operation __ror__.
1173
1174
"""
1174
- self ._check_device (other )
1175
+ self ._check_type_device (other )
1175
1176
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__ror__" )
1176
1177
if other is NotImplemented :
1177
1178
return other
@@ -1219,7 +1220,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
1219
1220
other = self ._check_allowed_dtypes (other , "integer" , "__rrshift__" )
1220
1221
if other is NotImplemented :
1221
1222
return other
1222
- self ._check_device (other )
1223
+ self ._check_type_device (other )
1223
1224
self , other = self ._normalize_two_args (self , other )
1224
1225
res = self ._array .__rrshift__ (other ._array )
1225
1226
return self .__class__ ._new (res , device = self .device )
@@ -1241,7 +1242,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
1241
1242
other = self ._check_allowed_dtypes (other , "numeric" , "__rsub__" )
1242
1243
if other is NotImplemented :
1243
1244
return other
1244
- self ._check_device (other )
1245
+ self ._check_type_device (other )
1245
1246
self , other = self ._normalize_two_args (self , other )
1246
1247
res = self ._array .__rsub__ (other ._array )
1247
1248
return self .__class__ ._new (res , device = self .device )
@@ -1263,7 +1264,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
1263
1264
other = self ._check_allowed_dtypes (other , "floating-point" , "__rtruediv__" )
1264
1265
if other is NotImplemented :
1265
1266
return other
1266
- self ._check_device (other )
1267
+ self ._check_type_device (other )
1267
1268
self , other = self ._normalize_two_args (self , other )
1268
1269
res = self ._array .__rtruediv__ (other ._array )
1269
1270
return self .__class__ ._new (res , device = self .device )
@@ -1285,7 +1286,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
1285
1286
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rxor__" )
1286
1287
if other is NotImplemented :
1287
1288
return other
1288
- self ._check_device (other )
1289
+ self ._check_type_device (other )
1289
1290
self , other = self ._normalize_two_args (self , other )
1290
1291
res = self ._array .__rxor__ (other ._array )
1291
1292
return self .__class__ ._new (res , device = self .device )
0 commit comments