Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tridiagonal solve in numba backend #1311

Merged
merged 2 commits into from
Mar 27, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 21, 2025

Also refactors the lapack codegen


📚 Documentation preview 📚: https://pytensor--1311.org.readthedocs.build/en/1311/

@ricardoV94 ricardoV94 added enhancement New feature or request numba linalg Linear algebra labels Mar 21, 2025
@ricardoV94 ricardoV94 force-pushed the numba_tridiagonal_impl branch 5 times, most recently from 4cc3474 to abd00fb Compare March 21, 2025 14:56
Copy link

codecov bot commented Mar 21, 2025

Codecov Report

Attention: Patch coverage is 50.22624% with 330 lines in your changes missing coverage. Please review.

Project coverage is 81.96%. Comparing base (8a7356c) to head (0757466).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
...tensor/link/numba/dispatch/linalg/solve/general.py 44.54% 61 Missing ⚠️
...nsor/link/numba/dispatch/linalg/solve/symmetric.py 43.33% 51 Missing ⚠️
...ytensor/link/numba/dispatch/linalg/solve/posdef.py 46.98% 44 Missing ⚠️
...or/link/numba/dispatch/linalg/solve/tridiagonal.py 58.25% 43 Missing ⚠️
pytensor/link/numba/dispatch/linalg/utils.py 42.62% 29 Missing and 6 partials ⚠️
...sor/link/numba/dispatch/linalg/solve/triangular.py 40.00% 29 Missing and 1 partial ⚠️
...ensor/link/numba/dispatch/linalg/solve/cholesky.py 43.18% 25 Missing ⚠️
...nk/numba/dispatch/linalg/decomposition/cholesky.py 47.22% 19 Missing ⚠️
pytensor/link/numba/dispatch/linalg/solve/norm.py 61.53% 10 Missing ⚠️
pytensor/link/numba/dispatch/slinalg.py 77.77% 3 Missing and 5 partials ⚠️
... and 1 more

❌ Your patch status has failed because the patch coverage (50.22%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1311      +/-   ##
==========================================
- Coverage   81.98%   81.96%   -0.03%     
==========================================
  Files         188      198      +10     
  Lines       48489    48672     +183     
  Branches     8673     8677       +4     
==========================================
+ Hits        39756    39896     +140     
- Misses       6582     6625      +43     
  Partials     2151     2151              
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 78.50% <ø> (ø)
pytensor/link/numba/dispatch/linalg/_LAPACK.py 79.28% <100.00%> (ø)
pytensor/tensor/slinalg.py 93.40% <100.00%> (+0.01%) ⬆️
pytensor/link/numba/dispatch/linalg/solve/utils.py 50.00% <50.00%> (ø)
pytensor/link/numba/dispatch/slinalg.py 69.10% <77.77%> (+25.01%) ⬆️
pytensor/link/numba/dispatch/linalg/solve/norm.py 61.53% <61.53%> (ø)
...nk/numba/dispatch/linalg/decomposition/cholesky.py 47.22% <47.22%> (ø)
...ensor/link/numba/dispatch/linalg/solve/cholesky.py 43.18% <43.18%> (ø)
...sor/link/numba/dispatch/linalg/solve/triangular.py 40.00% <40.00%> (ø)
pytensor/link/numba/dispatch/linalg/utils.py 42.62% <42.62%> (ø)
... and 4 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I like the reorganization. We should take it as a template to refactor tensor.linalg

@@ -74,7 +74,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
'"(numba_funcified_fgraph|store_core_outputs|nb_cholesky)" '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cholesky specifically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was emitting this warning in a test, strangely the other lapack functions were not? Gotta double check that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume none of them cache. I thought the tests all had a pytest.filterwarnings decorator on them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should filter the warnings we know of and can't do anything about (even less the users who get them)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it was the case that the warning was only making it fail with the cholesky, because we had a pytest.mark.filterwarnings("error") in the blockwise test and we tested a blockwise with cholesky. Instead of ignorning I am also making it fail if warnings are emitted on the slinalg test, and thus I am filtering those other Ops here as well.

Comment on lines 272 to 277
# Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
anorm = np.abs(d)
anorm[1:] += np.abs(du)
anorm[:-1] += np.abs(dl)
anorm = anorm.max()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this should be split out to a separate function

Copy link
Member Author

@ricardoV94 ricardoV94 Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but numba loves compiling nested functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there's a performance advantage to keeping it inlined, keep it inlined.

I was thinking if you break it out, it fits the template of the other solves with 3 functions: norm, decomp, and solve

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the norm/warnings at all, but that's a separate matter.

I can split, the only concern I could have is compile time but that's not why I inlined it. I'll do a quick benchmark to see if it's warranted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split it, if anything compile times seem a bit better, probably because at least that inner function can be cached.

@ricardoV94 ricardoV94 force-pushed the numba_tridiagonal_impl branch from abd00fb to 278996e Compare March 27, 2025 11:01
@ricardoV94 ricardoV94 force-pushed the numba_tridiagonal_impl branch from 278996e to 0757466 Compare March 27, 2025 11:08
@ricardoV94 ricardoV94 merged commit a038c8e into pymc-devs:main Mar 27, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request linalg Linear algebra numba
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants