Skip to content

Commit 827edd8

Browse files
authored
Merge pull request #275 from asmeurer/more-2023
More 2023.12 test fixes
2 parents 46d10db + b35152c commit 827edd8

15 files changed

+161
-106
lines changed

.github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
- name: Run the test suite
2828
env:
2929
ARRAY_API_TESTS_MODULE: array_api_strict
30+
ARRAY_API_STRICT_API_VERSION: 2023.12
3031
run: |
3132
pytest -v -rxXfE --skips-file array-api-strict-skips.txt array_api_tests/
3233
# We also have internal tests that isn't really necessary for adopters

array_api_tests/array_helpers.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
2-
logical_or, isfinite, greater, less, less_equal,
2+
logical_or, isfinite, greater, less_equal,
33
zeros, ones, full, bool, int8, int16, int32,
44
int64, uint8, uint16, uint32, uint64, float32,
55
float64, nan, inf, pi, remainder, divide, isinf,
@@ -164,6 +164,16 @@ def notequal(x, y):
164164

165165
return not_equal(x, y)
166166

167+
def less(x, y):
168+
"""
169+
Same as less(x, y) except it allows comparing uint64 with signed int dtypes
170+
"""
171+
if x.dtype == uint64 and dh.dtype_signed[y.dtype]:
172+
return xp.where(y < 0, xp.asarray(False), xp.less(x, xp.astype(y, uint64)))
173+
if y.dtype == uint64 and dh.dtype_signed[x.dtype]:
174+
return xp.where(x < 0, xp.asarray(True), xp.less(xp.astype(x, uint64), y))
175+
return xp.less(x, y)
176+
167177
def assert_exactly_equal(x, y, msg_extra=None):
168178
"""
169179
Test that the arrays x and y are exactly equal.

array_api_tests/dtype_helpers.py

+3
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ class MinMax(NamedTuple):
209209
min: Union[int, float]
210210
max: Union[int, float]
211211

212+
def __contains__(self, other):
213+
assert isinstance(other, (int, float))
214+
return self.min <= other <= self.max
212215

213216
dtype_ranges = _make_dtype_mapping_from_names(
214217
{

array_api_tests/hypothesis_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
177177
# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
178178
# ARRAY_API_TESTS_SKIP_DTYPES
179179
all_dtypes = sampled_from(_sorted_dtypes)
180-
int_dtypes = sampled_from(dh.int_dtypes)
180+
int_dtypes = sampled_from(dh.all_int_dtypes)
181181
uint_dtypes = sampled_from(dh.uint_dtypes)
182182
real_dtypes = sampled_from(dh.real_dtypes)
183183
# Warning: The hypothesis "floating_dtypes" is what we call

array_api_tests/shape_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
__all__ = [
1010
"broadcast_shapes",
11-
"normalise_axis",
11+
"normalize_axis",
1212
"ndindex",
1313
"axis_ndindex",
1414
"axes_ndindex",
@@ -65,7 +65,7 @@ def broadcast_shapes(*shapes: Shape):
6565
return result
6666

6767

68-
def normalise_axis(
68+
def normalize_axis(
6969
axis: Optional[Union[int, Sequence[int]]], ndim: int
7070
) -> Tuple[int, ...]:
7171
if axis is None:

array_api_tests/test_array_object.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def scalar_objects(
2626
)
2727

2828

29-
def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
29+
def normalize_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
3030
"""
31-
Normalise an indexing key.
31+
Normalize an indexing key.
3232
3333
* If a non-tuple index, wrap as a tuple.
3434
* Represent ellipsis as equivalent slices.
@@ -48,7 +48,7 @@ def get_indexed_axes_and_out_shape(
4848
key: Tuple[Union[int, slice, None], ...], shape: Shape
4949
) -> Tuple[Tuple[Sequence[int], ...], Shape]:
5050
"""
51-
From the (normalised) key and input shape, calculates:
51+
From the (normalized) key and input shape, calculates:
5252
5353
* indexed_axes: For each dimension, the axes which the key indexes.
5454
* out_shape: The resulting shape of indexing an array (of the input shape)
@@ -88,7 +88,7 @@ def test_getitem(shape, dtype, data):
8888
out = x[key]
8989

9090
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
91-
_key = normalise_key(key, shape)
91+
_key = normalize_key(key, shape)
9292
axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape)
9393
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
9494
out_zero_sided = any(side == 0 for side in expected_shape)
@@ -119,7 +119,7 @@ def test_setitem(shape, dtypes, data):
119119
x = xp.asarray(obj, dtype=dtypes.result_dtype)
120120
note(f"{x=}")
121121
key = data.draw(xps.indices(shape=shape), label="key")
122-
_key = normalise_key(key, shape)
122+
_key = normalize_key(key, shape)
123123
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
124124
value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape)
125125
if out_shape == ():

array_api_tests/test_fft.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def assert_s_axes_shape(
9494
axes: Optional[List[int]],
9595
out: Array,
9696
):
97-
_axes = sh.normalise_axis(axes, x.ndim)
97+
_axes = sh.normalize_axis(axes, x.ndim)
9898
_s = x.shape if s is None else s
9999
expected = []
100100
for i in range(x.ndim):
@@ -193,7 +193,7 @@ def test_rfftn(x, data):
193193

194194
ph.assert_float_to_complex_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
195195

196-
_axes = sh.normalise_axis(axes, x.ndim)
196+
_axes = sh.normalize_axis(axes, x.ndim)
197197
_s = x.shape if s is None else s
198198
expected = []
199199
for i in range(x.ndim):
@@ -225,7 +225,7 @@ def test_irfftn(x, data):
225225
)
226226

227227
# TODO: assert shape correctly
228-
# _axes = sh.normalise_axis(axes, x.ndim)
228+
# _axes = sh.normalize_axis(axes, x.ndim)
229229
# _s = x.shape if s is None else s
230230
# expected = []
231231
# for i in range(x.ndim):

array_api_tests/test_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def test_vector_norm(x, data):
980980
# TODO: Check that the ord values give the correct norms.
981981
# ord = kw.get('ord', 2)
982982

983-
_axes = sh.normalise_axis(axis, x.ndim)
983+
_axes = sh.normalize_axis(axis, x.ndim)
984984

985985
ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
986986
in_shape=x.shape, axes=_axes,

array_api_tests/test_manipulation_functions.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,27 @@ def test_moveaxis(x, data):
172172
out = xp.moveaxis(x, source, destination)
173173

174174
ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
175-
# TODO: shape and values testing
175+
176+
177+
_source = sh.normalize_axis(source, x.ndim)
178+
_destination = sh.normalize_axis(destination, x.ndim)
179+
180+
new_axes = [n for n in range(x.ndim) if n not in _source]
181+
182+
for dest, src in sorted(zip(_destination, _source)):
183+
new_axes.insert(dest, src)
184+
185+
expected_shape = tuple(x.shape[i] for i in new_axes)
186+
187+
ph.assert_result_shape("moveaxis", in_shapes=[x.shape],
188+
out_shape=out.shape, expected=expected_shape,
189+
kw={"source": source, "destination": destination})
190+
191+
indices = list(sh.ndindex(x.shape))
192+
permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices]
193+
assert_array_ndindex(
194+
"moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices
195+
)
176196

177197
@pytest.mark.unvectorized
178198
@given(
@@ -190,7 +210,7 @@ def test_squeeze(x, data):
190210
)
191211

192212
axes = (axis,) if isinstance(axis, int) else axis
193-
axes = sh.normalise_axis(axes, x.ndim)
213+
axes = sh.normalize_axis(axes, x.ndim)
194214

195215
squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1]
196216
if any(i not in squeezable_axes for i in axes):
@@ -230,7 +250,7 @@ def test_flip(x, data):
230250

231251
ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype)
232252

233-
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
253+
_axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
234254
for indices in sh.axes_ndindex(x.shape, _axes):
235255
reverse_indices = indices[::-1]
236256
assert_array_ndindex("flip", x, x_indices=indices, out=out,
@@ -360,7 +380,7 @@ def test_roll(x, data):
360380
assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw)
361381
else:
362382
shifts = (shift,) if isinstance(shift, int) else shift
363-
axes = sh.normalise_axis(kw["axis"], x.ndim)
383+
axes = sh.normalize_axis(kw["axis"], x.ndim)
364384
shifted_indices = roll_ndindex(x.shape, shifts, axes)
365385
assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw)
366386

array_api_tests/test_operators_and_elementwise_functions.py

+79-74
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,6 @@ def test_ceil(x):
929929
@pytest.mark.min_version("2023.12")
930930
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
931931
def test_clip(x, data):
932-
# TODO: test min/max kwargs, adjust values testing accordingly
933-
934932
# Ensure that if both min and max are arrays that all three of x, min, max
935933
# are broadcast compatible.
936934
shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2,
@@ -951,7 +949,7 @@ def test_clip(x, data):
951949
), label="max")
952950

953951
# min > max is undefined (but allow nans)
954-
assume(min is None or max is None or not xp.any(xp.asarray(min) > xp.asarray(max)))
952+
assume(min is None or max is None or not xp.any(ah.less(xp.asarray(max), xp.asarray(min))))
955953

956954
kw = data.draw(
957955
hh.specified_kwargs(
@@ -972,80 +970,86 @@ def test_clip(x, data):
972970
expected_shape = sh.broadcast_shapes(*shapes)
973971
ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape)
974972

975-
if min is max is None:
976-
ph.assert_array_elements("clip", out=out, expected=x)
977-
elif max is None:
978-
# If one operand is nan, the result is nan. See
979-
# https://github.com/data-apis/array-api/pull/813.
980-
def refimpl(_x, _min):
981-
if math.isnan(_x) or math.isnan(_min):
982-
return math.nan
973+
# This is based on right_scalar_assert_against_refimpl and
974+
# binary_assert_against_refimpl. clip() is currently the only ternary
975+
# elementwise function and the only function that supports arrays and
976+
# scalars. However, where() (in test_searching_functions) is similar
977+
# and if scalar support is added to it, we may want to factor out and
978+
# reuse this logic.
979+
980+
def refimpl(_x, _min, _max):
981+
# Skip cases where _min and _max are integers whose values do not
982+
# fit in the dtype of _x, since this behavior is unspecified.
983+
if dh.is_int_dtype(x.dtype):
984+
if _min is not None and _min not in dh.dtype_ranges[x.dtype]:
985+
return None
986+
if _max is not None and _max not in dh.dtype_ranges[x.dtype]:
987+
return None
988+
989+
# If min or max are float64 and x is float32, they will need to be
990+
# downcast to float32. This could result in a round in the wrong
991+
# direction meaning the resulting clipped value might not actually be
992+
# between min and max. This behavior is unspecified, so skip any cases
993+
# where x is within the rounding error of downcasting min or max.
994+
if x.dtype == xp.float32:
995+
if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min):
996+
_min_float32 = float(xp.asarray(_min, dtype=xp.float32))
997+
if math.isinf(_min_float32):
998+
return None
999+
tol = abs(_min - _min_float32)
1000+
if math.isclose(_min, _min_float32, abs_tol=tol):
1001+
return None
1002+
if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max):
1003+
_max_float32 = float(xp.asarray(_max, dtype=xp.float32))
1004+
if math.isinf(_max_float32):
1005+
return None
1006+
tol = abs(_max - _max_float32)
1007+
if math.isclose(_max, _max_float32, abs_tol=tol):
1008+
return None
1009+
1010+
if (math.isnan(_x)
1011+
or (_min is not None and math.isnan(_min))
1012+
or (_max is not None and math.isnan(_max))):
1013+
return math.nan
1014+
if _min is _max is None:
1015+
return _x
1016+
if _max is None:
9831017
return builtins.max(_x, _min)
984-
if dh.is_scalar(min):
985-
right_scalar_assert_against_refimpl(
986-
"clip", x, min, out, refimpl,
987-
left_sym="x",
988-
expr_template="clip({}, min={})",
989-
)
990-
else:
991-
binary_assert_against_refimpl(
992-
"clip", x, min, out, refimpl,
993-
left_sym="x", right_sym="min",
994-
expr_template="clip({}, min={})",
995-
)
996-
elif min is None:
997-
def refimpl(_x, _max):
998-
if math.isnan(_x) or math.isnan(_max):
999-
return math.nan
1018+
if _min is None:
10001019
return builtins.min(_x, _max)
1001-
if dh.is_scalar(max):
1002-
right_scalar_assert_against_refimpl(
1003-
"clip", x, max, out, refimpl,
1004-
left_sym="x",
1005-
expr_template="clip({}, max={})",
1020+
return builtins.min(builtins.max(_x, _min), _max)
1021+
1022+
stype = dh.get_scalar_type(x.dtype)
1023+
min_shape = () if min is None or dh.is_scalar(min) else min.shape
1024+
max_shape = () if max is None or dh.is_scalar(max) else max.shape
1025+
1026+
for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
1027+
x.shape, min_shape, max_shape, out.shape):
1028+
x_val = stype(x[x_idx])
1029+
if min is None or dh.is_scalar(min):
1030+
min_val = min
1031+
else:
1032+
min_val = stype(min[min_idx])
1033+
if max is None or dh.is_scalar(max):
1034+
max_val = max
1035+
else:
1036+
max_val = stype(max[max_idx])
1037+
expected = refimpl(x_val, min_val, max_val)
1038+
if expected is None:
1039+
continue
1040+
out_val = stype(out[o_idx])
1041+
if math.isnan(expected):
1042+
assert math.isnan(out_val), (
1043+
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
1044+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
10061045
)
10071046
else:
1008-
binary_assert_against_refimpl(
1009-
"clip", x, max, out, refimpl,
1010-
left_sym="x", right_sym="max",
1011-
expr_template="clip({}, max={})",
1047+
assert out_val == expected, (
1048+
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
1049+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
10121050
)
1013-
else:
1014-
def refimpl(_x, _min, _max):
1015-
if math.isnan(_x) or math.isnan(_min) or math.isnan(_max):
1016-
return math.nan
1017-
return builtins.min(builtins.max(_x, _min), _max)
1018-
1019-
# This is based on right_scalar_assert_against_refimpl and
1020-
# binary_assert_against_refimpl. clip() is currently the only ternary
1021-
# elementwise function and the only function that supports arrays and
1022-
# scalars. However, where() (in test_searching_functions) is similar
1023-
# and if scalar support is added to it, we may want to factor out and
1024-
# reuse this logic.
1025-
1026-
stype = dh.get_scalar_type(x.dtype)
1027-
min_shape = () if dh.is_scalar(min) else min.shape
1028-
max_shape = () if dh.is_scalar(max) else max.shape
1029-
1030-
for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
1031-
x.shape, min_shape, max_shape, out.shape):
1032-
x_val = stype(x[x_idx])
1033-
min_val = min if dh.is_scalar(min) else min[min_idx]
1034-
min_val = stype(min_val)
1035-
max_val = max if dh.is_scalar(max) else max[max_idx]
1036-
max_val = stype(max_val)
1037-
expected = refimpl(x_val, min_val, max_val)
1038-
out_val = stype(out[o_idx])
1039-
if math.isnan(expected):
1040-
assert math.isnan(out_val), (
1041-
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
1042-
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1043-
)
1044-
else:
1045-
assert out_val == expected, (
1046-
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
1047-
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1048-
)
1051+
1052+
10491053
if api_version >= "2022.12":
10501054

10511055
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
@@ -1062,7 +1066,7 @@ def test_copysign(x1, x2):
10621066
out = xp.copysign(x1, x2)
10631067
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
10641068
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1065-
# TODO: values testing
1069+
binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign)
10661070

10671071

10681072
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
@@ -1535,7 +1539,8 @@ def test_signbit(x):
15351539
out = xp.signbit(x)
15361540
ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
15371541
ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape)
1538-
# TODO: values testing
1542+
refimpl = lambda x: math.copysign(1.0, x) < 0
1543+
unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True)
15391544

15401545

15411546
@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(), elements=finite_kw))

0 commit comments

Comments
 (0)