Skip to content

Commit 809a198

Browse files
authoredJun 25, 2024··
Merge pull request #262 from asmeurer/2023-fixes
Fixes for 2023.12 tests
2 parents dbdca7b + 7ce5eb3 commit 809a198

10 files changed

+289
-57
lines changed
 

‎array_api_tests/dtype_helpers.py

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"is_int_dtype",
3535
"is_float_dtype",
3636
"get_scalar_type",
37+
"is_scalar",
3738
"dtype_ranges",
3839
"default_int",
3940
"default_uint",
@@ -189,6 +190,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
189190
else:
190191
return bool
191192

193+
def is_scalar(x):
194+
return isinstance(x, (int, float, complex, bool))
192195

193196
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
194197
dtype_value_pairs = []
@@ -275,6 +278,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
275278
_dtype = x_dtype
276279
else:
277280
_dtype = default_dtype
281+
elif api_version >= '2023.12':
282+
# Starting in 2023.12, floats should not promote with dtype=None
283+
_dtype = x_dtype
278284
elif is_float_dtype(x_dtype, include_complex=False):
279285
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
280286
_dtype = x_dtype
@@ -322,6 +328,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
322328
)
323329
else:
324330
default_complex = None
331+
325332
if dtype_nbits[default_int] == 32:
326333
default_uint = _name_to_dtype.get("uint32")
327334
else:

‎array_api_tests/hypothesis_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def scalars(draw, dtypes, finite=False):
370370
dtypes should be one of the shared_* dtypes strategies.
371371
"""
372372
dtype = draw(dtypes)
373-
if dtype in dh.dtype_ranges:
373+
if dh.is_int_dtype(dtype):
374374
m, M = dh.dtype_ranges[dtype]
375375
return draw(integers(m, M))
376376
elif dtype == bool_dtype:

‎array_api_tests/stubs.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
category_to_funcs: Dict[str, List[FunctionType]] = {}
4646
for name, mod in name_to_mod.items():
47-
if name.endswith("_functions") or name == "info": # info functions file just named info.py
47+
if name.endswith("_functions"):
4848
category = name.replace("_functions", "")
4949
objects = [getattr(mod, name) for name in mod.__all__]
5050
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
@@ -55,7 +55,26 @@
5555
all_funcs.extend(funcs)
5656
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
5757

58-
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
58+
info_funcs = []
59+
if api_version >= "2023.12":
60+
# The info functions in the stubs are in info.py, but this is not a name
61+
# in the standard.
62+
info_mod = name_to_mod["info"]
63+
64+
# Note that __array_namespace_info__ is in info.__all__ but it is in the
65+
# top-level namespace, not the info namespace.
66+
info_funcs = [getattr(info_mod, name) for name in info_mod.__all__
67+
if name != '__array_namespace_info__']
68+
assert all(isinstance(f, FunctionType) for f in info_funcs)
69+
name_to_func.update({f.__name__: f for f in info_funcs})
70+
71+
all_funcs.append(info_mod.__array_namespace_info__)
72+
name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__
73+
category_to_funcs['info'] = [info_mod.__array_namespace_info__]
74+
75+
EXTENSIONS: List[str] = ["linalg"]
76+
if api_version >= "2022.12":
77+
EXTENSIONS.append("fft")
5978
extension_to_funcs: Dict[str, List[FunctionType]] = {}
6079
for ext in EXTENSIONS:
6180
mod = name_to_mod[ext]

‎array_api_tests/test_manipulation_functions.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,22 @@ def test_tile(x, data):
426426
def test_unstack(x, data):
427427
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
428428
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
429-
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
430-
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
431-
# TODO: shapes and values testing
429+
out = xp.unstack(x, **kw)
430+
431+
assert isinstance(out, tuple)
432+
assert len(out) == x.shape[axis]
433+
expected_shape = list(x.shape)
434+
expected_shape.pop(axis)
435+
expected_shape = tuple(expected_shape)
436+
for i in range(x.shape[axis]):
437+
arr = out[i]
438+
ph.assert_result_shape("unstack", in_shapes=[x.shape],
439+
out_shape=arr.shape, expected=expected_shape,
440+
kw=kw, repr_name=f"out[{i}].shape")
441+
442+
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
443+
repr_name=f"out[{i}].dtype")
444+
445+
idx = [slice(None)] * x.ndim
446+
idx[axis] = i
447+
ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")

‎array_api_tests/test_operators_and_elementwise_functions.py

+123-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import cmath
55
import math
66
import operator
7+
import builtins
78
from copy import copy
89
from enum import Enum, auto
910
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
@@ -369,6 +370,8 @@ def right_scalar_assert_against_refimpl(
369370
370371
See unary_assert_against_refimpl for more information.
371372
"""
373+
if expr_template is None:
374+
expr_template = func_name + "({}, {})={}"
372375
if left.dtype in dh.complex_dtypes:
373376
component_filter = copy(filter_)
374377
filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
@@ -422,7 +425,7 @@ def right_scalar_assert_against_refimpl(
422425
)
423426

424427

425-
# When appropiate, this module tests operators alongside their respective
428+
# When appropriate, this module tests operators alongside their respective
426429
# elementwise methods. We do this by parametrizing a generalised test method
427430
# with every relevant method and operator.
428431
#
@@ -432,8 +435,8 @@ def right_scalar_assert_against_refimpl(
432435
# - The argument strategies, which can be used to draw arguments for the test
433436
# case. They may require additional filtering for certain test cases.
434437
# - right_is_scalar (binary parameters only), which denotes if the right
435-
# argument is a scalar in a test case. This can be used to appropiately adjust
436-
# draw filtering and test logic.
438+
# argument is a scalar in a test case. This can be used to appropriately
439+
# adjust draw filtering and test logic.
437440

438441

439442
func_to_op = {v: k for k, v in dh.op_to_func.items()}
@@ -475,7 +478,7 @@ def make_unary_params(
475478
)
476479
if api_version < min_version:
477480
marks = pytest.mark.skip(
478-
reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}"
481+
reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
479482
)
480483
else:
481484
marks = ()
@@ -924,15 +927,125 @@ def test_ceil(x):
924927

925928

926929
@pytest.mark.min_version("2023.12")
927-
@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes()))
928-
def test_clip(x):
930+
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
931+
def test_clip(x, data):
929932
# TODO: test min/max kwargs, adjust values testing accordingly
930-
out = xp.clip(x)
931-
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
932-
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
933-
ph.assert_array_elements("clip", out=out, expected=x)
934933

934+
# Ensure that if both min and max are arrays that all three of x, min, max
935+
# are broadcast compatible.
936+
shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2,
937+
base_shape=x.shape),
938+
label="min.shape, max.shape")
939+
940+
dtypes = hh.real_floating_dtypes if dh.is_float_dtype(x.dtype) else hh.int_dtypes
941+
942+
min = data.draw(st.one_of(
943+
st.none(),
944+
hh.scalars(dtypes=st.just(x.dtype)),
945+
hh.arrays(dtype=dtypes, shape=shape1),
946+
), label="min")
947+
max = data.draw(st.one_of(
948+
st.none(),
949+
hh.scalars(dtypes=st.just(x.dtype)),
950+
hh.arrays(dtype=dtypes, shape=shape2),
951+
), label="max")
952+
953+
# min > max is undefined (but allow nans)
954+
assume(min is None or max is None or not xp.any(xp.asarray(min) > xp.asarray(max)))
955+
956+
kw = data.draw(
957+
hh.specified_kwargs(
958+
("min", min, None),
959+
("max", max, None)),
960+
label="kwargs")
961+
962+
out = xp.clip(x, **kw)
963+
964+
# min and max do not participate in type promotion
965+
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
935966

967+
shapes = [x.shape]
968+
if min is not None and not dh.is_scalar(min):
969+
shapes.append(min.shape)
970+
if max is not None and not dh.is_scalar(max):
971+
shapes.append(max.shape)
972+
expected_shape = sh.broadcast_shapes(*shapes)
973+
ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape)
974+
975+
if min is max is None:
976+
ph.assert_array_elements("clip", out=out, expected=x)
977+
elif max is None:
978+
# If one operand is nan, the result is nan. See
979+
# https://github.com/data-apis/array-api/pull/813.
980+
def refimpl(_x, _min):
981+
if math.isnan(_x) or math.isnan(_min):
982+
return math.nan
983+
return builtins.max(_x, _min)
984+
if dh.is_scalar(min):
985+
right_scalar_assert_against_refimpl(
986+
"clip", x, min, out, refimpl,
987+
left_sym="x",
988+
expr_template="clip({}, min={})",
989+
)
990+
else:
991+
binary_assert_against_refimpl(
992+
"clip", x, min, out, refimpl,
993+
left_sym="x", right_sym="min",
994+
expr_template="clip({}, min={})",
995+
)
996+
elif min is None:
997+
def refimpl(_x, _max):
998+
if math.isnan(_x) or math.isnan(_max):
999+
return math.nan
1000+
return builtins.min(_x, _max)
1001+
if dh.is_scalar(max):
1002+
right_scalar_assert_against_refimpl(
1003+
"clip", x, max, out, refimpl,
1004+
left_sym="x",
1005+
expr_template="clip({}, max={})",
1006+
)
1007+
else:
1008+
binary_assert_against_refimpl(
1009+
"clip", x, max, out, refimpl,
1010+
left_sym="x", right_sym="max",
1011+
expr_template="clip({}, max={})",
1012+
)
1013+
else:
1014+
def refimpl(_x, _min, _max):
1015+
if math.isnan(_x) or math.isnan(_min) or math.isnan(_max):
1016+
return math.nan
1017+
return builtins.min(builtins.max(_x, _min), _max)
1018+
1019+
# This is based on right_scalar_assert_against_refimpl and
1020+
# binary_assert_against_refimpl. clip() is currently the only ternary
1021+
# elementwise function and the only function that supports arrays and
1022+
# scalars. However, where() (in test_searching_functions) is similar
1023+
# and if scalar support is added to it, we may want to factor out and
1024+
# reuse this logic.
1025+
1026+
stype = dh.get_scalar_type(x.dtype)
1027+
min_shape = () if dh.is_scalar(min) else min.shape
1028+
max_shape = () if dh.is_scalar(max) else max.shape
1029+
1030+
for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
1031+
x.shape, min_shape, max_shape, out.shape):
1032+
x_val = stype(x[x_idx])
1033+
min_val = min if dh.is_scalar(min) else min[min_idx]
1034+
min_val = stype(min_val)
1035+
max_val = max if dh.is_scalar(max) else max[max_idx]
1036+
max_val = stype(max_val)
1037+
expected = refimpl(x_val, min_val, max_val)
1038+
out_val = stype(out[o_idx])
1039+
if math.isnan(expected):
1040+
assert math.isnan(out_val), (
1041+
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
1042+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1043+
)
1044+
else:
1045+
assert out_val == expected, (
1046+
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
1047+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1048+
)
9361049
if api_version >= "2022.12":
9371050

9381051
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))

‎array_api_tests/test_searching_functions.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,14 @@ def test_where(shapes, dtypes, data):
173173
@given(data=st.data())
174174
def test_searchsorted(data):
175175
# TODO: test side="right"
176+
# TODO: Allow different dtypes for x1 and x2
176177
_x1 = data.draw(
177178
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
178179
label="_x1",
179180
)
180181
x1 = xp.asarray(_x1, dtype=dh.default_float)
181182
if data.draw(st.booleans(), label="use sorter?"):
182-
sorter = data.draw(
183-
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
184-
label="sorter",
185-
)
183+
sorter = xp.argsort(x1)
186184
else:
187185
sorter = None
188186
x1 = xp.sort(x1)

‎array_api_tests/test_signatures.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def squeeze(x, /, axis):
3131

3232
from . import dtype_helpers as dh
3333
from . import xp
34-
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
34+
from .stubs import (array_methods, category_to_funcs, extension_to_funcs,
35+
name_to_func, info_funcs)
3536

3637
ParameterKind = Literal[
3738
Parameter.POSITIONAL_ONLY,
@@ -308,3 +309,15 @@ def test_array_method_signature(stub: FunctionType):
308309
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
309310
method = getattr(x, stub.__name__)
310311
_test_func_signature(method, stub, is_method=True)
312+
313+
if info_funcs: # pytest fails collecting if info_funcs is empty
314+
@pytest.mark.min_version("2023.12")
315+
@pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__)
316+
def test_info_func_signature(stub: FunctionType):
317+
try:
318+
info_namespace = xp.__array_namespace_info__()
319+
except Exception as e:
320+
raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}")
321+
322+
func = getattr(info_namespace, stub.__name__)
323+
_test_func_signature(func, stub)

0 commit comments

Comments
 (0)
Please sign in to comment.