-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add object mode fallback for Numba RandomVariables #1249
base: main
Are you sure you want to change the base?
Conversation
When we find a RandomVariable that doesn't have a Numba implementation, we now fallback to object mode instead of failing with NotImplementedError. This provides a more graceful degradation path for random variables that don't yet have specialized Numba implementations. - Added rv_fallback_impl function to create object mode implementation - Modified numba_funcify_RandomVariable to catch NotImplementedError - Added test for unsupported random variable fallback 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
Would be good if it referenced the original issues (there's a PR template you could tell it to fill). We shouldn't use it for beginner friendly issues, that's the point of marking them as beginner friendly? Fine if you're just testing. I would be much more excited if it tackled docs issues. Like ask it to fix and finish the PR related to: #292 , #830 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were also some RVs that weren't being tested because we were not falling back to objmode. Test them now
|
||
[rv_node] = op.fgraph.apply_nodes | ||
rv_op: RandomVariable = rv_node.op | ||
|
||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a generic fallback implementation function can't we just use it like we do for other Ops?
May just need to do the unboxing of the RV that the other function is doing
inplace = rv_op.inplace | ||
|
||
try: | ||
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only this line should be in the try except
# Create a mock random variable that doesn't have a numba implementation | ||
class CustomRV(ptr.RandomVariable): | ||
name = "custom" | ||
signature = "(d)->(d)" # We need a parameter for test to pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create a univariate rv which will be a simpler test
x = custom_rv(value, rng=rng) | ||
|
||
# Capture warnings to check for the fallback warning | ||
with warnings.catch_warnings(record=True) as w: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use pytest.warns
# Run again to make sure the compiled function works properly | ||
result2 = fn() | ||
assert isinstance(result2, np.ndarray) | ||
assert not np.array_equal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will fail because updates were not set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually test with and without updates, in which case it should change or stay the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also set seed twice and compare to make sure it's following it
Top post does not include Related to or Closes # Edit: I repeated myself |
Fixes https://github.com/pymc-devs/pytensor/issues/1245\n\nSummary:\n- When a RandomVariable is not implemented in Numba, it now gracefully falls back to object mode.\n- Added tests to verify that unsupported RandomVariables correctly trigger the object mode fallback.\n- This update ensures a smoother degradation experience and improves testing coverage.\n\nCloses #1245\n\nTest Plan:\n- Run the test suite using pytest to ensure no regressions occur.\n\nAcknowledgements:\n- Thanks to ricardoV94 for the feedback and review comments.