Skip to content

Commit db87db9

Browse files
authored
Merge pull request #86 from asmeurer/iter_indices-linalg
Use ndindex.iter_indices in _test_stacks in the linalg tests
2 parents 520e685 + a3423fe commit db87db9

File tree

3 files changed

+38
-21
lines changed

3 files changed

+38
-21
lines changed

Diff for: .github/workflows/numpy.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ jobs:
3434
array_api_tests/test_creation_functions.py::test_linspace
3535
# https://github.com/numpy/numpy/issues/20870
3636
array_api_tests/test_data_type_functions.py::test_can_cast
37-
# linalg tests generally need more mulling over
38-
array_api_tests/test_linalg.py
37+
# The return dtype for trace is not consistent in the spec
38+
# (https://github.com/data-apis/array-api/issues/202#issuecomment-952529197)
39+
array_api_tests/test_linalg.py::test_trace
3940
# waiting on NumPy to allow/revert distinct NaNs for np.unique
4041
# https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448
4142
array_api_tests/test_set_functions.py

Diff for: array_api_tests/test_linalg.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
1919
shared, sampled_from, data, just)
20+
from ndindex import iter_indices
2021

21-
from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity
22+
from .array_helpers import assert_exactly_equal, asarray
2223
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2324
square_matrix_shapes, symmetric_matrices,
2425
positive_definite_matrices, MAX_ARRAY_SIZE,
@@ -43,29 +44,41 @@
4344
# Standin strategy for not yet implemented tests
4445
todo = none()
4546

46-
def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw):
47+
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1),
48+
assert_equal=assert_exactly_equal, **kw):
4749
"""
4850
Test that f(*args, **kw) maps across stacks of matrices
4951
50-
dims is the number of dimensions f should have for a single n x m matrix
51-
stack.
52+
dims is the number of dimensions f(*args) should have for a single n x m
53+
matrix stack.
54+
55+
matrix_axes are the axes along which matrices (or vectors) are stacked in
56+
the input.
57+
58+
true_val may be a function such that true_val(*x_stacks, **kw) gives the
59+
true value for f on a stack.
60+
61+
res should be the result of f(*args, **kw). It is computed if not passed
62+
in.
5263
53-
true_val may be a function such that true_val(*x_stacks) gives the true
54-
value for f on a stack
5564
"""
5665
if res is None:
5766
res = f(*args, **kw)
5867

59-
shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape
60-
for x in args])
61-
for _idx in sh.ndindex(shape[:-2]):
62-
idx = _idx + (slice(None),)*dims
63-
res_stack = res[idx]
64-
x_stacks = [x[_idx + (...,)] for x in args]
68+
shapes = [x.shape for x in args]
69+
70+
for (x_idxes, (res_idx,)) in zip(
71+
iter_indices(*shapes, skip_axes=matrix_axes),
72+
iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))):
73+
x_idxes = [x_idx.raw for x_idx in x_idxes]
74+
res_idx = res_idx.raw
75+
76+
res_stack = res[res_idx]
77+
x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)]
6578
decomp_res_stack = f(*x_stacks, **kw)
66-
assert_exactly_equal(res_stack, decomp_res_stack)
79+
assert_equal(res_stack, decomp_res_stack)
6780
if true_val:
68-
assert_exactly_equal(decomp_res_stack, true_val(*x_stacks))
81+
assert_equal(decomp_res_stack, true_val(*x_stacks))
6982

7083
def _test_namedtuple(res, fields, func_name):
7184
"""
@@ -452,10 +465,12 @@ def test_slogdet(x):
452465

453466
# Check that when the determinant is 0, the sign and logabsdet are (0,
454467
# -inf).
455-
d = linalg.det(x)
456-
zero_det = equal(d, zero(d.shape, d.dtype))
457-
assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype))
458-
assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype))
468+
# TODO: This test does not necessarily hold exactly. Update it to test it
469+
# approximately.
470+
# d = linalg.det(x)
471+
# zero_det = equal(d, zero(d.shape, d.dtype))
472+
# assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype))
473+
# assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype))
459474

460475
# More generally, det(x) should equal sign*exp(logabsdet), but this does
461476
# not hold exactly due to floating-point loss of precision.
@@ -614,7 +629,7 @@ def true_trace(x_stack):
614629

615630
@given(
616631
dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes),
617-
shape=shapes(),
632+
shape=shapes(min_dims=1),
618633
data=data(),
619634
)
620635
def test_vecdot(dtypes, shape, data):

Diff for: requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest
22
hypothesis>=6.31.1
3+
ndindex>=1.6
34
regex
45
removestar

0 commit comments

Comments
 (0)