Skip to content

Commit 1b97e58

Browse files
committed
dump
1 parent e38ce34 commit 1b97e58

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

Diff for: array_api_tests/test_inspection_functions.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import xp
2+
3+
4+
def test_array_namespace_info():
5+
assert hasattr(xp, "__array_namespace_info__")
6+
# TODO: test output

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+46
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,16 @@ def test_ceil(x):
933933
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
934934

935935

936+
@pytest.mark.min_version("2023.12")
937+
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
938+
def test_clip(x):
939+
# TODO: test min/max kwargs, adjust values testing accordingly
940+
out = xp.clip(x)
941+
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
942+
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
943+
ph.assert_array_elements("clip", out=out, expected=x)
944+
945+
936946
if api_version >= "2022.12":
937947

938948
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -943,6 +953,15 @@ def test_conj(x):
943953
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
944954

945955

956+
@pytest.mark.min_version("2023.12")
957+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
958+
def test_copysign(x1, x2):
959+
out = xp.copysign(x1, x2)
960+
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
961+
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
962+
# TODO: values testing
963+
964+
946965
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
947966
def test_cos(x):
948967
out = xp.cos(x)
@@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data):
10951114
)
10961115

10971116

1117+
@pytest.mark.min_version("2023.12")
1118+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1119+
def test_hypot(x1, x2):
1120+
out = xp.hypot(x1, x2)
1121+
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1122+
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1123+
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1124+
1125+
10981126
if api_version >= "2022.12":
10991127

11001128
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2):
12611289
)
12621290

12631291

1292+
@pytest.mark.min_version("2023.12")
1293+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1294+
def test_maximum(x1, x2):
1295+
out = xp.maximum(x1, x2)
1296+
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1297+
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1298+
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
1299+
1300+
1301+
@pytest.mark.min_version("2023.12")
1302+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1303+
def test_minimum(x1, x2):
1304+
out = xp.minimum(x1, x2)
1305+
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1306+
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1307+
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1308+
1309+
12641310
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
12651311
@given(data=st.data())
12661312
def test_multiply(ctx, data):

Diff for: array_api_tests/test_statistical_functions.py

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
from .typing import DataType
1717

1818

19+
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
20+
def test_cumulative_sum(x):
21+
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
22+
out = xp.cumulative_sum(x)
23+
# TODO: assert dtype
24+
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
25+
# TODO: assert values
26+
27+
1928
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2029
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
2130
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]

0 commit comments

Comments
 (0)