diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index ee33f6533c..8fff2a2f59 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -5,12 +5,15 @@ import numpy as np +import pytensor.tensor as pt from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.ifelse import ifelse from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.raise_op import Assert from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm @@ -512,6 +515,80 @@ def perform(self, node, inputs, outputs): else: outputs[0][0] = res + def L_op(self, inputs, outputs, output_grads): + """ + Reverse-mode gradient of the QR function. + + References + ---------- + .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ + .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 + """ + + from pytensor.tensor.slinalg import solve_triangular + + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) + m, n = A.shape + + def _H(x: ptb.TensorVariable): + return x.conj().mT + + def _copyltu(x: ptb.TensorVariable): + return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) + + if self.mode == "raw": + raise NotImplementedError("Gradient of qr not implemented for mode=raw") + + elif self.mode == "r": + # We need all the components of the QR to compute the gradient of A even if we only + # use the upper triangular component in the cost function. + Q, R = qr(A, mode="reduced") + dQ = Q.zeros_like() + dR = cast(ptb.TensorVariable, output_grads[0]) + + else: + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) + if self.mode == "complete": + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" + ) + R = qr_assert_op(R, ptm.le(m, n)) + + new_output_grads = [] + is_disconnected = [ + isinstance(x.type, DisconnectedType) for x in output_grads + ] + if all(is_disconnected): + # This should never be reached by Pytensor + return [DisconnectedType()()] # pragma: no cover + + for disconnected, output_grad, output in zip( + is_disconnected, output_grads, [Q, R], strict=True + ): + if disconnected: + new_output_grads.append(output.zeros_like()) + else: + new_output_grads.append(output_grad) + + (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) + + # gradient expression when m >= n + M = R @ _H(dR) - _H(dQ) @ Q + K = dQ + Q @ _copyltu(M) + A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) + + # gradient expression when m < n + Y = A[:, m:] + U = R[:, :m] + dU, dV = dR[:, :m], dR[:, m:] + dQ_Yt_dV = dQ + Y @ _H(dV) + M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q + X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) + Y_bar = Q @ dV + A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) + + return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] + def qr(a, mode="reduced"): """ diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 4b83446c5f..c8ae3ac4cb 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -152,6 +152,72 @@ def test_qr_modes(): assert "name 'complete' is not defined" in str(e) +@pytest.mark.parametrize( + "shape, gradient_test_case, mode", + ( + [(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]] + + [((3, 3), 0, "raw")] + ), + ids=( + [ + f"shape={s}, gradient_test_case={c}, mode=reduced" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [ + f"shape={s}, gradient_test_case={c}, mode=complete" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]] + + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] + ), +) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) +def test_qr_grad(shape, gradient_test_case, mode, is_complex): + rng = np.random.default_rng(utt.fetch_seed()) + + def _test_fn(x, case=2, mode="reduced"): + if case == 0: + return qr(x, mode=mode)[0].sum() + elif case == 1: + return qr(x, mode=mode)[1].sum() + elif case == 2: + Q, R = qr(x, mode=mode) + return Q.sum() + R.sum() + + if is_complex: + pytest.xfail("Complex inputs currently not supported by verify_grad") + + m, n = shape + a = rng.standard_normal(shape).astype(config.floatX) + if is_complex: + a += 1j * rng.standard_normal(shape).astype(config.floatX) + + if mode == "raw": + with pytest.raises(NotImplementedError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) + + elif mode == "complete" and m > n: + with pytest.raises(AssertionError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) + + else: + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random + ) + + class TestSvd(utt.InferShapeTester): op_class = SVD