Skip to content

Commit 5391803

Browse files
authored
Merge pull request #255 from honno/2023-coverage
2023.12 coverage
2 parents e1fe6fb + 4974c68 commit 5391803

5 files changed

+203
-1
lines changed

Diff for: array_api_tests/test_inspection_functions.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
from hypothesis import given, strategies as st
3+
4+
from . import xp
5+
6+
pytestmark = pytest.mark.min_version("2023.12")
7+
8+
9+
def test_array_namespace_info():
10+
out = xp.__array_namespace_info__()
11+
12+
capabilities = out.capabilities()
13+
assert isinstance(capabilities, dict)
14+
15+
out.default_device()
16+
17+
default_dtypes = out.default_dtypes()
18+
assert isinstance(default_dtypes, dict)
19+
assert {"real floating", "complex floating", "integral", "indexing"}.issubset(set(default_dtypes.keys()))
20+
21+
devices = out.devices()
22+
assert isinstance(devices, list)
23+
24+
25+
atomic_kinds = [
26+
"bool",
27+
"signed integer",
28+
"unsigned integer",
29+
"real floating",
30+
"complex floating",
31+
]
32+
33+
34+
@given(
35+
st.one_of(
36+
st.none(),
37+
st.sampled_from(atomic_kinds + ["integral", "numeric"]),
38+
st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple),
39+
)
40+
)
41+
def test_array_namespace_info_dtypes(kind):
42+
out = xp.__array_namespace_info__().dtypes(kind=kind)
43+
assert isinstance(out, dict)

Diff for: array_api_tests/test_manipulation_functions.py

+58
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,31 @@ def test_expand_dims(x, axis):
149149
)
150150

151151

152+
@pytest.mark.min_version("2023.12")
153+
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
154+
def test_moveaxis(x, data):
155+
source = data.draw(
156+
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source"
157+
)
158+
if isinstance(source, int):
159+
destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination")
160+
else:
161+
assert isinstance(source, tuple) # sanity check
162+
destination = data.draw(
163+
st.lists(
164+
st.integers(-x.ndim, x.ndim - 1),
165+
min_size=len(source),
166+
max_size=len(source),
167+
unique_by=lambda n: n if n >= 0 else x.ndim + n,
168+
).map(tuple),
169+
label="destination"
170+
)
171+
172+
out = xp.moveaxis(x, source, destination)
173+
174+
ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
175+
# TODO: shape and values testing
176+
152177
@pytest.mark.unvectorized
153178
@given(
154179
x=hh.arrays(
@@ -253,6 +278,20 @@ def reshape_shapes(draw, shape):
253278
return tuple(rshape)
254279

255280

281+
@pytest.mark.min_version("2023.12")
282+
@given(
283+
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)),
284+
repeats=st.integers(1, 4),
285+
)
286+
def test_repeat(x, repeats):
287+
# TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
288+
out = xp.repeat(x, repeats)
289+
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
290+
expected_shape = (math.prod(x.shape) * repeats,)
291+
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
292+
# TODO: values testing
293+
294+
256295
@pytest.mark.unvectorized
257296
@pytest.mark.skip("flaky") # TODO: fix!
258297
@given(
@@ -371,3 +410,22 @@ def test_stack(shape, dtypes, kw, data):
371410
out_val=out[out_idx],
372411
kw=kw,
373412
)
413+
414+
415+
@pytest.mark.min_version("2023.12")
416+
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
417+
def test_tile(x, data):
418+
repetitions = data.draw(st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions")
419+
out = xp.tile(x, repetitions)
420+
ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype)
421+
# TODO: shapes and values testing
422+
423+
424+
@pytest.mark.min_version("2023.12")
425+
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
426+
def test_unstack(x, data):
427+
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
428+
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
429+
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
430+
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
431+
# TODO: shapes and values testing

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+55
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,16 @@ def test_ceil(x):
933933
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
934934

935935

936+
@pytest.mark.min_version("2023.12")
937+
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
938+
def test_clip(x):
939+
# TODO: test min/max kwargs, adjust values testing accordingly
940+
out = xp.clip(x)
941+
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
942+
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
943+
ph.assert_array_elements("clip", out=out, expected=x)
944+
945+
936946
if api_version >= "2022.12":
937947

938948
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -943,6 +953,15 @@ def test_conj(x):
943953
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
944954

945955

956+
@pytest.mark.min_version("2023.12")
957+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
958+
def test_copysign(x1, x2):
959+
out = xp.copysign(x1, x2)
960+
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
961+
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
962+
# TODO: values testing
963+
964+
946965
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
947966
def test_cos(x):
948967
out = xp.cos(x)
@@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data):
10951114
)
10961115

10971116

1117+
@pytest.mark.min_version("2023.12")
1118+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1119+
def test_hypot(x1, x2):
1120+
out = xp.hypot(x1, x2)
1121+
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1122+
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1123+
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1124+
1125+
10981126
if api_version >= "2022.12":
10991127

11001128
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2):
12611289
)
12621290

12631291

1292+
@pytest.mark.min_version("2023.12")
1293+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1294+
def test_maximum(x1, x2):
1295+
out = xp.maximum(x1, x2)
1296+
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1297+
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1298+
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
1299+
1300+
1301+
@pytest.mark.min_version("2023.12")
1302+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1303+
def test_minimum(x1, x2):
1304+
out = xp.minimum(x1, x2)
1305+
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1306+
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1307+
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1308+
1309+
12641310
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
12651311
@given(data=st.data())
12661312
def test_multiply(ctx, data):
@@ -1380,6 +1426,15 @@ def test_round(x):
13801426
unary_assert_against_refimpl("round", x, out, round, strict_check=True)
13811427

13821428

1429+
@pytest.mark.min_version("2023.12")
1430+
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
1431+
def test_signbit(x):
1432+
out = xp.signbit(x)
1433+
ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
1434+
ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape)
1435+
# TODO: values testing
1436+
1437+
13831438
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw))
13841439
def test_sign(x):
13851440
out = xp.sign(x)

Diff for: array_api_tests/test_searching_functions.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22

33
import pytest
4-
from hypothesis import given
4+
from hypothesis import given, note
55
from hypothesis import strategies as st
66

77
from . import _array_module as xp
@@ -167,3 +167,39 @@ def test_where(shapes, dtypes, data):
167167
out_repr=f"out[{idx}]",
168168
out_val=out[idx]
169169
)
170+
171+
172+
@pytest.mark.min_version("2023.12")
173+
@given(data=st.data())
174+
def test_searchsorted(data):
175+
# TODO: test side="right"
176+
_x1 = data.draw(
177+
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
178+
label="_x1",
179+
)
180+
x1 = xp.asarray(_x1, dtype=dh.default_float)
181+
if data.draw(st.booleans(), label="use sorter?"):
182+
sorter = data.draw(
183+
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
184+
label="sorter",
185+
)
186+
else:
187+
sorter = None
188+
x1 = xp.sort(x1)
189+
note(f"{x1=}")
190+
x2 = data.draw(
191+
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
192+
lambda o: xp.asarray(o, dtype=dh.default_float)
193+
),
194+
label="x2",
195+
)
196+
197+
out = xp.searchsorted(x1, x2, sorter=sorter)
198+
199+
ph.assert_dtype(
200+
"searchsorted",
201+
in_dtype=[x1.dtype, x2.dtype],
202+
out_dtype=out.dtype,
203+
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
204+
)
205+
# TODO: shapes and values testing

Diff for: array_api_tests/test_statistical_functions.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
from .typing import DataType
1717

1818

19+
@pytest.mark.min_version("2023.12")
20+
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
21+
def test_cumulative_sum(x):
22+
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
23+
out = xp.cumulative_sum(x)
24+
# TODO: assert dtype
25+
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
26+
# TODO: assert values
27+
28+
1929
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2030
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
2131
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]

0 commit comments

Comments
 (0)