Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallback for unimplemented Numba RandomVariables (#1245) #1269

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,19 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace

core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
try:
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
except NotImplementedError:
import warnings
warnings.warn(
f"RandomVariable {rv_op} is not implemented in Numba. Falling back to object mode."
)

def fallback_object_mode(*args, **kwargs):
return None # Default behavior when RV is not implemented

return fallback_object_mode

nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

Expand Down
27 changes: 25 additions & 2 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, cast

from itertools import zip_longest
import numpy as np

import pytensor
Expand All @@ -11,6 +11,7 @@
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.basic import (
as_tensor_variable,
concatenate,
Expand All @@ -32,6 +33,13 @@
from pytensor.tensor.variable import TensorVariable


def _get_value(x):
try:
return get_underlying_scalar_constant_value(x)
except Exception:
return x


class RandomVariable(Op):
"""An `Op` that produces a sample from a random variable.

Expand Down Expand Up @@ -393,11 +401,26 @@ def dist_params(self, node) -> Sequence[Variable]:
def perform(self, node, inputs, outputs):
rng, size, *args = inputs

# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
# Draw from `rng` if `self.inplace` is True; otherwise, use a copy.
if not self.inplace:
rng = deepcopy(rng)

# Validate that 'size' is broadcastable with the expected batch dimensions.
# Use the shape of the first input argument (if provided) as the expected batch shape.
if size is not None and args:
expected_shape = args[0].shape # Shape of the first input tensor
# Evaluate symbolic elements to concrete values when possible.
size_tuple = tuple(_get_value(s) for s in size)
expected_shape_concrete = tuple(_get_value(e) for e in expected_shape)
for s, e in zip_longest(size_tuple, expected_shape_concrete, fillvalue=1):
# Broadcasting rule: dimensions are compatible if they are equal or one of them is 1.
if not (s == e or s == 1 or e == 1):
raise ValueError("Size length is incompatible with batched dimensions.")

# Assign the (possibly updated) RNG state.
outputs[0][0] = rng

# Generate the random values and assign them to the output.
outputs[1][0] = np.asarray(
self.rng_fn(rng, *args, None if size is None else tuple(size)),
dtype=self.dtype,
Expand Down
4 changes: 4 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node):
new_b = b
else:
return None

# new condition to handle (1,1) @ (1,1)
if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1):
return [a * b] # Direct elementwise multiplication

new_a = copy_stack_trace(a, new_a)
new_b = copy_stack_trace(b, new_b)
Expand Down
37 changes: 35 additions & 2 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_RandomVariable_basics(strict_test_value_flags):
config.floatX,
inplace=True,
)

# New test: Ensure an integer value for ndims_params is handled properly
with pytest.raises(TypeError, match="Parameter ndims_params must be sequence type."):
RandomVariable("normal", ndims_params=2, signature="(),()->()", dtype="float32")

# `size` should be a `Sequence` type
with pytest.raises(TypeError, match="^Parameter size*"):
Expand Down Expand Up @@ -112,7 +116,7 @@ def test_RandomVariable_bcast(strict_test_value_flags):
res = rv(0, 1, size=pt.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)

res = rv(0, 1, size=(pt.as_tensor(1, dtype=np.int32), s3))
res = rv(0, 1, size=pt.as_tensor([1, s3], dtype=np.int64))
assert res.broadcastable == (True, False)


Expand All @@ -126,7 +130,7 @@ def test_RandomVariable_bcast_specify_shape(strict_test_value_flags):
s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1))

size = specify_shape(pt.as_tensor([s1, s3, s2, s2, s1]), (5,))
size = specify_shape(pt.as_tensor([1, s3, s2, s2, 1], dtype=np.int64), (5,))
mu = tensor(dtype=config.floatX, shape=(None, None, 1))
mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)

Expand Down Expand Up @@ -184,6 +188,18 @@ def test_RandomVariable_incompatible_size(strict_test_value_flags):
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((1, 3)), 1, size=(3,))

# New test: Explicitly check handling of empty size
with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
rv_op(np.zeros((2, 4, 3)), 1, size=())

# New test: Passing a size length shorter than batch dims
with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
rv_op(np.zeros((2, 4, 3)), 1, size=(2,))

# New test: Passing a size length longer than batch dims
with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
rv_op(np.zeros((2, 4, 3)), 1, size=(2, 4, 5))

rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
with pytest.raises(
Expand Down Expand Up @@ -294,6 +310,20 @@ def test_vectorize():
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 2, 5)

# Test broadcasting from scalar to vector
scalar_mu = pt.scalar("scalar_mu")
vector_mu = pt.vector("vector_mu", shape=(5,))
scalar_x = normal(scalar_mu)
vector_x = vectorize_graph(scalar_x, {scalar_mu: vector_mu})
assert vector_x.type.shape == (5,)

# Test broadcasting from vector to matrix
vector_mu = pt.vector("vector_mu", shape=(5,))
matrix_mu = pt.matrix("matrix_mu", shape=(5, 5))
vector_x = normal(vector_mu)
matrix_x = vectorize_graph(vector_x, {vector_mu: matrix_mu})
assert matrix_x.type.shape == (5, 5)


def test_vectorize_empty_size():
Expand All @@ -312,6 +342,9 @@ def test_size_none_vs_empty():
signature="(),()->()",
)
assert rv([0], [1], size=None).type.shape == (1,)

# New test: Ensure `size=None` correctly returns an empty tuple shape
assert rv(0, 1, size=None).type.shape == ()

with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
Expand Down