-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement gradient for QR decomposition #1099
Comments
I went for the backward-mode version of the solution implemented in jax. So the gradient is only valid when R is square and full-rank.
This solution requires R to be square and full-rank for the gradient to be valid.
As far as I can tell, this solution only works on square input (based on some verify_grad I ran; I can't really pinpoint which part of the derivation assumes it...)
I couldn't really make sense of this derivation... maybe there are some typos in it ?
I didn't do it, although if you want I can add LQ as a QR decomposition with a couple of transpose |
I the past I always refer to this blog. will it help https://giggleliu.github.io/posts/2019-04-02-einsumbp/? |
Very nice reference! Maybe the einsum rule could improve on what we're doing now (which I think just autodiffing through whatever graph einsum comes up with) |
The derivation is much clearer in that blog ! Thanks ! Besides in the paper they reference, there is an expression for QR gradient when m<n (where QR is a m x n matrix); that was not covered in the previous references. I'm going to look into it. |
I updated the PR with the gradient expression covering the others shapes from that paper (I think there is actually a tiny mistake in it, but I reckon it is corrected in the PR) |
Description
QR is one of the few remaining linalg ops that is missing a gradient. JAX code for the jvp is here, whic also includes this derivation. This paper also claims to derive the gradients for QR, but I find it unreadable.
Relatedly but perhaps worthy of a separate issue, this paper derives gradients for the LQ decomposition,$$A = LQ$$ , where $L$ is lower triangular and $Q$ is orthonormal ($$Q^TQ=I$$ .) Compare this to QR, which gives you $$A = QR$$ , where $$Q$$ is again orthonormal, but $$R$$ is upper triangular, and you see why I mention it in this issue. It wouldn't be hard to offer LQ as well.
The text was updated successfully, but these errors were encountered: