|
17 | 17 | from hypothesis import assume, given
|
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers,
|
19 | 19 | shared, sampled_from, data, just)
|
| 20 | +from ndindex import iter_indices |
20 | 21 |
|
21 |
| -from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity |
| 22 | +from .array_helpers import assert_exactly_equal, asarray |
22 | 23 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
23 | 24 | square_matrix_shapes, symmetric_matrices,
|
24 | 25 | positive_definite_matrices, MAX_ARRAY_SIZE,
|
|
43 | 44 | # Standin strategy for not yet implemented tests
|
44 | 45 | todo = none()
|
45 | 46 |
|
46 |
| -def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): |
| 47 | +def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), |
| 48 | + assert_equal=assert_exactly_equal, **kw): |
47 | 49 | """
|
48 | 50 | Test that f(*args, **kw) maps across stacks of matrices
|
49 | 51 |
|
50 |
| - dims is the number of dimensions f should have for a single n x m matrix |
51 |
| - stack. |
| 52 | + dims is the number of dimensions f(*args) should have for a single n x m |
| 53 | + matrix stack. |
| 54 | +
|
| 55 | + matrix_axes are the axes along which matrices (or vectors) are stacked in |
| 56 | + the input. |
| 57 | +
|
| 58 | + true_val may be a function such that true_val(*x_stacks, **kw) gives the |
| 59 | + true value for f on a stack. |
| 60 | +
|
| 61 | + res should be the result of f(*args, **kw). It is computed if not passed |
| 62 | + in. |
52 | 63 |
|
53 |
| - true_val may be a function such that true_val(*x_stacks) gives the true |
54 |
| - value for f on a stack |
55 | 64 | """
|
56 | 65 | if res is None:
|
57 | 66 | res = f(*args, **kw)
|
58 | 67 |
|
59 |
| - shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape |
60 |
| - for x in args]) |
61 |
| - for _idx in sh.ndindex(shape[:-2]): |
62 |
| - idx = _idx + (slice(None),)*dims |
63 |
| - res_stack = res[idx] |
64 |
| - x_stacks = [x[_idx + (...,)] for x in args] |
| 68 | + shapes = [x.shape for x in args] |
| 69 | + |
| 70 | + for (x_idxes, (res_idx,)) in zip( |
| 71 | + iter_indices(*shapes, skip_axes=matrix_axes), |
| 72 | + iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): |
| 73 | + x_idxes = [x_idx.raw for x_idx in x_idxes] |
| 74 | + res_idx = res_idx.raw |
| 75 | + |
| 76 | + res_stack = res[res_idx] |
| 77 | + x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] |
65 | 78 | decomp_res_stack = f(*x_stacks, **kw)
|
66 |
| - assert_exactly_equal(res_stack, decomp_res_stack) |
| 79 | + assert_equal(res_stack, decomp_res_stack) |
67 | 80 | if true_val:
|
68 |
| - assert_exactly_equal(decomp_res_stack, true_val(*x_stacks)) |
| 81 | + assert_equal(decomp_res_stack, true_val(*x_stacks)) |
69 | 82 |
|
70 | 83 | def _test_namedtuple(res, fields, func_name):
|
71 | 84 | """
|
@@ -452,10 +465,12 @@ def test_slogdet(x):
|
452 | 465 |
|
453 | 466 | # Check that when the determinant is 0, the sign and logabsdet are (0,
|
454 | 467 | # -inf).
|
455 |
| - d = linalg.det(x) |
456 |
| - zero_det = equal(d, zero(d.shape, d.dtype)) |
457 |
| - assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) |
458 |
| - assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) |
| 468 | + # TODO: This test does not necessarily hold exactly. Update it to test it |
| 469 | + # approximately. |
| 470 | + # d = linalg.det(x) |
| 471 | + # zero_det = equal(d, zero(d.shape, d.dtype)) |
| 472 | + # assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) |
| 473 | + # assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) |
459 | 474 |
|
460 | 475 | # More generally, det(x) should equal sign*exp(logabsdet), but this does
|
461 | 476 | # not hold exactly due to floating-point loss of precision.
|
@@ -614,7 +629,7 @@ def true_trace(x_stack):
|
614 | 629 |
|
615 | 630 | @given(
|
616 | 631 | dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes),
|
617 |
| - shape=shapes(), |
| 632 | + shape=shapes(min_dims=1), |
618 | 633 | data=data(),
|
619 | 634 | )
|
620 | 635 | def test_vecdot(dtypes, shape, data):
|
|
0 commit comments