Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add object mode fallback for Numba RandomVariables #1249

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

twiecki
Copy link
Member

@twiecki twiecki commented Feb 26, 2025

Fixes https://github.com/pymc-devs/pytensor/issues/1245\n\nSummary:\n- When a RandomVariable is not implemented in Numba, it now gracefully falls back to object mode.\n- Added tests to verify that unsupported RandomVariables correctly trigger the object mode fallback.\n- This update ensures a smoother degradation experience and improves testing coverage.\n\nCloses #1245\n\nTest Plan:\n- Run the test suite using pytest to ensure no regressions occur.\n\nAcknowledgements:\n- Thanks to ricardoV94 for the feedback and review comments.

When we find a RandomVariable that doesn't have a Numba implementation,
we now fallback to object mode instead of failing with NotImplementedError.

This provides a more graceful degradation path for random variables
that don't yet have specialized Numba implementations.

- Added rv_fallback_impl function to create object mode implementation
- Modified numba_funcify_RandomVariable to catch NotImplementedError
- Added test for unsupported random variable fallback

🤖 Generated with Claude Code
Co-Authored-By: Claude <[email protected]>
@ricardoV94
Copy link
Member

ricardoV94 commented Feb 26, 2025

Would be good if it referenced the original issues (there's a PR template you could tell it to fill). We shouldn't use it for beginner friendly issues, that's the point of marking them as beginner friendly? Fine if you're just testing.

I would be much more excited if it tackled docs issues. Like ask it to fix and finish the PR related to: #292 , #830

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were also some RVs that weren't being tested because we were not falling back to objmode. Test them now


[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op

warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a generic fallback implementation function can't we just use it like we do for other Ops?

May just need to do the unboxing of the RV that the other function is doing

inplace = rv_op.inplace

try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only this line should be in the try except

# Create a mock random variable that doesn't have a numba implementation
class CustomRV(ptr.RandomVariable):
name = "custom"
signature = "(d)->(d)" # We need a parameter for test to pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a univariate rv which will be a simpler test

x = custom_rv(value, rng=rng)

# Capture warnings to check for the fallback warning
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pytest.warns

# Run again to make sure the compiled function works properly
result2 = fn()
assert isinstance(result2, np.ndarray)
assert not np.array_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail because updates were not set

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually test with and without updates, in which case it should change or stay the same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also set seed twice and compare to make sure it's following it

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 26, 2025

Top post does not include Related to or Closes #

Edit: I repeated myself

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Numba RandomVariables should fallback to object mode
2 participants