|
42 | 42 | Shape,
|
43 | 43 | Shape_i,
|
44 | 44 | SpecifyShape,
|
45 |
| - Unbroadcast, |
46 | 45 | specify_shape,
|
47 |
| - unbroadcast, |
48 | 46 | )
|
49 | 47 | from pytensor.tensor.subtensor import Subtensor, get_idx_list
|
50 | 48 | from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
|
@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
|
1296 | 1294 | # structure.
|
1297 | 1295 | replacement = shape_feature.scheduled[node]
|
1298 | 1296 | return [shape_feature.shape_of[replacement][node.op.i]]
|
1299 |
| - |
1300 |
| - |
1301 |
| -@register_useless |
1302 |
| -@register_canonicalize |
1303 |
| -@register_specialize |
1304 |
| -@node_rewriter([Unbroadcast]) |
1305 |
| -def local_useless_unbroadcast(fgraph, node): |
1306 |
| - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.""" |
1307 |
| - if isinstance(node.op, Unbroadcast): |
1308 |
| - x = node.inputs[0] |
1309 |
| - if x.type.ndim == node.outputs[0].type.ndim and all( |
1310 |
| - s1 == s2 |
1311 |
| - for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True) |
1312 |
| - if s1 == 1 or s2 == 1 |
1313 |
| - ): |
1314 |
| - # No broadcastable flag was modified |
1315 |
| - # No need to copy over stack trace, |
1316 |
| - # because x should already have a stack trace. |
1317 |
| - return [x] |
1318 |
| - else: |
1319 |
| - # Keep the flags that modify something |
1320 |
| - new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) |
1321 |
| - if new_axes == node.op.axes: |
1322 |
| - # All flags are useful |
1323 |
| - return None |
1324 |
| - else: |
1325 |
| - r = unbroadcast(x, *new_axes) |
1326 |
| - # Copy over stacktrace from previous output |
1327 |
| - copy_stack_trace(node.outputs, r) |
1328 |
| - return [r] |
1329 |
| - |
1330 |
| - |
1331 |
| -@register_canonicalize |
1332 |
| -@register_specialize |
1333 |
| -@node_rewriter([Unbroadcast]) |
1334 |
| -def local_unbroadcast_lift(fgraph, node): |
1335 |
| - """ |
1336 |
| - Lifts `Unbroadcast` through unary Elemwise operations, |
1337 |
| - and merges consecutive `Unbroadcast`s. |
1338 |
| -
|
1339 |
| - Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) |
1340 |
| - Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) |
1341 |
| -
|
1342 |
| - TODO: Implement equivalent Elemwise lift for SpecifyShape |
1343 |
| - """ |
1344 |
| - op = node.op |
1345 |
| - if not isinstance(op, Unbroadcast): |
1346 |
| - return False |
1347 |
| - |
1348 |
| - inp = node.inputs[0] |
1349 |
| - inode = inp.owner |
1350 |
| - if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: |
1351 |
| - if len(fgraph.clients.get(inp, ())) == 1: |
1352 |
| - unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) |
1353 |
| - copy_stack_trace(node.outputs, unbroadcasted) |
1354 |
| - |
1355 |
| - rval = inode.op.make_node(unbroadcasted).outputs |
1356 |
| - |
1357 |
| - # Copy over stacktrace from previous output (after unbroadcasting) |
1358 |
| - # and input (after elemwise operation) to new output, because an |
1359 |
| - # error in the new graph could have been caused by either of the |
1360 |
| - # two ops. |
1361 |
| - copy_stack_trace(node.outputs + node.inputs, rval) |
1362 |
| - return rval |
1363 |
| - |
1364 |
| - if inode and isinstance(inode.op, Unbroadcast): |
1365 |
| - # Merge axis of each unbroadcast |
1366 |
| - axis = tuple(set(inode.op.axes).union(set(op.axes))) |
1367 |
| - iinput = inode.inputs[0] |
1368 |
| - rval = [unbroadcast(iinput, *axis)] |
1369 |
| - # Copy over stacktrace from previous output (after second unbroadcasting) |
1370 |
| - # and from previous input (after first unbroadcasting) because an error in |
1371 |
| - # the new graph could have been caused by either of the two Unbroadcast ops. |
1372 |
| - copy_stack_trace(node.outputs + node.inputs, rval) |
1373 |
| - return rval |
0 commit comments