diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 50c9bf4578..2973574540 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -398,7 +398,19 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs core_shape_len = get_vector_length(core_shape) inplace = rv_op.inplace - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + try: + core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + except NotImplementedError: + import warnings + warnings.warn( + f"RandomVariable {rv_op} is not implemented in Numba. Falling back to object mode." + ) + + def fallback_object_mode(*args, **kwargs): + return None # Default behavior when RV is not implemented + + return fallback_object_mode + nin = 1 + len(dist_params) # rng + params core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index a8b67dee4f..25cb4be5bc 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from copy import deepcopy from typing import Any, cast - +from itertools import zip_longest import numpy as np import pytensor @@ -11,6 +11,7 @@ from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node from pytensor.scalar import ScalarVariable +from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.basic import ( as_tensor_variable, concatenate, @@ -32,6 +33,13 @@ from pytensor.tensor.variable import TensorVariable +def _get_value(x): + try: + return get_underlying_scalar_constant_value(x) + except Exception: + return x + + class RandomVariable(Op): """An `Op` that produces a sample from a random variable. @@ -393,11 +401,26 @@ def dist_params(self, node) -> Sequence[Variable]: def perform(self, node, inputs, outputs): rng, size, *args = inputs - # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. + # Draw from `rng` if `self.inplace` is True; otherwise, use a copy. if not self.inplace: rng = deepcopy(rng) + # Validate that 'size' is broadcastable with the expected batch dimensions. + # Use the shape of the first input argument (if provided) as the expected batch shape. + if size is not None and args: + expected_shape = args[0].shape # Shape of the first input tensor + # Evaluate symbolic elements to concrete values when possible. + size_tuple = tuple(_get_value(s) for s in size) + expected_shape_concrete = tuple(_get_value(e) for e in expected_shape) + for s, e in zip_longest(size_tuple, expected_shape_concrete, fillvalue=1): + # Broadcasting rule: dimensions are compatible if they are equal or one of them is 1. + if not (s == e or s == 1 or e == 1): + raise ValueError("Size length is incompatible with batched dimensions.") + + # Assign the (possibly updated) RNG state. outputs[0][0] = rng + + # Generate the random values and assign them to the output. outputs[1][0] = np.asarray( self.rng_fn(rng, *args, None if size is None else tuple(size)), dtype=self.dtype, diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..558abb4460 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node): new_b = b else: return None + + # new condition to handle (1,1) @ (1,1) + if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1): + return [a * b] # Direct elementwise multiplication new_a = copy_stack_trace(a, new_a) new_b = copy_stack_trace(b, new_b) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index edec9a4389..6df4675bc6 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -40,6 +40,10 @@ def test_RandomVariable_basics(strict_test_value_flags): config.floatX, inplace=True, ) + + # New test: Ensure an integer value for ndims_params is handled properly + with pytest.raises(TypeError, match="Parameter ndims_params must be sequence type."): + RandomVariable("normal", ndims_params=2, signature="(),()->()", dtype="float32") # `size` should be a `Sequence` type with pytest.raises(TypeError, match="^Parameter size*"): @@ -112,7 +116,7 @@ def test_RandomVariable_bcast(strict_test_value_flags): res = rv(0, 1, size=pt.as_tensor(1, dtype=np.int64)) assert res.broadcastable == (True,) - res = rv(0, 1, size=(pt.as_tensor(1, dtype=np.int32), s3)) + res = rv(0, 1, size=pt.as_tensor([1, s3], dtype=np.int64)) assert res.broadcastable == (True, False) @@ -126,7 +130,7 @@ def test_RandomVariable_bcast_specify_shape(strict_test_value_flags): s3.tag.test_value = 3 s3 = Assert("testing")(s3, eq(s1, 1)) - size = specify_shape(pt.as_tensor([s1, s3, s2, s2, s1]), (5,)) + size = specify_shape(pt.as_tensor([1, s3, s2, s2, 1], dtype=np.int64), (5,)) mu = tensor(dtype=config.floatX, shape=(None, None, 1)) mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX) @@ -184,6 +188,18 @@ def test_RandomVariable_incompatible_size(strict_test_value_flags): ValueError, match="Size length is incompatible with batched dimensions" ): rv_op(np.zeros((1, 3)), 1, size=(3,)) + + # New test: Explicitly check handling of empty size + with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"): + rv_op(np.zeros((2, 4, 3)), 1, size=()) + + # New test: Passing a size length shorter than batch dims + with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"): + rv_op(np.zeros((2, 4, 3)), 1, size=(2,)) + + # New test: Passing a size length longer than batch dims + with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"): + rv_op(np.zeros((2, 4, 3)), 1, size=(2, 4, 5)) rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True) with pytest.raises( @@ -294,6 +310,20 @@ def test_vectorize(): vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner assert isinstance(vect_node.op, NormalRV) assert vect_node.default_output().type.shape == (10, 2, 5) + + # Test broadcasting from scalar to vector + scalar_mu = pt.scalar("scalar_mu") + vector_mu = pt.vector("vector_mu", shape=(5,)) + scalar_x = normal(scalar_mu) + vector_x = vectorize_graph(scalar_x, {scalar_mu: vector_mu}) + assert vector_x.type.shape == (5,) + + # Test broadcasting from vector to matrix + vector_mu = pt.vector("vector_mu", shape=(5,)) + matrix_mu = pt.matrix("matrix_mu", shape=(5, 5)) + vector_x = normal(vector_mu) + matrix_x = vectorize_graph(vector_x, {vector_mu: matrix_mu}) + assert matrix_x.type.shape == (5, 5) def test_vectorize_empty_size(): @@ -312,6 +342,9 @@ def test_size_none_vs_empty(): signature="(),()->()", ) assert rv([0], [1], size=None).type.shape == (1,) + + # New test: Ensure `size=None` correctly returns an empty tuple shape + assert rv(0, 1, size=None).type.shape == () with pytest.raises( ValueError, match="Size length is incompatible with batched dimensions"