Skip to content

Commit 1f87699

Browse files
committed
Allow all dtypes in equal/not_equal/==/!= and update tests
Also update elementwise function tests to check for disallowed type promotions, not just disallowed mixed kind types.
1 parent b47039f commit 1f87699

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

array_api_strict/_array_object.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _check_allowed_dtypes(
157157
other: bool | int | float | Array,
158158
dtype_category: str,
159159
op: str,
160+
*,
160161
check_promotion: bool = True,
161162
) -> Array:
162163
"""
@@ -577,7 +578,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
577578
"""
578579
# Even though "all" dtypes are allowed, we still require them to be
579580
# promotable with each other.
580-
other = self._check_allowed_dtypes(other, "all", "__eq__")
581+
other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False)
581582
if other is NotImplemented:
582583
return other
583584
self, other = self._normalize_two_args(self, other)
@@ -766,7 +767,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
766767
"""
767768
Performs the operation __ne__.
768769
"""
769-
other = self._check_allowed_dtypes(other, "all", "__ne__")
770+
other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False)
770771
if other is NotImplemented:
771772
return other
772773
self, other = self._normalize_two_args(self, other)

array_api_strict/_elementwise_functions.py

-4
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ def equal(x1: Array, x2: Array, /) -> Array:
375375
376376
See its docstring for more information.
377377
"""
378-
# Call result type here just to raise on disallowed type combinations
379-
_result_type(x1.dtype, x2.dtype)
380378
x1, x2 = Array._normalize_two_args(x1, x2)
381379
return Array._new(np.equal(x1._array, x2._array))
382380

@@ -707,8 +705,6 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
707705
708706
See its docstring for more information.
709707
"""
710-
# Call result type here just to raise on disallowed type combinations
711-
_result_type(x1.dtype, x2.dtype)
712708
x1, x2 = Array._normalize_two_args(x1, x2)
713709
return Array._new(np.not_equal(x1._array, x2._array))
714710

array_api_strict/tests/test_array_object.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_validate_index():
9494

9595
def test_operators():
9696
# For every operator, we test that it works for the required type
97-
# combinations and raises TypeError otherwise
97+
# combinations and assert_raises TypeError otherwise
9898
binary_op_dtypes = {
9999
"__add__": "numeric",
100100
"__and__": "integer_or_boolean",
@@ -178,25 +178,25 @@ def _array_vals():
178178
# See the promotion table in NEP 47 or the array
179179
# API spec page on type promotion. Mixed kind
180180
# promotion is not defined.
181-
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
182-
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
183-
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
184-
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
185-
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
186-
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
187-
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
188-
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
189-
):
190-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
181+
if (op not in comparison_ops and
182+
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
183+
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
184+
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
185+
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
186+
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
187+
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
188+
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
189+
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
190+
)):
191+
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
191192
# Ensure in-place operators only promote to the same dtype as the left operand.
192193
elif (
193194
_op.startswith("__i")
194195
and result_type(x.dtype, y.dtype) != x.dtype
195196
):
196197
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
197198
# Ensure only those dtypes that are required for every operator are allowed.
198-
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
199-
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
199+
elif (dtypes == "all"
200200
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
201201
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
202202
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
@@ -207,7 +207,7 @@ def _array_vals():
207207
):
208208
getattr(x, _op)(y)
209209
else:
210-
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
210+
assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y))
211211

212212
unary_op_dtypes = {
213213
"__abs__": "numeric",

array_api_strict/tests/test_elementwise_functions.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from inspect import getfullargspec, getmodule
22

3-
from numpy.testing import assert_raises
3+
from .test_array_object import assert_raises
44

55
from .. import asarray, _elementwise_functions
66
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
@@ -9,6 +9,11 @@
99
_boolean_dtypes,
1010
_floating_dtypes,
1111
_integer_dtypes,
12+
int8,
13+
int16,
14+
int32,
15+
int64,
16+
uint64,
1217
)
1318
from .._flags import set_array_api_strict_flags
1419

@@ -86,6 +91,15 @@ def nargs(func):
8691
"trunc": "real numeric",
8792
}
8893

94+
comparison_functions = [
95+
'equal',
96+
'greater',
97+
'greater_equal',
98+
'less',
99+
'less_equal',
100+
'not_equal',
101+
]
102+
89103
def test_missing_functions():
90104
# Ensure the above dictionary is complete.
91105
import array_api_strict._elementwise_functions as mod
@@ -115,8 +129,20 @@ def _array_vals():
115129
func = getattr(_elementwise_functions, func_name)
116130
if nargs(func) == 2:
117131
for y in _array_vals():
132+
# Disallow dtypes that aren't type promotable
133+
if (func_name not in comparison_functions and
134+
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
135+
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
136+
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
137+
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
138+
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
139+
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
140+
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
141+
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
142+
)):
143+
assert_raises(TypeError, lambda: func(x, y), (func_name, x, y))
118144
if x.dtype not in dtypes or y.dtype not in dtypes:
119-
assert_raises(TypeError, lambda: func(x, y))
145+
assert_raises(TypeError, lambda: func(x, y), (func_name, x, y))
120146
else:
121147
if x.dtype not in dtypes:
122148
assert_raises(TypeError, lambda: func(x))

0 commit comments

Comments
 (0)