diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 483952e..579da90 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -233,15 +233,16 @@ def _check_allowed_dtypes( return other - def _check_device(self, other: Array | bool | int | float | complex) -> None: - """Check that other is on a device compatible with the current array""" - if isinstance(other, (bool, int, float, complex)): - return - elif isinstance(other, Array): + def _check_type_device(self, other: Array | bool | int | float | complex) -> None: + """Check that other is either a Python scalar or an array on a device + compatible with the current array. + """ + if isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - else: - raise TypeError(f"Expected Array | python scalar; got {type(other)}") + # Disallow subclasses of Python scalars, such as np.float64 and np.complex128 + elif type(other) not in (bool, int, float, complex): + raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: @@ -542,7 +543,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __add__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other @@ -554,7 +555,7 @@ def __and__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __and__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other @@ -651,7 +652,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __eq__. """ - self._check_device(other) + self._check_type_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -677,7 +678,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __floordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other @@ -689,7 +690,7 @@ def __ge__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ge__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other @@ -741,7 +742,7 @@ def __gt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __gt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other @@ -796,7 +797,7 @@ def __le__(self, other: Array | int | float, /) -> Array: """ Performs the operation __le__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other @@ -808,7 +809,7 @@ def __lshift__(self, other: Array | int, /) -> Array: """ Performs the operation __lshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other @@ -820,7 +821,7 @@ def __lt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __lt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other @@ -832,7 +833,7 @@ def __matmul__(self, other: Array, /) -> Array: """ Performs the operation __matmul__. """ - self._check_device(other) + self._check_type_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") @@ -845,7 +846,7 @@ def __mod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __mod__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other @@ -857,7 +858,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __mul__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other @@ -869,7 +870,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __ne__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other @@ -890,7 +891,7 @@ def __or__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __or__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other @@ -913,7 +914,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array: """ from ._elementwise_functions import pow # type: ignore[attr-defined] - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other @@ -925,7 +926,7 @@ def __rshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other @@ -961,7 +962,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __sub__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other @@ -975,7 +976,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __truediv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other @@ -987,7 +988,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __xor__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other @@ -999,7 +1000,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __iadd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -1010,7 +1011,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __radd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other @@ -1022,7 +1023,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __iand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -1033,7 +1034,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other @@ -1045,7 +1046,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ifloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -1056,7 +1057,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rfloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other @@ -1068,7 +1069,7 @@ def __ilshift__(self, other: Array | int, /) -> Array: """ Performs the operation __ilshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -1079,7 +1080,7 @@ def __rlshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rlshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other @@ -1096,7 +1097,7 @@ def __imatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__imatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1109,7 +1110,7 @@ def __rmatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1130,7 +1131,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array: other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) @@ -1152,7 +1153,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1171,7 +1172,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ror__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other @@ -1219,7 +1220,7 @@ def __rrshift__(self, other: Array | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) @@ -1241,7 +1242,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) @@ -1263,7 +1264,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) @@ -1285,7 +1286,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dbab1af..e950be5 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -255,30 +255,37 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): func(s) return False +binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer or boolean", + "__eq__": "all", + "__floordiv__": "real numeric", + "__ge__": "real numeric", + "__gt__": "real numeric", + "__le__": "real numeric", + "__lshift__": "integer", + "__lt__": "real numeric", + "__mod__": "real numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer or boolean", + "__pow__": "numeric", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", +} +unary_op_dtypes = { + "__abs__": "numeric", + "__invert__": "integer or boolean", + "__neg__": "numeric", + "__pos__": "numeric", +} def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise - binary_op_dtypes = { - "__add__": "numeric", - "__and__": "integer or boolean", - "__eq__": "all", - "__floordiv__": "real numeric", - "__ge__": "real numeric", - "__gt__": "real numeric", - "__le__": "real numeric", - "__lshift__": "integer", - "__lt__": "real numeric", - "__mod__": "real numeric", - "__mul__": "numeric", - "__ne__": "all", - "__or__": "integer or boolean", - "__pow__": "numeric", - "__rshift__": "integer", - "__sub__": "numeric", - "__truediv__": "floating-point", - "__xor__": "integer or boolean", - } + # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -337,12 +344,6 @@ def _array_vals(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) - unary_op_dtypes = { - "__abs__": "numeric", - "__invert__": "integer or boolean", - "__neg__": "numeric", - "__pos__": "numeric", - } for op, dtypes in unary_op_dtypes.items(): for a in _array_vals(): if ( @@ -410,6 +411,53 @@ def _matmul_array_vals(): x.__imatmul__(y) +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_vs_numpy_generics(op, dtypes): + """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 + are disallowed in binary operators. + np.float64 and np.complex128 are subclasses of float and complex, so they need + special treatment in order to be rejected. + """ + match = "Expected Array or Python scalar" + + if dtypes not in ("numeric", "integer", "real numeric", "floating-point"): + a = asarray(True) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.bool_(True)) + + if dtypes != "floating-point": + a = asarray(1) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.int64(1)) + + if dtypes not in ("integer", "integer or boolean"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.float32(1.)) + with pytest.raises(TypeError, match=match): + func(np.float64(1.)) + + if dtypes not in ("integer", "integer or boolean", "real numeric"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.complex64(1.)) + with pytest.raises(TypeError, match=match): + func(np.complex128(1.)) + + +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_device_mismatch(op, dtypes): + dtype = float64 if dtypes == "floating-point" else int64 + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + with pytest.raises(ValueError, match="different devices"): + getattr(a, op)(b) + + def test_python_scalar_construtors(): b = asarray(False) i = asarray(0)