-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement gradient for QR decomposition #1303
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really great! The most important change that we need here is to use solve_triangular
instead of inverse.
If you're willing, I would also like to move the QR function from nlinalg
to slinalg
, and use the scipy signature in the perform method. If you're not that's fine, and we can do that in a future PR.
pytensor/tensor/nlinalg.py
Outdated
Rinvt = _H(inv(R)) | ||
A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a solve here instead of explicitly inverting R
. You can use solve_triangular
to exploit the structure of R, and set trans=1
instead of inverting R at all.
Thanks for the quick review ! |
to import it without circular, use Forget I brought up changing the signature, this PR is already amazing so let's just focus on gradients here. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1303 +/- ##
==========================================
- Coverage 82.00% 82.00% -0.01%
==========================================
Files 188 188
Lines 48478 48530 +52
Branches 8665 8678 +13
==========================================
+ Hits 39755 39797 +42
- Misses 6575 6582 +7
- Partials 2148 2151 +3
🚀 New features to boost your workflow:
|
ok thanks ! I didn't use the trans argument of solve_triangular though; would that simplify the expression ? (it works anyway) |
No worries about the This is really great! Thanks for your contribution. It looks nice to me. Going to ask @ricardoV94 for a final sign off then merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome!
Just some small change requests
pytensor/tensor/nlinalg.py
Outdated
if self.mode == "raw" or (self.mode == "complete" and m != n): | ||
raise NotImplementedError("Gradient of qr not implemented") | ||
|
||
elif m < n: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gradient should work for non-static shapes, A = matrix("A", shape=(None, None))
. If you need the shape check add it symbolically with assert_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the shape checks with Assert
on the output shape. Let me know if that's what you had in mind, I'm not very familiar with symbolic shapes
pytensor/tensor/nlinalg.py
Outdated
return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) | ||
|
||
if self.mode == "raw" or (self.mode == "complete" and m != n): | ||
raise NotImplementedError("Gradient of qr not implemented") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention what exactly makes it non-implemented
* for mode=reduced or mode=r, all input shapes are accepted * for mode=complete, shapes m x n where m <= n are accepted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
I'm really sorry, wait a bit before merging, I may have mixed conjugate transpose and transpose. The tests are only done with real matrix, so that wouldn't be caught. |
Can you try adding a complex test? You can follow the template here. Those tests don't work because numba is fussy. I think this case won't work either because we don''t support complex gradients, but you can still add the tests and xfail them (it's aspirational that way!) |
Ok, done. |
Did you check if they actually do fail before you xfailed them? I'm not 100% sure on that |
yes I did, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! This is gold
Actually @jessegrabowski even |
There's been some vague conversations about it, but nothing concrete has been planned. It needs a dev to come along and take up the banner to make it happen |
Ok thanks. I could help, but being new to pytensor dev I don't know about being the flag-bearer... |
If you're excited about it, or you need it for you work, you should definitely hack around and see if you can put a prototype, and open a draft PR. If you build it, they will come :) |
Ok, I'll see... :) |
Are you a student interested in doing GSoC? We could make a project like that if you are and it's something you're interested in. Or we can brainstorm something else. |
Thanks but I am not a student :-D |
Ah ok! Sorry I just assumed. We get a lot of activity around this time of year from students looking to participate in GSoC. Anyway still happy to brainstorm on projects/features you'd like to see in pytensor (if you're looking for stuff to do) :) |
fyi GSOC is no longer restricted to students since last year, it now also accommodates people who are beginning to contribute to OSS in general |
Description
Implement the support for gradient in the QR decomposition. It is only defined when R is a square matrix.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1303.org.readthedocs.build/en/1303/