Skip to content

Commit fbec5fd

Browse files
committed
Add a less() helper to allow comparing uint64 and signed int arrays
1 parent 704f456 commit fbec5fd

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

array_api_tests/array_helpers.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
2-
logical_or, isfinite, greater, less, less_equal,
2+
logical_or, isfinite, greater, less_equal,
33
zeros, ones, full, bool, int8, int16, int32,
44
int64, uint8, uint16, uint32, uint64, float32,
55
float64, nan, inf, pi, remainder, divide, isinf,
@@ -164,6 +164,16 @@ def notequal(x, y):
164164

165165
return not_equal(x, y)
166166

167+
def less(x, y):
168+
"""
169+
Same as less(x, y) except it allows comparing uint64 with signed int dtypes
170+
"""
171+
if x.dtype == uint64 and dh.dtype_signed[y.dtype]:
172+
return xp.where(y < 0, xp.asarray(False), xp.less(x, xp.astype(y, uint64)))
173+
if y.dtype == uint64 and dh.dtype_signed[x.dtype]:
174+
return xp.where(x < 0, xp.asarray(True), xp.less(xp.astype(x, uint64), y))
175+
return xp.less(x, y)
176+
167177
def assert_exactly_equal(x, y, msg_extra=None):
168178
"""
169179
Test that the arrays x and y are exactly equal.

meta_tests/test_array_helpers.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from hypothesis import given
2+
13
from array_api_tests import _array_module as xp
2-
from array_api_tests .array_helpers import exactly_equal, notequal
4+
from array_api_tests.hypothesis_helpers import two_mutual_arrays
5+
from array_api_tests.dtype_helpers import int_dtypes
6+
from array_api_tests.shape_helpers import iter_indices, broadcast_shapes
7+
from array_api_tests .array_helpers import exactly_equal, notequal, less
38

49
# TODO: These meta-tests currently only work with NumPy
510

@@ -17,3 +22,10 @@ def test_notequal():
1722
res = xp.asarray([False, True, False, False, False, True, False, True])
1823
assert xp.all(xp.equal(notequal(a, b), res))
1924

25+
26+
@given(*two_mutual_arrays(dtypes=int_dtypes))
27+
def test_less(x, y):
28+
res = less(x, y)
29+
30+
for i, j, k in iter_indices(x.shape, y.shape, broadcast_shapes(x.shape, y.shape)):
31+
assert res[k] == (int(x[i]) < int(y[j]))

0 commit comments

Comments
 (0)