-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement gradient for QR decomposition #1303
Changes from 2 commits
a187464
ee9aaa2
0e47b7d
9e5e765
ac48c11
a6ae03b
4edc698
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -512,6 +512,80 @@ def perform(self, node, inputs, outputs): | |
else: | ||
outputs[0][0] = res | ||
|
||
def L_op(self, inputs, outputs, output_grads): | ||
""" | ||
Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here: | ||
https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803 | ||
|
||
And from ..[2] which describes a solution in the square matrix case. | ||
|
||
References | ||
---------- | ||
.. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018) | ||
.. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives | ||
of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation", | ||
Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454 | ||
""" | ||
|
||
from pytensor.tensor.slinalg import solve_triangular | ||
|
||
(A,) = (cast(ptb.TensorVariable, x) for x in inputs) | ||
*_, m, n = A.type.shape | ||
|
||
def _H(x: ptb.TensorVariable): | ||
return x.conj().mT | ||
|
||
def _copyutl(x: ptb.TensorVariable): | ||
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) | ||
|
||
if self.mode == "raw" or (self.mode == "complete" and m != n): | ||
raise NotImplementedError("Gradient of qr not implemented") | ||
|
||
elif m < n: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gradient should work for non-static shapes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I replaced the shape checks with |
||
raise NotImplementedError( | ||
"Gradient of qr not implemented for m x n matrices with m < n" | ||
) | ||
|
||
elif self.mode == "r": | ||
# We need all the components of the QR to compute the gradient of A even if we only | ||
# use the upper triangular component in the cost function. | ||
Q, R = qr(A, mode="reduced") | ||
dR = cast(ptb.TensorVariable, output_grads[0]) | ||
R_dRt = R @ _H(dR) | ||
M = ptb.tril(R_dRt - _H(R_dRt), k=-1) | ||
M_Rinvt = _H(solve_triangular(R, _H(M))) | ||
A_bar = Q @ (M_Rinvt + dR) | ||
return [A_bar] | ||
|
||
else: | ||
Q, R = (cast(ptb.TensorVariable, x) for x in outputs) | ||
|
||
new_output_grads = [] | ||
is_disconnected = [ | ||
isinstance(x.type, DisconnectedType) for x in output_grads | ||
] | ||
if all(is_disconnected): | ||
# This should never be reached by Pytensor | ||
return [DisconnectedType()()] # pragma: no cover | ||
|
||
for disconnected, output_grad, output in zip( | ||
is_disconnected, output_grads, [Q, R], strict=True | ||
): | ||
if disconnected: | ||
new_output_grads.append(output.zeros_like()) | ||
else: | ||
new_output_grads.append(output_grad) | ||
|
||
(dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) | ||
|
||
Qt_dQ = _H(Q) @ dQ | ||
R_dRt = R @ _H(dR) | ||
M = Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ | ||
M_Rinvt = _H(solve_triangular(R, _H(M))) | ||
A_bar = M_Rinvt + Q @ dR | ||
|
||
return [A_bar] | ||
|
||
|
||
def qr(a, mode="reduced"): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention what exactly makes it non-implemented