From 5379bd5796c6c7039033c6f59e2b5923f3fc9162 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:01:20 -0600 Subject: [PATCH 1/3] Allow any combination of real dtypes in comparisons This does not change == or != because the standard is currently unclear about that so I'd like to see what happens there first. --- array_api_strict/_array_object.py | 19 +++++++++++++------ array_api_strict/_elementwise_functions.py | 8 -------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..86d2b8b 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -152,7 +152,13 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + def _check_allowed_dtypes( + self, + other: bool | int | float | Array, + dtype_category: str, + op: str, + check_promotion: bool = True, + ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -176,7 +182,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor # This will raise TypeError for type combinations that are not allowed # to promote in the spec (even if the NumPy array operator would # promote them). - res_dtype = _result_type(self.dtype, other.dtype) + if check_promotion: + res_dtype = _result_type(self.dtype, other.dtype) if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side @@ -604,7 +611,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__ge__") + other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -638,7 +645,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__gt__") + other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -692,7 +699,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__le__") + other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -714,7 +721,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__lt__") + other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b39bd86..f0c94db 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -439,8 +439,6 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) @@ -453,8 +451,6 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) @@ -524,8 +520,6 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) @@ -538,8 +532,6 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) From b47039fbe6bb6b67a8be1d739a9b81c24e42e3f2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:24:13 -0600 Subject: [PATCH 2/3] Add helpful error messages to assert_raises calls in test_array_object.py --- array_api_strict/tests/test_array_object.py | 65 ++++++++++----------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b0d4868..d4b8794 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,7 +1,7 @@ import operator from builtins import all as all_ -from numpy.testing import assert_raises, suppress_warnings +import numpy.testing import numpy as np import pytest @@ -29,6 +29,10 @@ import array_api_strict +def assert_raises(exception, func, msg=None): + with numpy.testing.assert_raises(exception, msg=msg): + func() + def test_validate_index(): # The indexing tests in the official array API test suite test that the # array object correctly handles the subset of indices that are required @@ -111,6 +115,7 @@ def test_operators(): "__truediv__": "floating", "__xor__": "integer_or_boolean", } + comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"] # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -124,7 +129,7 @@ def _array_vals(): BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: + if op not in comparison_ops: rop = "__r" + op[2:] iop = "__i" + op[2:] ops += [rop, iop] @@ -155,16 +160,16 @@ def _array_vals(): or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s)) + assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op) else: # Only test for no error - with suppress_warnings() as sup: + with numpy.testing.suppress_warnings() as sup: # ignore warnings from pow(BIG_INT) sup.filter(RuntimeWarning, "invalid value encountered in power") getattr(a, _op)(s) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) # Test array op array. for _op in ops: @@ -188,7 +193,7 @@ def _array_vals(): _op.startswith("__i") and result_type(x.dtype, y.dtype) != x.dtype ): - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure only those dtypes that are required for every operator are allowed. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) @@ -202,7 +207,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) unary_op_dtypes = { "__abs__": "numeric", @@ -221,7 +226,7 @@ def _array_vals(): # Only test for no error getattr(a, op)() else: - assert_raises(TypeError, lambda: getattr(a, op)()) + assert_raises(TypeError, lambda: getattr(a, op)(), _op) # Finally, matmul() must be tested separately, because it works a bit # different from the other operations. @@ -240,9 +245,9 @@ def _matmul_array_vals(): or type(s) == int and a.dtype in _integer_dtypes): # Type promotion is valid, but @ is not allowed on 0-D # inputs, so the error is a ValueError - assert_raises(ValueError, lambda: getattr(a, _op)(s)) + assert_raises(ValueError, lambda: getattr(a, _op)(s), _op) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) for x in _matmul_array_vals(): for y in _matmul_array_vals(): @@ -356,20 +361,17 @@ def test_allow_newaxis(): def test_disallow_flat_indexing_with_newaxis(): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[None, 0, 0] + assert_raises(IndexError, lambda: a[None, 0, 0]) def test_disallow_mask_with_newaxis(): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[None, asarray(True)] + assert_raises(IndexError, lambda: a[None, asarray(True)]) @pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) @pytest.mark.parametrize("index", ["string", False, True]) def test_error_on_invalid_index(shape, index): a = ones(shape) - with pytest.raises(IndexError): - a[index] + assert_raises(IndexError, lambda: a[index]) def test_mask_0d_array_without_errors(): a = ones(()) @@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors(): ) def test_error_on_invalid_index_with_ellipsis(i): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[..., i] - with pytest.raises(IndexError): - a[i, ...] + assert_raises(IndexError, lambda: a[..., i]) + assert_raises(IndexError, lambda: a[i, ...]) def test_array_keys_use_private_array(): """ @@ -400,8 +400,7 @@ def test_array_keys_use_private_array(): a = ones((0,), dtype=bool_) key = ones((0, 0), dtype=bool_) - with pytest.raises(IndexError): - a[key] + assert_raises(IndexError, lambda: a[key]) def test_array_namespace(): a = ones((3, 3)) @@ -422,16 +421,16 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) + assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) def test_iter(): - pytest.raises(TypeError, lambda: iter(asarray(3))) + assert_raises(TypeError, lambda: iter(asarray(3))) assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] assert all_(isinstance(a, Array) for a in iter(ones(3))) assert all_(a.shape == () for a in iter(ones(3))) assert all_(a.dtype == float64 for a in iter(ones(3))) - pytest.raises(TypeError, lambda: iter(ones((3, 3)))) + assert_raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): @@ -447,17 +446,17 @@ def dlpack_2023_12(api_version): exception = NotImplementedError if api_version >= '2023.12' else ValueError - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=None)) From 1f8769914880a626eb1bca6ccbe0399a47d73694 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:33:21 -0600 Subject: [PATCH 3/3] Allow all dtypes in equal/not_equal/==/!= and update tests Also update elementwise function tests to check for disallowed type promotions, not just disallowed mixed kind types. --- array_api_strict/_array_object.py | 5 ++-- array_api_strict/_elementwise_functions.py | 4 --- array_api_strict/tests/test_array_object.py | 28 ++++++++--------- .../tests/test_elementwise_functions.py | 30 +++++++++++++++++-- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 86d2b8b..bded0c6 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -157,6 +157,7 @@ def _check_allowed_dtypes( other: bool | int | float | Array, dtype_category: str, op: str, + *, check_promotion: bool = True, ) -> Array: """ @@ -577,7 +578,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, "all", "__eq__") + other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -766,7 +767,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, "all", "__ne__") + other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index f0c94db..d4a108d 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -375,8 +375,6 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) @@ -707,8 +705,6 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index d4b8794..04e606e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -94,7 +94,7 @@ def test_validate_index(): def test_operators(): # For every operator, we test that it works for the required type - # combinations and raises TypeError otherwise + # combinations and assert_raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", "__and__": "integer_or_boolean", @@ -178,16 +178,17 @@ def _array_vals(): # See the promotion table in NEP 47 or the array # API spec page on type promotion. Mixed kind # promotion is not defined. - if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] - or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] - or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes - or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes - or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes - or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes - or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes - or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - ): - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + if (op not in comparison_ops and + (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + )): + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure in-place operators only promote to the same dtype as the left operand. elif ( _op.startswith("__i") @@ -195,8 +196,7 @@ def _array_vals(): ): assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure only those dtypes that are required for every operator are allowed. - elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes - or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + elif (dtypes == "all" or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes @@ -207,7 +207,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y)) unary_op_dtypes = { "__abs__": "numeric", diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 90994f3..92c9c59 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,6 @@ from inspect import getfullargspec, getmodule -from numpy.testing import assert_raises +from .test_array_object import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -9,6 +9,11 @@ _boolean_dtypes, _floating_dtypes, _integer_dtypes, + int8, + int16, + int32, + int64, + uint64, ) from .._flags import set_array_api_strict_flags @@ -86,6 +91,15 @@ def nargs(func): "trunc": "real numeric", } +comparison_functions = [ + 'equal', + 'greater', + 'greater_equal', + 'less', + 'less_equal', + 'not_equal', +] + def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod @@ -115,8 +129,20 @@ def _array_vals(): func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): + # Disallow dtypes that aren't type promotable + if (func_name not in comparison_functions and + (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + )): + assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) else: if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x))