Skip to content

Commit df5f6f5

Browse files
committed
Update the field names for the qr() and svd() namedtuples
1 parent 9c9ffe1 commit df5f6f5

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

array_api_tests/test_linalg.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -399,29 +399,29 @@ def test_qr(x, kw):
399399
M, N = x.shape[-2:]
400400
K = min(M, N)
401401

402-
_test_namedtuple(res, ['q', 'r'], 'qr')
403-
q = res.q
404-
r = res.r
402+
_test_namedtuple(res, ['Q', 'R'], 'qr')
403+
Q = res.Q
404+
R = res.R
405405

406-
assert q.dtype == x.dtype, "qr().q did not return the correct dtype"
406+
assert Q.dtype == x.dtype, "qr().Q did not return the correct dtype"
407407
if mode == 'complete':
408-
assert q.shape == x.shape[:-2] + (M, M), "qr().q did not return the correct shape"
408+
assert Q.shape == x.shape[:-2] + (M, M), "qr().Q did not return the correct shape"
409409
else:
410-
assert q.shape == x.shape[:-2] + (M, K), "qr().q did not return the correct shape"
410+
assert Q.shape == x.shape[:-2] + (M, K), "qr().Q did not return the correct shape"
411411

412-
assert r.dtype == x.dtype, "qr().r did not return the correct dtype"
412+
assert R.dtype == x.dtype, "qr().R did not return the correct dtype"
413413
if mode == 'complete':
414-
assert r.shape == x.shape[:-2] + (M, N), "qr().r did not return the correct shape"
414+
assert R.shape == x.shape[:-2] + (M, N), "qr().R did not return the correct shape"
415415
else:
416-
assert r.shape == x.shape[:-2] + (K, N), "qr().r did not return the correct shape"
416+
assert R.shape == x.shape[:-2] + (K, N), "qr().R did not return the correct shape"
417417

418-
_test_stacks(lambda x: linalg.qr(x, **kw).q, x, res=q)
419-
_test_stacks(lambda x: linalg.qr(x, **kw).r, x, res=r)
418+
_test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q)
419+
_test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R)
420420

421-
# TODO: Test that q is orthonormal
421+
# TODO: Test that Q is orthonormal
422422

423-
# Check that r is upper-triangular.
424-
assert_exactly_equal(r, _array_module.triu(r))
423+
# Check that R is upper-triangular.
424+
assert_exactly_equal(R, _array_module.triu(R))
425425

426426
@pytest.mark.xp_extension('linalg')
427427
@given(
@@ -506,29 +506,29 @@ def test_svd(x, kw):
506506
*stack, M, N = x.shape
507507
K = min(M, N)
508508

509-
_test_namedtuple(res, ['u', 's', 'vh'], 'svd')
509+
_test_namedtuple(res, ['U', 'S', 'Vh'], 'svd')
510510

511-
u, s, vh = res
511+
U, S, Vh = res
512512

513-
assert u.dtype == x.dtype, "svd().u did not return the correct dtype"
514-
assert s.dtype == x.dtype, "svd().s did not return the correct dtype"
515-
assert vh.dtype == x.dtype, "svd().vh did not return the correct dtype"
513+
assert U.dtype == x.dtype, "svd().U did not return the correct dtype"
514+
assert S.dtype == x.dtype, "svd().S did not return the correct dtype"
515+
assert Vh.dtype == x.dtype, "svd().Vh did not return the correct dtype"
516516

517517
if full_matrices:
518-
assert u.shape == (*stack, M, M), "svd().u did not return the correct shape"
519-
assert vh.shape == (*stack, N, N), "svd().vh did not return the correct shape"
518+
assert U.shape == (*stack, M, M), "svd().U did not return the correct shape"
519+
assert Vh.shape == (*stack, N, N), "svd().Vh did not return the correct shape"
520520
else:
521-
assert u.shape == (*stack, M, K), "svd(full_matrices=False).u did not return the correct shape"
522-
assert vh.shape == (*stack, K, N), "svd(full_matrices=False).vh did not return the correct shape"
523-
assert s.shape == (*stack, K), "svd().s did not return the correct shape"
521+
assert U.shape == (*stack, M, K), "svd(full_matrices=False).U did not return the correct shape"
522+
assert Vh.shape == (*stack, K, N), "svd(full_matrices=False).Vh did not return the correct shape"
523+
assert S.shape == (*stack, K), "svd().S did not return the correct shape"
524524

525525
# The values of s must be sorted from largest to smallest
526526
if K >= 1:
527-
assert _array_module.all(s[..., :-1] >= s[..., 1:]), "svd().s values are not sorted from largest to smallest"
527+
assert _array_module.all(S[..., :-1] >= S[..., 1:]), "svd().S values are not sorted from largest to smallest"
528528

529-
_test_stacks(lambda x: linalg.svd(x, **kw).u, x, res=u)
530-
_test_stacks(lambda x: linalg.svd(x, **kw).s, x, dims=1, res=s)
531-
_test_stacks(lambda x: linalg.svd(x, **kw).vh, x, res=vh)
529+
_test_stacks(lambda x: linalg.svd(x, **kw).U, x, res=U)
530+
_test_stacks(lambda x: linalg.svd(x, **kw).S, x, dims=1, res=S)
531+
_test_stacks(lambda x: linalg.svd(x, **kw).Vh, x, res=Vh)
532532

533533
@pytest.mark.xp_extension('linalg')
534534
@given(

0 commit comments

Comments
 (0)