Skip to content

Commit 3cdcfde

Browse files
committed
Fix Blockwise and RandomVariable in Numba with repeated arguments
1 parent a920c09 commit 3cdcfde

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

pytensor/tensor/blockwise.py

+7
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,13 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
443443
class OpWithCoreShape(OpFromGraph):
444444
"""Generalizes an `Op` to include core shape as an additional input."""
445445

446+
def __init__(self, *args, on_unused_input="ignore", **kwargs):
447+
# We set on_unused_inputs="ignore" so that we can easily wrap nodes with repeated inputs
448+
# In this case the subsequent appearance of repeated inputs get disconnected in the inner graph
449+
# I can't think of a scenario where this will backfire, but if there's one
450+
# I bet on inplacing operations (time will tell)
451+
return super().__init__(*args, on_unused_input=on_unused_input, **kwargs)
452+
446453

447454
class BlockwiseWithCoreShape(OpWithCoreShape):
448455
"""Generalizes a Blockwise `Op` to include a core shape parameter."""

tests/link/numba/test_blockwise.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import pytest
33

44
from pytensor import function
5-
from pytensor.tensor import tensor
5+
from pytensor.tensor import tensor, tensor3
66
from pytensor.tensor.basic import ARange
7-
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
88
from pytensor.tensor.nlinalg import SVD, Det
99
from pytensor.tensor.slinalg import Cholesky, cholesky
1010
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
@@ -58,3 +58,15 @@ def test_blockwise_benchmark(benchmark):
5858
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
5959
fn(x_test) # JIT compile
6060
benchmark(fn, x_test)
61+
62+
63+
def test_repeated_args():
64+
x = tensor3("x")
65+
x_test = np.full((1, 1, 1), 2.0, dtype=x.type.dtype)
66+
out = x @ x
67+
fn, _ = compare_numba_and_py([x], [out], [x_test], eval_obj_mode=False)
68+
69+
# Confirm we are testing a Blockwise with repeated inputs
70+
final_node = fn.maker.fgraph.outputs[0].owner
71+
assert isinstance(final_node.op, BlockwiseWithCoreShape)
72+
assert final_node.inputs[0] is final_node.inputs[1]

tests/link/numba/test_random.py

+12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import shared
1111
from pytensor.compile.builders import OpFromGraph
1212
from pytensor.compile.function import function
13+
from pytensor.tensor.random.op import RandomVariableWithCoreShape
1314
from tests.link.numba.test_basic import (
1415
compare_numba_and_py,
1516
numba_mode,
@@ -693,3 +694,14 @@ def test_rv_inside_ofg():
693694
def test_unnatural_batched_dims(batch_dims_tester):
694695
"""Tests for RVs that don't have natural batch dims in Numba API."""
695696
batch_dims_tester(mode="NUMBA")
697+
698+
699+
def test_repeated_args():
700+
v = pt.scalar()
701+
x = ptr.beta(v, v)
702+
fn, _ = compare_numba_and_py([v], [x], [0.5 * 1e6], eval_obj_mode=False)
703+
704+
# Confirm we are testing a RandomVariable with repeated inputs
705+
final_node = fn.maker.fgraph.outputs[0].owner
706+
assert isinstance(final_node.op, RandomVariableWithCoreShape)
707+
assert final_node.inputs[-2] is final_node.inputs[-1]

0 commit comments

Comments
 (0)