Skip to content

Commit e7fcd34

Browse files
committed
Disallow float64 and complex128
1 parent b0e4df9 commit e7fcd34

File tree

2 files changed

+38
-62
lines changed

2 files changed

+38
-62
lines changed

array_api_strict/_array_object.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non
240240
if isinstance(other, Array):
241241
if self.device != other.device:
242242
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
243-
elif not isinstance(other, bool | int | float | complex):
243+
# Disallow subclasses of Python scalars, such as np.float64 and np.complex128
244+
elif type(other) not in (bool, int, float, complex):
244245
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")
245246

246247
# Helper function to match the type promotion rules in the spec

array_api_strict/tests/test_array_object.py

+36-61
Original file line numberDiff line numberDiff line change
@@ -411,72 +411,47 @@ def _matmul_array_vals():
411411
x.__imatmul__(y)
412412

413413

414-
@pytest.mark.parametrize(
415-
"op",
416-
[
417-
op for op, dtypes in binary_op_dtypes.items()
418-
if dtypes not in ("real numeric", "floating-point")
419-
],
420-
)
421-
def test_binary_operators_vs_numpy_int(op):
422-
"""np.int64 is not a subclass of int and must be disallowed"""
423-
a = asarray(1)
424-
i64 = np.int64(1)
425-
with pytest.raises(TypeError, match="Expected Array or Python scalar"):
426-
getattr(a, op)(i64)
427-
428-
429-
@pytest.mark.parametrize(
430-
"op",
431-
[
432-
op for op, dtypes in binary_op_dtypes.items()
433-
if dtypes not in ("integer", "integer or boolean")
434-
],
435-
)
436-
def test_binary_operators_vs_numpy_float(op):
437-
"""
438-
np.float64 is a subclass of float and must be allowed.
439-
np.float32 is not and must be rejected.
440-
"""
441-
a = asarray(1.)
442-
f64 = np.float64(1.)
443-
f32 = np.float32(1.)
444-
func = getattr(a, op)
445-
for op in binary_op_dtypes:
446-
assert isinstance(func(f64), Array)
447-
with pytest.raises(TypeError, match="Expected Array or Python scalar"):
448-
func(f32)
449-
450-
451-
@pytest.mark.parametrize(
452-
"op",
453-
[
454-
op for op, dtypes in binary_op_dtypes.items()
455-
if dtypes not in ("integer", "integer or boolean", "real numeric")
456-
],
457-
)
458-
def test_binary_operators_vs_numpy_complex(op):
459-
"""
460-
np.complex128 is a subclass of complex and must be allowed.
461-
np.complex64 is not and must be rejected.
414+
@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items())
415+
def test_binary_operators_vs_numpy_generics(op, dtypes):
416+
"""Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128
417+
are disallowed in binary operators.
418+
np.float64 and np.complex128 are subclasses of float and complex, so they need
419+
special treatment in order to be rejected.
462420
"""
463-
a = asarray(1.)
464-
c64 = np.complex64(1.)
465-
c128 = np.complex128(1.)
466-
func = getattr(a, op)
467-
for op in binary_op_dtypes:
468-
assert isinstance(func(c128), Array)
469-
with pytest.raises(TypeError, match="Expected Array or Python scalar"):
470-
func(c64)
421+
match = "Expected Array or Python scalar"
422+
423+
if dtypes not in ("numeric", "integer", "real numeric", "floating-point"):
424+
a = asarray(True)
425+
func = getattr(a, op)
426+
with pytest.raises(TypeError, match=match):
427+
func(np.bool_(True))
428+
429+
if dtypes != "floating-point":
430+
a = asarray(1)
431+
func = getattr(a, op)
432+
with pytest.raises(TypeError, match=match):
433+
func(np.int64(1))
434+
435+
if dtypes not in ("integer", "integer or boolean"):
436+
a = asarray(1.,)
437+
func = getattr(a, op)
438+
with pytest.raises(TypeError, match=match):
439+
func(np.float32(1.))
440+
with pytest.raises(TypeError, match=match):
441+
func(np.float64(1.))
442+
443+
if dtypes not in ("integer", "integer or boolean", "real numeric"):
444+
a = asarray(1.,)
445+
func = getattr(a, op)
446+
with pytest.raises(TypeError, match=match):
447+
func(np.complex64(1.))
448+
with pytest.raises(TypeError, match=match):
449+
func(np.complex128(1.))
471450

472451

473452
@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items())
474453
def test_binary_operators_device_mismatch(op, dtypes):
475-
if dtypes in ("real numeric", "floating-point"):
476-
dtype = float64
477-
else:
478-
dtype = int64
479-
454+
dtype = float64 if dtypes == "floating-point" else int64
480455
a = asarray(1, dtype=dtype, device=CPU_DEVICE)
481456
b = asarray(1, dtype=dtype, device=Device("device1"))
482457
with pytest.raises(ValueError, match="different devices"):

0 commit comments

Comments
 (0)