Skip to content

Commit 18ba52c

Browse files
committed
Use infer_shape of core_op to infer Blockwise core shapes
This can only be done when the output of infer_shape of the core_op depends only on the input shapes, and not their values.
1 parent ef97287 commit 18ba52c

File tree

2 files changed

+81
-3
lines changed

2 files changed

+81
-3
lines changed

pytensor/tensor/blockwise.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from pytensor import config
77
from pytensor.compile.builders import OpFromGraph
88
from pytensor.gradient import DisconnectedType
9-
from pytensor.graph.basic import Apply, Constant
9+
from pytensor.graph import FunctionGraph
10+
from pytensor.graph.basic import Apply, Constant, ancestors
1011
from pytensor.graph.null_type import NullType
1112
from pytensor.graph.op import Op
1213
from pytensor.graph.replace import (
@@ -185,15 +186,40 @@ def infer_shape(
185186

186187
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
187188

189+
# Try to extract the core shapes from the core_op
190+
core_op_infer_shape = getattr(self.core_op, "infer_shape", None)
191+
if core_op_infer_shape is not None:
192+
dummy_core_node = self._create_dummy_core_node(node.inputs)
193+
dummy_core_inputs = dummy_core_node.inputs
194+
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
195+
core_input_shapes = [
196+
input_shape[batch_ndims:] for input_shape in input_shapes
197+
]
198+
core_output_shapes = core_op_infer_shape(
199+
dummy_fgraph, dummy_core_node, core_input_shapes
200+
)
201+
188202
out_shapes = []
189-
for output, sig in zip(node.outputs, self.outputs_sig, strict=True):
203+
for o, (output, sig) in enumerate(
204+
zip(node.outputs, self.outputs_sig, strict=True)
205+
):
190206
core_out_shape = []
191207
for i, dim_name in enumerate(sig):
192208
# The output dim is the same as another input dim
193209
if dim_name in core_dims:
194210
core_out_shape.append(core_dims[dim_name])
195211
else:
196-
# TODO: We could try to make use of infer_shape of core_op
212+
if core_op_infer_shape is not None:
213+
# If the input values are needed to compute the dimension length, we can't use the infer_shape
214+
# of the core_node as the value is not constant across batch dims of the Blockwise
215+
core_out_dim = core_output_shapes[o][i]
216+
if not (
217+
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
218+
):
219+
core_out_shape.append(core_out_dim)
220+
continue
221+
222+
# Fallback shape requires evaluating the Blockwise Op
197223
core_out_shape.append(Shape_i(batch_ndims + i)(output))
198224
out_shapes.append((*batch_shape, *core_out_shape))
199225

tests/tensor/test_blockwise.py

+52
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,58 @@ def test_blockwise_shape():
259259
assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4)
260260

261261

262+
def test_blockwise_infer_core_shape():
263+
class TestOpWithInferShape(Op):
264+
def make_node(self, a, b):
265+
assert a.type.ndim == 1
266+
assert b.type.ndim == 1
267+
c = tensor(shape=(None,))
268+
d = tensor(shape=(None,))
269+
return Apply(self, [a, b], [c, d])
270+
271+
def perform(self, node, inputs, outputs):
272+
a, b = inputs
273+
c, d = outputs
274+
c[0] = np.arange(a.size + b.size)
275+
d[0] = np.arange(a.sum() + b.sum())
276+
277+
def infer_shape(self, fgraph, node, input_shapes):
278+
# First output shape depends only on input_shapes
279+
# Second output shape depends on input values
280+
x, y = node.inputs
281+
[(x_shape,), (y_shape,)] = input_shapes
282+
return (x_shape + y_shape,), (x.sum() + y.sum(),)
283+
284+
blockwise_op = Blockwise(
285+
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
286+
)
287+
288+
a = tensor("a", shape=(5, 3))
289+
b = tensor("b", shape=(1, 4))
290+
c, d = blockwise_op(a, b)
291+
assert c.type.shape == (5, None)
292+
assert d.type.shape == (5, None)
293+
294+
c_shape_fn = pytensor.function([a, b], c.shape)
295+
# c_shape can be computed from the input shapes alone
296+
assert not any(
297+
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
298+
for n in c_shape_fn.maker.fgraph.apply_nodes
299+
)
300+
301+
d_shape_fn = pytensor.function([a, b], d.shape)
302+
# d_shape cannot be computed from the input shapes alone
303+
assert any(
304+
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
305+
for n in d_shape_fn.maker.fgraph.apply_nodes
306+
)
307+
308+
a_test = np.zeros(a.type.shape, dtype=a.type.dtype)
309+
b_test = np.zeros(b.type.shape, dtype=b.type.dtype)
310+
assert tuple(c_shape_fn(a_test, b_test)) == (5, 7)
311+
assert tuple(d_shape_fn(a_test, b_test)) == (5, 0)
312+
313+
262314
class BlockwiseOpTester:
263315
"""Base class to test Blockwise works for specific Ops"""
264316

0 commit comments

Comments
 (0)