@@ -259,6 +259,58 @@ def test_blockwise_shape():
259
259
assert tuple (shape_fn (inp1_test , inp2_test )[1 ]) == (7 , 5 , 4 )
260
260
261
261
262
+ def test_blockwise_infer_core_shape ():
263
+ class TestOpWithInferShape (Op ):
264
+ def make_node (self , a , b ):
265
+ assert a .type .ndim == 1
266
+ assert b .type .ndim == 1
267
+ c = tensor (shape = (None ,))
268
+ d = tensor (shape = (None ,))
269
+ return Apply (self , [a , b ], [c , d ])
270
+
271
+ def perform (self , node , inputs , outputs ):
272
+ a , b = inputs
273
+ c , d = outputs
274
+ c [0 ] = np .arange (a .size + b .size )
275
+ d [0 ] = np .arange (a .sum () + b .sum ())
276
+
277
+ def infer_shape (self , fgraph , node , input_shapes ):
278
+ # First output shape depends only on input_shapes
279
+ # Second output shape depends on input values
280
+ x , y = node .inputs
281
+ [(x_shape ,), (y_shape ,)] = input_shapes
282
+ return (x_shape + y_shape ,), (x .sum () + y .sum (),)
283
+
284
+ blockwise_op = Blockwise (
285
+ core_op = TestOpWithInferShape (), signature = "(a),(b)->(c),(d)"
286
+ )
287
+
288
+ a = tensor ("a" , shape = (5 , 3 ))
289
+ b = tensor ("b" , shape = (1 , 4 ))
290
+ c , d = blockwise_op (a , b )
291
+ assert c .type .shape == (5 , None )
292
+ assert d .type .shape == (5 , None )
293
+
294
+ c_shape_fn = pytensor .function ([a , b ], c .shape )
295
+ # c_shape can be computed from the input shapes alone
296
+ assert not any (
297
+ isinstance (getattr (n .op , "core_op" , n .op ), TestOpWithInferShape )
298
+ for n in c_shape_fn .maker .fgraph .apply_nodes
299
+ )
300
+
301
+ d_shape_fn = pytensor .function ([a , b ], d .shape )
302
+ # d_shape cannot be computed from the input shapes alone
303
+ assert any (
304
+ isinstance (getattr (n .op , "core_op" , n .op ), TestOpWithInferShape )
305
+ for n in d_shape_fn .maker .fgraph .apply_nodes
306
+ )
307
+
308
+ a_test = np .zeros (a .type .shape , dtype = a .type .dtype )
309
+ b_test = np .zeros (b .type .shape , dtype = b .type .dtype )
310
+ assert tuple (c_shape_fn (a_test , b_test )) == (5 , 7 )
311
+ assert tuple (d_shape_fn (a_test , b_test )) == (5 , 0 )
312
+
313
+
262
314
class BlockwiseOpTester :
263
315
"""Base class to test Blockwise works for specific Ops"""
264
316
0 commit comments