Skip to content

Commit 6dd6172

Browse files
committed
Fix Blockwise infer shape from core Op
Sometimes `_create_dummy_core_node` can create a multi-node graph, where the root inputs are not `node.inputs`. Then infer_shape may bypass the intermediate nodes. This was the case with Subtensor, which introduces `ScalarFromTensor` nodes, but ignores them in the shape graph (for a cleaner graph)
1 parent 9578bd3 commit 6dd6172

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

pytensor/tensor/blockwise.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.compile.builders import OpFromGraph
88
from pytensor.gradient import DisconnectedType
99
from pytensor.graph import FunctionGraph
10-
from pytensor.graph.basic import Apply, Constant, ancestors
10+
from pytensor.graph.basic import Apply, Constant, explicit_graph_inputs
1111
from pytensor.graph.null_type import NullType
1212
from pytensor.graph.op import Op
1313
from pytensor.graph.replace import (
@@ -190,7 +190,7 @@ def infer_shape(
190190
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
191191
if core_op_infer_shape is not None:
192192
dummy_core_node = self._create_dummy_core_node(node.inputs)
193-
dummy_core_inputs = dummy_core_node.inputs
193+
dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs))
194194
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
195195
core_input_shapes = [
196196
input_shape[batch_ndims:] for input_shape in input_shapes
@@ -214,7 +214,8 @@ def infer_shape(
214214
# of the core_node as the value is not constant across batch dims of the Blockwise
215215
core_out_dim = core_output_shapes[o][i]
216216
if not (
217-
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
217+
set(dummy_core_inputs)
218+
& set(explicit_graph_inputs([core_out_dim]))
218219
):
219220
core_out_shape.append(core_out_dim)
220221
continue

tests/tensor/test_blockwise.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,13 @@ class TestOpWithInferShape(Op):
264264
def make_node(self, a, b):
265265
assert a.type.ndim == 1
266266
assert b.type.ndim == 1
267+
# Simulate make_node that introduces operations on inputs
268+
a_identity = a.copy()
269+
b_identity = b.copy()
270+
267271
c = tensor(shape=(None,))
268272
d = tensor(shape=(None,))
269-
return Apply(self, [a, b], [c, d])
273+
return Apply(self, [a_identity, b_identity], [c, d])
270274

271275
def perform(self, node, inputs, outputs):
272276
a, b = inputs
@@ -277,9 +281,12 @@ def perform(self, node, inputs, outputs):
277281
def infer_shape(self, fgraph, node, input_shapes):
278282
# First output shape depends only on input_shapes
279283
# Second output shape depends on input values
280-
x, y = node.inputs
281-
[(x_shape,), (y_shape,)] = input_shapes
282-
return (x_shape + y_shape,), (x.sum() + y.sum(),)
284+
a_identity, b_identity = node.inputs
285+
# Simulate shape depending on original inputs, not the ones that go directly into the node
286+
a = a_identity.owner.inputs[0]
287+
b = b_identity.owner.inputs[0]
288+
[(a_shape,), (b_shape,)] = input_shapes
289+
return (a_shape + b_shape,), (a.sum() + b.sum(),)
283290

284291
blockwise_op = Blockwise(
285292
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"

0 commit comments

Comments
 (0)