-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement tridiagonal solve in numba backend #1311
Conversation
4cc3474
to
abd00fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I like the reorganization. We should take it as a template to refactor tensor.linalg
@@ -74,7 +74,7 @@ def numba_njit(*args, fastmath=None, **kwargs): | |||
message=( | |||
"(\x1b\\[1m)*" # ansi escape code for bold text | |||
"Cannot cache compiled function " | |||
'"(numba_funcified_fgraph|store_core_outputs)" ' | |||
'"(numba_funcified_fgraph|store_core_outputs|nb_cholesky)" ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cholesky specifically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was emitting this warning in a test, strangely the other lapack functions were not? Gotta double check that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume none of them cache. I thought the tests all had a pytest.filterwarnings decorator on them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should filter the warnings we know of and can't do anything about (even less the users who get them)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it was the case that the warning was only making it fail with the cholesky, because we had a pytest.mark.filterwarnings("error")
in the blockwise test and we tested a blockwise with cholesky. Instead of ignorning I am also making it fail if warnings are emitted on the slinalg test, and thus I am filtering those other Ops here as well.
# Adapted from scipy _matrix_norm_tridiagonal: | ||
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367 | ||
anorm = np.abs(d) | ||
anorm[1:] += np.abs(du) | ||
anorm[:-1] += np.abs(dl) | ||
anorm = anorm.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this should be split out to a separate function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but numba loves compiling nested functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if there's a performance advantage to keeping it inlined, keep it inlined.
I was thinking if you break it out, it fits the template of the other solves with 3 functions: norm, decomp, and solve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love the norm/warnings at all, but that's a separate matter.
I can split, the only concern I could have is compile time but that's not why I inlined it. I'll do a quick benchmark to see if it's warranted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split it, if anything compile times seem a bit better, probably because at least that inner function can be cached.
abd00fb
to
278996e
Compare
278996e
to
0757466
Compare
Also refactors the lapack codegen
📚 Documentation preview 📚: https://pytensor--1311.org.readthedocs.build/en/1311/