Skip to content

Fixes for 2023.12 tests #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
825949e
Fix accumulation_result_dtype for 2023.12 logic
asmeurer May 6, 2024
9aad419
Merge branch 'master' into 2023-fixes
asmeurer May 14, 2024
4a0d975
Fix the searchsorted test (and add a TODO)
asmeurer May 14, 2024
b7065de
Enable fft in the stubs
asmeurer May 15, 2024
6737695
Properly include __array_namespace_info__ in the stubs
asmeurer May 15, 2024
5aa865d
Test info functions in the signature tests
asmeurer May 15, 2024
ebb4f37
Skip cumulative_sum in the nan propogation special case test
asmeurer May 16, 2024
5e62058
Print the function name for non-machine readable special cases
asmeurer May 16, 2024
c8c9498
Enable dtype checks in sum and prod for 2023.12
asmeurer May 16, 2024
8a50ebc
Add shape, dtype, and value testing for cumulative_sum
asmeurer May 17, 2024
303d756
Merge branch 'master' into 2023-fixes
asmeurer May 23, 2024
bb33ff2
Fix flake8 issue
asmeurer May 23, 2024
a04ff8f
Fix test_cumulative_sum @given inputs
asmeurer May 23, 2024
6362204
Add missing tests for unstack()
asmeurer May 30, 2024
0f311e2
Fix formatting of greater than or equal sign
asmeurer Jun 5, 2024
279677d
Print the array module version in the tests header
asmeurer Jun 5, 2024
aafb6a1
Fix scalars() to not generate integers for floating-point dtypes
asmeurer Jun 5, 2024
bf3b773
Generate keyword arguments in test_clip()
asmeurer Jun 5, 2024
930932a
Add helper function is_scalar
asmeurer Jun 5, 2024
5993eca
Fix potentially undefined variable
asmeurer Jun 5, 2024
873eb64
Fix spelling
asmeurer Jun 5, 2024
4f0214e
Add value testing for clip()
asmeurer Jun 5, 2024
b311c83
Merge branch 'master' into 2023-fixes
asmeurer Jun 7, 2024
1b10ebf
Add the remainder of the value testing for clip()
asmeurer Jun 7, 2024
ceb63ea
Support int dtypes in test_clip and some portability fixes
asmeurer Jun 7, 2024
7ce5eb3
Fix input dtypes for test_clip
asmeurer Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"is_int_dtype",
"is_float_dtype",
"get_scalar_type",
"is_scalar",
"dtype_ranges",
"default_int",
"default_uint",
Expand Down Expand Up @@ -189,6 +190,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
else:
return bool

def is_scalar(x):
return isinstance(x, (int, float, complex, bool))

def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
Expand Down Expand Up @@ -275,6 +278,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
_dtype = x_dtype
else:
_dtype = default_dtype
elif api_version >= '2023.12':
# Starting in 2023.12, floats should not promote with dtype=None
_dtype = x_dtype
elif is_float_dtype(x_dtype, include_complex=False):
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
_dtype = x_dtype
Expand Down Expand Up @@ -322,6 +328,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
)
else:
default_complex = None

if dtype_nbits[default_int] == 32:
default_uint = _name_to_dtype.get("uint32")
else:
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def scalars(draw, dtypes, finite=False):
dtypes should be one of the shared_* dtypes strategies.
"""
dtype = draw(dtypes)
if dtype in dh.dtype_ranges:
if dh.is_int_dtype(dtype):
m, M = dh.dtype_ranges[dtype]
return draw(integers(m, M))
elif dtype == bool_dtype:
Expand Down
23 changes: 21 additions & 2 deletions array_api_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

category_to_funcs: Dict[str, List[FunctionType]] = {}
for name, mod in name_to_mod.items():
if name.endswith("_functions") or name == "info": # info functions file just named info.py
if name.endswith("_functions"):
category = name.replace("_functions", "")
objects = [getattr(mod, name) for name in mod.__all__]
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
Expand All @@ -55,7 +55,26 @@
all_funcs.extend(funcs)
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}

EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
info_funcs = []
if api_version >= "2023.12":
# The info functions in the stubs are in info.py, but this is not a name
# in the standard.
info_mod = name_to_mod["info"]

# Note that __array_namespace_info__ is in info.__all__ but it is in the
# top-level namespace, not the info namespace.
info_funcs = [getattr(info_mod, name) for name in info_mod.__all__
if name != '__array_namespace_info__']
assert all(isinstance(f, FunctionType) for f in info_funcs)
name_to_func.update({f.__name__: f for f in info_funcs})

all_funcs.append(info_mod.__array_namespace_info__)
name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__
category_to_funcs['info'] = [info_mod.__array_namespace_info__]

EXTENSIONS: List[str] = ["linalg"]
if api_version >= "2022.12":
EXTENSIONS.append("fft")
extension_to_funcs: Dict[str, List[FunctionType]] = {}
for ext in EXTENSIONS:
mod = name_to_mod[ext]
Expand Down
22 changes: 19 additions & 3 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,22 @@ def test_tile(x, data):
def test_unstack(x, data):
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing
out = xp.unstack(x, **kw)

assert isinstance(out, tuple)
assert len(out) == x.shape[axis]
expected_shape = list(x.shape)
expected_shape.pop(axis)
expected_shape = tuple(expected_shape)
for i in range(x.shape[axis]):
arr = out[i]
ph.assert_result_shape("unstack", in_shapes=[x.shape],
out_shape=arr.shape, expected=expected_shape,
kw=kw, repr_name=f"out[{i}].shape")

ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
repr_name=f"out[{i}].dtype")

idx = [slice(None)] * x.ndim
idx[axis] = i
ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")
133 changes: 123 additions & 10 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import cmath
import math
import operator
import builtins
from copy import copy
from enum import Enum, auto
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
Expand Down Expand Up @@ -369,6 +370,8 @@ def right_scalar_assert_against_refimpl(

See unary_assert_against_refimpl for more information.
"""
if expr_template is None:
expr_template = func_name + "({}, {})={}"
if left.dtype in dh.complex_dtypes:
component_filter = copy(filter_)
filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
Expand Down Expand Up @@ -422,7 +425,7 @@ def right_scalar_assert_against_refimpl(
)


# When appropiate, this module tests operators alongside their respective
# When appropriate, this module tests operators alongside their respective
# elementwise methods. We do this by parametrizing a generalised test method
# with every relevant method and operator.
#
Expand All @@ -432,8 +435,8 @@ def right_scalar_assert_against_refimpl(
# - The argument strategies, which can be used to draw arguments for the test
# case. They may require additional filtering for certain test cases.
# - right_is_scalar (binary parameters only), which denotes if the right
# argument is a scalar in a test case. This can be used to appropiately adjust
# draw filtering and test logic.
# argument is a scalar in a test case. This can be used to appropriately
# adjust draw filtering and test logic.


func_to_op = {v: k for k, v in dh.op_to_func.items()}
Expand Down Expand Up @@ -475,7 +478,7 @@ def make_unary_params(
)
if api_version < min_version:
marks = pytest.mark.skip(
reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}"
reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
)
else:
marks = ()
Expand Down Expand Up @@ -924,15 +927,125 @@ def test_ceil(x):


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes()))
def test_clip(x):
@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
def test_clip(x, data):
# TODO: test min/max kwargs, adjust values testing accordingly
out = xp.clip(x)
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
ph.assert_array_elements("clip", out=out, expected=x)

# Ensure that if both min and max are arrays that all three of x, min, max
# are broadcast compatible.
shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2,
base_shape=x.shape),
label="min.shape, max.shape")

dtypes = hh.real_floating_dtypes if dh.is_float_dtype(x.dtype) else hh.int_dtypes

min = data.draw(st.one_of(
st.none(),
hh.scalars(dtypes=st.just(x.dtype)),
hh.arrays(dtype=dtypes, shape=shape1),
), label="min")
max = data.draw(st.one_of(
st.none(),
hh.scalars(dtypes=st.just(x.dtype)),
hh.arrays(dtype=dtypes, shape=shape2),
), label="max")

# min > max is undefined (but allow nans)
assume(min is None or max is None or not xp.any(xp.asarray(min) > xp.asarray(max)))

kw = data.draw(
hh.specified_kwargs(
("min", min, None),
("max", max, None)),
label="kwargs")

out = xp.clip(x, **kw)

# min and max do not participate in type promotion
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)

shapes = [x.shape]
if min is not None and not dh.is_scalar(min):
shapes.append(min.shape)
if max is not None and not dh.is_scalar(max):
shapes.append(max.shape)
expected_shape = sh.broadcast_shapes(*shapes)
ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape)

if min is max is None:
ph.assert_array_elements("clip", out=out, expected=x)
elif max is None:
# If one operand is nan, the result is nan. See
# https://github.com/data-apis/array-api/pull/813.
def refimpl(_x, _min):
if math.isnan(_x) or math.isnan(_min):
return math.nan
return builtins.max(_x, _min)
if dh.is_scalar(min):
right_scalar_assert_against_refimpl(
"clip", x, min, out, refimpl,
left_sym="x",
expr_template="clip({}, min={})",
)
else:
binary_assert_against_refimpl(
"clip", x, min, out, refimpl,
left_sym="x", right_sym="min",
expr_template="clip({}, min={})",
)
elif min is None:
def refimpl(_x, _max):
if math.isnan(_x) or math.isnan(_max):
return math.nan
return builtins.min(_x, _max)
if dh.is_scalar(max):
right_scalar_assert_against_refimpl(
"clip", x, max, out, refimpl,
left_sym="x",
expr_template="clip({}, max={})",
)
else:
binary_assert_against_refimpl(
"clip", x, max, out, refimpl,
left_sym="x", right_sym="max",
expr_template="clip({}, max={})",
)
else:
def refimpl(_x, _min, _max):
if math.isnan(_x) or math.isnan(_min) or math.isnan(_max):
return math.nan
return builtins.min(builtins.max(_x, _min), _max)

# This is based on right_scalar_assert_against_refimpl and
# binary_assert_against_refimpl. clip() is currently the only ternary
# elementwise function and the only function that supports arrays and
# scalars. However, where() (in test_searching_functions) is similar
# and if scalar support is added to it, we may want to factor out and
# reuse this logic.

stype = dh.get_scalar_type(x.dtype)
min_shape = () if dh.is_scalar(min) else min.shape
max_shape = () if dh.is_scalar(max) else max.shape

for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
x.shape, min_shape, max_shape, out.shape):
x_val = stype(x[x_idx])
min_val = min if dh.is_scalar(min) else min[min_idx]
min_val = stype(min_val)
max_val = max if dh.is_scalar(max) else max[max_idx]
max_val = stype(max_val)
expected = refimpl(x_val, min_val, max_val)
out_val = stype(out[o_idx])
if math.isnan(expected):
assert math.isnan(out_val), (
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
)
else:
assert out_val == expected, (
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
)
if api_version >= "2022.12":

@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
Expand Down
6 changes: 2 additions & 4 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,14 @@ def test_where(shapes, dtypes, data):
@given(data=st.data())
def test_searchsorted(data):
# TODO: test side="right"
# TODO: Allow different dtypes for x1 and x2
_x1 = data.draw(
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=dh.default_float)
if data.draw(st.booleans(), label="use sorter?"):
sorter = data.draw(
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
label="sorter",
)
sorter = xp.argsort(x1)
else:
sorter = None
x1 = xp.sort(x1)
Expand Down
15 changes: 14 additions & 1 deletion array_api_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def squeeze(x, /, axis):

from . import dtype_helpers as dh
from . import xp
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
from .stubs import (array_methods, category_to_funcs, extension_to_funcs,
name_to_func, info_funcs)

ParameterKind = Literal[
Parameter.POSITIONAL_ONLY,
Expand Down Expand Up @@ -308,3 +309,15 @@ def test_array_method_signature(stub: FunctionType):
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
method = getattr(x, stub.__name__)
_test_func_signature(method, stub, is_method=True)

if info_funcs: # pytest fails collecting if info_funcs is empty
@pytest.mark.min_version("2023.12")
@pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__)
def test_info_func_signature(stub: FunctionType):
try:
info_namespace = xp.__array_namespace_info__()
except Exception as e:
raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}")

func = getattr(info_namespace, stub.__name__)
_test_func_signature(func, stub)
Loading
Loading