From 5f729a3fccd60d5e4c02d11a158796f5318775a1 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 26 Feb 2025 18:35:17 +0800 Subject: [PATCH] Add object mode fallback for Numba RandomVariables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When we find a RandomVariable that doesn't have a Numba implementation, we now fallback to object mode instead of failing with NotImplementedError. This provides a more graceful degradation path for random variables that don't yet have specialized Numba implementations. - Added rv_fallback_impl function to create object mode implementation - Modified numba_funcify_RandomVariable to catch NotImplementedError - Added test for unsupported random variable fallback 🤖 Generated with Claude Code Co-Authored-By: Claude --- pytensor/link/numba/dispatch/random.py | 123 ++++++++++++++++++------- tests/link/numba/test_random.py | 49 ++++++++++ 2 files changed, 140 insertions(+), 32 deletions(-) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 50c9bf4578..2182995d70 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -386,50 +386,45 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs): ) -@numba_funcify.register -def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): - core_shape = node.inputs[0] +def rv_fallback_impl(op: RandomVariableWithCoreShape, node): + """Create a fallback implementation for random variables using object mode.""" + import warnings [rv_node] = op.fgraph.apply_nodes rv_op: RandomVariable = rv_node.op + + warnings.warn( + f"Numba will use object mode to execute the random variable {rv_op.name}", + UserWarning, + ) + size = rv_op.size_param(rv_node) - dist_params = rv_op.dist_params(rv_node) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) - core_shape_len = get_vector_length(core_shape) inplace = rv_op.inplace - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) - nin = 1 + len(dist_params) # rng + params - core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) - - batch_ndim = rv_op.batch_ndim(rv_node) - - # numba doesn't support nested literals right now... - input_bc_patterns = encode_literals( - tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params) - ) - output_bc_patterns = encode_literals( - (rv_node.outputs[1].type.broadcastable[:batch_ndim],) - ) - output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) - inplace_pattern = encode_literals(()) - def random_wrapper(core_shape, rng, size, *dist_params): if not inplace: rng = copy(rng) - draws = _vectorized( - core_op_fn, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, - (rng,), - dist_params, - (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), - None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), + fixed_size = ( + None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len) ) - return rng, draws + + with numba.objmode(res="UniTuple(types.npy_rng, types.pyobject)"): + # Convert tuple params back to arrays for perform method + np_dist_params = [np.asarray(p) for p in dist_params] + + # Prepare output storage for perform method + outputs = [[None], [None]] + + # Call the perform method directly + rv_op.perform(rv_node, [rng, fixed_size, *np_dist_params], outputs) + + next_rng = outputs[0][0] + result = outputs[1][0] + res = (next_rng, result) + + return res def random(core_shape, rng, size, *dist_params): raise NotImplementedError("Non-jitted random variable not implemented") @@ -439,3 +434,67 @@ def ov_random(core_shape, rng, size, *dist_params): return random_wrapper return random + + +@numba_funcify.register +def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs): + core_shape = node.inputs[0] + + [rv_node] = op.fgraph.apply_nodes + rv_op: RandomVariable = rv_node.op + size = rv_op.size_param(rv_node) + dist_params = rv_op.dist_params(rv_node) + size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) + core_shape_len = get_vector_length(core_shape) + inplace = rv_op.inplace + + try: + core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + nin = 1 + len(dist_params) # rng + params + core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) + + batch_ndim = rv_op.batch_ndim(rv_node) + + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple( + input_var.type.broadcastable[:batch_ndim] for input_var in dist_params + ) + ) + output_bc_patterns = encode_literals( + (rv_node.outputs[1].type.broadcastable[:batch_ndim],) + ) + output_dtypes = encode_literals((rv_node.default_output().type.dtype,)) + inplace_pattern = encode_literals(()) + + def random_wrapper(core_shape, rng, size, *dist_params): + if not inplace: + rng = copy(rng) + + draws = _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (rng,), + dist_params, + (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), + None + if size_len is None + else numba_ndarray.to_fixed_tuple(size, size_len), + ) + return rng, draws + + def random(core_shape, rng, size, *dist_params): + raise NotImplementedError("Non-jitted random variable not implemented") + + @overload(random, jit_options=_jit_options) + def ov_random(core_shape, rng, size, *dist_params): + return random_wrapper + + return random + + except NotImplementedError: + # Fall back to object mode for random variables that don't have core implementation + return rv_fallback_impl(op, node) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index d2301a54cb..9d5667beeb 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -705,3 +705,52 @@ def test_repeated_args(): final_node = fn.maker.fgraph.outputs[0].owner assert isinstance(final_node.op, RandomVariableWithCoreShape) assert final_node.inputs[-2] is final_node.inputs[-1] + + +def test_unsupported_rv_fallback(): + """Test that unsupported random variables fallback to object mode.""" + import warnings + + # Create a mock random variable that doesn't have a numba implementation + class CustomRV(ptr.RandomVariable): + name = "custom" + signature = "(d)->(d)" # We need a parameter for test to pass + dtype = "float64" + + def _supp_shape_from_params(self, dist_params, param_shapes=None): + # Return the shape of the support + return [1] + + def rng_fn(self, rng, value, size=None): + # Just return the value plus a random number + return value + rng.standard_normal() + + custom_rv = CustomRV() + + # Create a graph with the unsupported RV + rng = shared(np.random.default_rng(123)) + value = np.array(1.0) + x = custom_rv(value, rng=rng) + + # Capture warnings to check for the fallback warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Compile with numba mode + fn = function([], x, mode=numba_mode) + + # Execute to trigger the fallback + result = fn() + + # Check that a warning was raised about object mode + assert any("object mode" in str(warning.message) for warning in w) + + # Verify the result is as expected + assert isinstance(result, np.ndarray) + + # Run again to make sure the compiled function works properly + result2 = fn() + assert isinstance(result2, np.ndarray) + assert not np.array_equal( + result, result2 + ) # Results should differ with different RNG states