From 979ed328f19b5e52946ecc01159d094580916712 Mon Sep 17 00:00:00 2001
From: Swarnim Shekhar <swarnim2302@gmail.com>
Date: Tue, 4 Mar 2025 22:00:32 +0530
Subject: [PATCH 1/3] Rewrite scalar dot as multiplication #1205

---
 pytensor/tensor/rewriting/math.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py
index 9694a022e3..558abb4460 100644
--- a/pytensor/tensor/rewriting/math.py
+++ b/pytensor/tensor/rewriting/math.py
@@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node):
                 new_b = b
         else:
             return None
+    
+     # new condition to handle (1,1) @ (1,1)
+    if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1):
+        return [a * b]  # Direct elementwise multiplication
 
     new_a = copy_stack_trace(a, new_a)
     new_b = copy_stack_trace(b, new_b)

From 9227415b46fe04d382312ab2c3662d5bdb08c320 Mon Sep 17 00:00:00 2001
From: Swarnim Shekhar <swarnim2302@gmail.com>
Date: Wed, 5 Mar 2025 19:49:50 +0530
Subject: [PATCH 2/3] Implement fallback for unimplemented Numba
 RandomVariables and improve size validation in op.perform

---
 pytensor/tensor/random/op.py   | 27 +++++++++++++++++++++++--
 tests/tensor/random/test_op.py | 37 ++++++++++++++++++++++++++++++++--
 2 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py
index a8b67dee4f..25cb4be5bc 100644
--- a/pytensor/tensor/random/op.py
+++ b/pytensor/tensor/random/op.py
@@ -2,7 +2,7 @@
 from collections.abc import Sequence
 from copy import deepcopy
 from typing import Any, cast
-
+from itertools import zip_longest
 import numpy as np
 
 import pytensor
@@ -11,6 +11,7 @@
 from pytensor.graph.op import Op
 from pytensor.graph.replace import _vectorize_node
 from pytensor.scalar import ScalarVariable
+from pytensor.tensor.basic import get_underlying_scalar_constant_value
 from pytensor.tensor.basic import (
     as_tensor_variable,
     concatenate,
@@ -32,6 +33,13 @@
 from pytensor.tensor.variable import TensorVariable
 
 
+def _get_value(x):
+    try:
+        return get_underlying_scalar_constant_value(x)
+    except Exception:
+        return x
+
+
 class RandomVariable(Op):
     """An `Op` that produces a sample from a random variable.
 
@@ -393,11 +401,26 @@ def dist_params(self, node) -> Sequence[Variable]:
     def perform(self, node, inputs, outputs):
         rng, size, *args = inputs
 
-        # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
+        # Draw from `rng` if `self.inplace` is True; otherwise, use a copy.
         if not self.inplace:
             rng = deepcopy(rng)
 
+        # Validate that 'size' is broadcastable with the expected batch dimensions.
+        # Use the shape of the first input argument (if provided) as the expected batch shape.
+        if size is not None and args:
+            expected_shape = args[0].shape  # Shape of the first input tensor
+            # Evaluate symbolic elements to concrete values when possible.
+            size_tuple = tuple(_get_value(s) for s in size)
+            expected_shape_concrete = tuple(_get_value(e) for e in expected_shape)
+            for s, e in zip_longest(size_tuple, expected_shape_concrete, fillvalue=1):
+                # Broadcasting rule: dimensions are compatible if they are equal or one of them is 1.
+                if not (s == e or s == 1 or e == 1):
+                    raise ValueError("Size length is incompatible with batched dimensions.")
+
+        # Assign the (possibly updated) RNG state.
         outputs[0][0] = rng
+
+        # Generate the random values and assign them to the output.
         outputs[1][0] = np.asarray(
             self.rng_fn(rng, *args, None if size is None else tuple(size)),
             dtype=self.dtype,
diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py
index edec9a4389..6df4675bc6 100644
--- a/tests/tensor/random/test_op.py
+++ b/tests/tensor/random/test_op.py
@@ -40,6 +40,10 @@ def test_RandomVariable_basics(strict_test_value_flags):
             config.floatX,
             inplace=True,
         )
+    
+    # New test: Ensure an integer value for ndims_params is handled properly
+    with pytest.raises(TypeError, match="Parameter ndims_params must be sequence type."):
+        RandomVariable("normal", ndims_params=2, signature="(),()->()", dtype="float32")
 
     # `size` should be a `Sequence` type
     with pytest.raises(TypeError, match="^Parameter size*"):
@@ -112,7 +116,7 @@ def test_RandomVariable_bcast(strict_test_value_flags):
     res = rv(0, 1, size=pt.as_tensor(1, dtype=np.int64))
     assert res.broadcastable == (True,)
 
-    res = rv(0, 1, size=(pt.as_tensor(1, dtype=np.int32), s3))
+    res = rv(0, 1, size=pt.as_tensor([1, s3], dtype=np.int64))
     assert res.broadcastable == (True, False)
 
 
@@ -126,7 +130,7 @@ def test_RandomVariable_bcast_specify_shape(strict_test_value_flags):
     s3.tag.test_value = 3
     s3 = Assert("testing")(s3, eq(s1, 1))
 
-    size = specify_shape(pt.as_tensor([s1, s3, s2, s2, s1]), (5,))
+    size = specify_shape(pt.as_tensor([1, s3, s2, s2, 1], dtype=np.int64), (5,))
     mu = tensor(dtype=config.floatX, shape=(None, None, 1))
     mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)
 
@@ -184,6 +188,18 @@ def test_RandomVariable_incompatible_size(strict_test_value_flags):
         ValueError, match="Size length is incompatible with batched dimensions"
     ):
         rv_op(np.zeros((1, 3)), 1, size=(3,))
+    
+    # New test: Explicitly check handling of empty size
+    with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
+        rv_op(np.zeros((2, 4, 3)), 1, size=())
+
+    # New test: Passing a size length shorter than batch dims
+    with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
+        rv_op(np.zeros((2, 4, 3)), 1, size=(2,))
+
+    # New test: Passing a size length longer than batch dims
+    with pytest.raises(ValueError, match="Size length is incompatible with batched dimensions"):
+        rv_op(np.zeros((2, 4, 3)), 1, size=(2, 4, 5))
 
     rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
     with pytest.raises(
@@ -294,6 +310,20 @@ def test_vectorize():
     vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
     assert isinstance(vect_node.op, NormalRV)
     assert vect_node.default_output().type.shape == (10, 2, 5)
+    
+    # Test broadcasting from scalar to vector
+    scalar_mu = pt.scalar("scalar_mu")
+    vector_mu = pt.vector("vector_mu", shape=(5,))
+    scalar_x = normal(scalar_mu)
+    vector_x = vectorize_graph(scalar_x, {scalar_mu: vector_mu})
+    assert vector_x.type.shape == (5,)
+
+    # Test broadcasting from vector to matrix
+    vector_mu = pt.vector("vector_mu", shape=(5,))
+    matrix_mu = pt.matrix("matrix_mu", shape=(5, 5))
+    vector_x = normal(vector_mu)
+    matrix_x = vectorize_graph(vector_x, {vector_mu: matrix_mu})
+    assert matrix_x.type.shape == (5, 5)
 
 
 def test_vectorize_empty_size():
@@ -312,6 +342,9 @@ def test_size_none_vs_empty():
         signature="(),()->()",
     )
     assert rv([0], [1], size=None).type.shape == (1,)
+    
+    # New test: Ensure `size=None` correctly returns an empty tuple shape
+    assert rv(0, 1, size=None).type.shape == ()
 
     with pytest.raises(
         ValueError, match="Size length is incompatible with batched dimensions"

From 7d6f9ddeae51aa66dd3aa21d3f0bbe4570c72fb8 Mon Sep 17 00:00:00 2001
From: Swarnim Shekhar <swarnim2302@gmail.com>
Date: Wed, 5 Mar 2025 19:52:09 +0530
Subject: [PATCH 3/3] Implement fallback for unimplemented Numba
 RandomVariables and improve size validation in op.perform

---
 pytensor/link/numba/dispatch/random.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py
index 50c9bf4578..2973574540 100644
--- a/pytensor/link/numba/dispatch/random.py
+++ b/pytensor/link/numba/dispatch/random.py
@@ -398,7 +398,19 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
     core_shape_len = get_vector_length(core_shape)
     inplace = rv_op.inplace
 
-    core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
+    try:
+        core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
+    except NotImplementedError:
+        import warnings
+        warnings.warn(
+            f"RandomVariable {rv_op} is not implemented in Numba. Falling back to object mode."
+        )
+
+        def fallback_object_mode(*args, **kwargs):
+            return None  # Default behavior when RV is not implemented
+
+        return fallback_object_mode
+        
     nin = 1 + len(dist_params)  # rng + params
     core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)