Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradient for QR decomposition #1303

Merged
merged 7 commits into from
Mar 27, 2025

Conversation

educhesne
Copy link
Contributor

@educhesne educhesne commented Mar 19, 2025

Description

Implement the support for gradient in the QR decomposition. It is only defined when R is a square matrix.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1303.org.readthedocs.build/en/1303/

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great! The most important change that we need here is to use solve_triangular instead of inverse.

If you're willing, I would also like to move the QR function from nlinalg to slinalg, and use the scipy signature in the perform method. If you're not that's fine, and we can do that in a future PR.

Comment on lines 553 to 554
Rinvt = _H(inv(R))
A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a solve here instead of explicitly inverting R. You can use solve_triangular to exploit the structure of R, and set trans=1 instead of inverting R at all.

@educhesne
Copy link
Contributor Author

Thanks for the quick review !
I cannot import solve_triangular in pt.nlinalg (it creates circular dependencies), so I'll move qr from nlinalg to slinalg directly.
The scipy signature is different, in particular mode values change (from complete, reduced to full, economic) and the default mode is complete/full rather than reduced/economic; should qr reflects the scipy signature ? or keep the numpy one ?

@jessegrabowski
Copy link
Member

jessegrabowski commented Mar 19, 2025

to import it without circular, use from pytensor.tensor.slinalg import solve_triangular inside the gradient function

Forget I brought up changing the signature, this PR is already amazing so let's just focus on gradients here.

Copy link

codecov bot commented Mar 19, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.00%. Comparing base (a149f6c) to head (4edc698).
Report is 12 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1303      +/-   ##
==========================================
- Coverage   82.00%   82.00%   -0.01%     
==========================================
  Files         188      188              
  Lines       48478    48530      +52     
  Branches     8665     8678      +13     
==========================================
+ Hits        39755    39797      +42     
- Misses       6575     6582       +7     
- Partials     2148     2151       +3     
Files with missing lines Coverage Δ
pytensor/tensor/nlinalg.py 95.63% <100.00%> (+0.36%) ⬆️

... and 6 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@educhesne
Copy link
Contributor Author

ok thanks ! I didn't use the trans argument of solve_triangular though; would that simplify the expression ? (it works anyway)

@jessegrabowski
Copy link
Member

No worries about the trans argument, I think I misunderstood what was needed.

This is really great! Thanks for your contribution. It looks nice to me. Going to ask @ricardoV94 for a final sign off then merge.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!

Just some small change requests

if self.mode == "raw" or (self.mode == "complete" and m != n):
raise NotImplementedError("Gradient of qr not implemented")

elif m < n:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient should work for non-static shapes, A = matrix("A", shape=(None, None)). If you need the shape check add it symbolically with assert_op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the shape checks with Assert on the output shape. Let me know if that's what you had in mind, I'm not very familiar with symbolic shapes

return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1))

if self.mode == "raw" or (self.mode == "complete" and m != n):
raise NotImplementedError("Gradient of qr not implemented")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention what exactly makes it non-implemented

@ricardoV94 ricardoV94 added gradients linalg Linear algebra labels Mar 21, 2025
Etienne Duchesne added 2 commits March 21, 2025 17:03
* for mode=reduced or mode=r, all input shapes are accepted
* for mode=complete, shapes m x n where m <= n are accepted
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@educhesne
Copy link
Contributor Author

I'm really sorry, wait a bit before merging, I may have mixed conjugate transpose and transpose. The tests are only done with real matrix, so that wouldn't be caught.

@jessegrabowski
Copy link
Member

Can you try adding a complex test? You can follow the template here. Those tests don't work because numba is fussy. I think this case won't work either because we don''t support complex gradients, but you can still add the tests and xfail them (it's aspirational that way!)

@educhesne
Copy link
Contributor Author

Ok, done.
I've rechecked the conjugate/transpose thing and I am more confident it is correct now, sorry I have overreacted...

@jessegrabowski
Copy link
Member

Did you check if they actually do fail before you xfailed them? I'm not 100% sure on that

@educhesne
Copy link
Contributor Author

educhesne commented Mar 26, 2025

yes I did, verify_grad doesn't accept complex dtype

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! This is gold

@ricardoV94 ricardoV94 added the enhancement New feature or request label Mar 27, 2025
@ricardoV94 ricardoV94 merged commit 2774599 into pymc-devs:main Mar 27, 2025
73 checks passed
@educhesne educhesne deleted the qr_gradient branch March 28, 2025 07:53
@educhesne
Copy link
Contributor Author

Actually @jessegrabowski even pytensor.gradient.grad doesn't accept complex dtypes.
I found old discussions from the Theano era about it, but nothing else (in particular nothing about choosing the convention for complex gradients, between the jax/holomorphic compatible convention vs the torch/gradient descent compatible one)
Do you know if it is something planned or considered ?

@jessegrabowski
Copy link
Member

There's been some vague conversations about it, but nothing concrete has been planned. It needs a dev to come along and take up the banner to make it happen

@educhesne
Copy link
Contributor Author

Ok thanks. I could help, but being new to pytensor dev I don't know about being the flag-bearer...

@jessegrabowski
Copy link
Member

If you're excited about it, or you need it for you work, you should definitely hack around and see if you can put a prototype, and open a draft PR. If you build it, they will come :)

@educhesne
Copy link
Contributor Author

Ok, I'll see... :)
(I don't need it, but that looks interesting...)

@jessegrabowski
Copy link
Member

Are you a student interested in doing GSoC? We could make a project like that if you are and it's something you're interested in. Or we can brainstorm something else.

@educhesne
Copy link
Contributor Author

Thanks but I am not a student :-D

@jessegrabowski
Copy link
Member

Ah ok! Sorry I just assumed. We get a lot of activity around this time of year from students looking to participate in GSoC. Anyway still happy to brainstorm on projects/features you'd like to see in pytensor (if you're looking for stuff to do) :)

@ricardoV94
Copy link
Member

Thanks but I am not a student :-D

fyi GSOC is no longer restricted to students since last year, it now also accommodates people who are beginning to contribute to OSS in general

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request gradients linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement gradient for QR decomposition
3 participants