Skip to content

Commit 933e610

Browse files
committed
Allow specifying the equality check in the linalg _test_stacks
We will eventually change this to be a float-point approximate check rather than exact equality. For now, this is not updated. See #44
1 parent 1a435ae commit 933e610

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

Diff for: array_api_tests/test_linalg.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
# Standin strategy for not yet implemented tests
4545
todo = none()
4646

47-
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), **kw):
47+
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1),
48+
assert_equal=assert_exactly_equal, **kw):
4849
"""
4950
Test that f(*args, **kw) maps across stacks of matrices
5051
@@ -75,9 +76,9 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
7576
res_stack = res[res_idx]
7677
x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)]
7778
decomp_res_stack = f(*x_stacks, **kw)
78-
assert_exactly_equal(res_stack, decomp_res_stack)
79+
assert_equal(res_stack, decomp_res_stack)
7980
if true_val:
80-
assert_exactly_equal(decomp_res_stack, true_val(*x_stacks))
81+
assert_equal(decomp_res_stack, true_val(*x_stacks))
8182

8283
def _test_namedtuple(res, fields, func_name):
8384
"""

0 commit comments

Comments
 (0)