Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradient for QR decomposition #1099

Closed
jessegrabowski opened this issue Nov 22, 2024 · 5 comments · Fixed by #1303
Closed

Implement gradient for QR decomposition #1099

jessegrabowski opened this issue Nov 22, 2024 · 5 comments · Fixed by #1303

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Nov 22, 2024

Description

QR is one of the few remaining linalg ops that is missing a gradient. JAX code for the jvp is here, whic also includes this derivation. This paper also claims to derive the gradients for QR, but I find it unreadable.

Relatedly but perhaps worthy of a separate issue, this paper derives gradients for the LQ decomposition, $$A = LQ$$, where $L$ is lower triangular and $Q$ is orthonormal ($$Q^TQ=I$$.) Compare this to QR, which gives you $$A = QR$$, where $$Q$$ is again orthonormal, but $$R$$ is upper triangular, and you see why I mention it in this issue. It wouldn't be hard to offer LQ as well.

@educhesne
Copy link
Contributor

I went for the backward-mode version of the solution implemented in jax. So the gradient is only valid when R is square and full-rank.
A note about the implementation: at some point the inverse of R is needed, and I used matrix_inverse rather than solve_triangular because I assumed that no dependency on scipy was allowed in nlinalg; was I correct ?
(first PR to pytensor here; any comment appreciated)

JAX code for the jvp is here, whic also includes this derivation.

This solution requires R to be square and full-rank for the gradient to be valid.

This paper also claims to derive the gradients for QR, but I find it unreadable.

As far as I can tell, this solution only works on square input (based on some verify_grad I ran; I can't really pinpoint which part of the derivation assumes it...)

Relatedly but perhaps worthy of a separate issue, this paper derives gradients for the LQ decomposition

I couldn't really make sense of this derivation... maybe there are some typos in it ?

It wouldn't be hard to offer LQ as well.

I didn't do it, although if you want I can add LQ as a QR decomposition with a couple of transpose

@qiyang-ustc
Copy link

I the past I always refer to this blog. will it help https://giggleliu.github.io/posts/2019-04-02-einsumbp/?

@jessegrabowski
Copy link
Member Author

Very nice reference! Maybe the einsum rule could improve on what we're doing now (which I think just autodiffing through whatever graph einsum comes up with)

@educhesne
Copy link
Contributor

educhesne commented Mar 22, 2025

The derivation is much clearer in that blog ! Thanks ! Besides in the paper they reference, there is an expression for QR gradient when m<n (where QR is a m x n matrix); that was not covered in the previous references. I'm going to look into it.

@educhesne
Copy link
Contributor

I updated the PR with the gradient expression covering the others shapes from that paper (I think there is actually a tiny mistake in it, but I reckon it is corrected in the PR)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants