Skip to content

Commit 42a7adb

Browse files
committed
Remove Unbroadcast Op
1 parent a24f534 commit 42a7adb

File tree

22 files changed

+164
-656
lines changed

22 files changed

+164
-656
lines changed

doc/library/tensor/basic.rst

+6-9
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
619619

620620
.. function:: shape_padleft(x, n_ones=1)
621621

622-
Reshape `x` by left padding the shape with `n_ones` 1s. Note that all
623-
this new dimension will be broadcastable. To make them non-broadcastable
624-
see the :func:`unbroadcast`.
622+
Reshape `x` by left padding the shape with `n_ones` 1s.
623+
All new dimensions will be broadcastable.
625624

626625
:param x: variable to be reshaped
627626
:type x: any `TensorVariable` (or compatible)
@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
633632

634633
.. function:: shape_padright(x, n_ones=1)
635634

636-
Reshape `x` by right padding the shape with `n_ones` ones. Note that all
637-
this new dimension will be broadcastable. To make them non-broadcastable
638-
see the :func:`unbroadcast`.
635+
Reshape `x` by right padding the shape with `n_ones` ones.
636+
All new dimensions will be broadcastable.
639637

640638
:param x: variable to be reshaped
641639
:type x: any TensorVariable (or compatible)
@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
646644

647645
.. function:: shape_padaxis(t, axis)
648646

649-
Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new
650-
dimension will be broadcastable. To make it non-broadcastable
651-
see the :func:`unbroadcast`.
647+
Reshape `t` by inserting ``1`` at the dimension `axis`.
648+
All new dimensions will be broadcastable.
652649

653650
:type x: any `TensorVariable` (or compatible)
654651
:param x: variable to be reshaped

pytensor/compile/function/pfunc.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,8 @@ def clone_inputs(i):
292292
f" shared_var.type={store_into.type},"
293293
f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
294294
)
295-
err_sug = (
296-
"If the difference is related to the broadcast pattern,"
297-
" you can call the"
298-
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
299-
" function to mask broadcastable dimensions."
300-
)
301295

302-
raise TypeError(err_msg, err_sug)
296+
raise TypeError(err_msg)
303297
assert store_into.type.is_super(update_val.type)
304298

305299
update_d[store_into] = update_val

pytensor/ifelse.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.graph.replace import clone_replace
2727
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
2828
from pytensor.graph.type import HasDataType, HasShape
29-
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
29+
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape
3030

3131

3232
if TYPE_CHECKING:
@@ -481,7 +481,6 @@ def cond_make_inplace(fgraph, node):
481481
Shape,
482482
SpecifyShape,
483483
Reshape,
484-
Unbroadcast,
485484
pt.math.Dot,
486485
pt.math.Max,
487486
pt.math.Argmax,

pytensor/link/jax/dispatch/shape.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pytensor.graph.basic import Apply
55
from pytensor.graph.op import Op
66
from pytensor.link.jax.dispatch.basic import jax_funcify
7-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
7+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
88
from pytensor.tensor.type import TensorType
99

1010

@@ -104,11 +104,3 @@ def specifyshape(x, *shape):
104104
return x
105105

106106
return specifyshape
107-
108-
109-
@jax_funcify.register(Unbroadcast)
110-
def jax_funcify_Unbroadcast(op, **kwargs):
111-
def unbroadcast(x):
112-
return x
113-
114-
return unbroadcast

pytensor/link/numba/dispatch/tensor_basic.py

-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Split,
1818
TensorFromScalar,
1919
)
20-
from pytensor.tensor.shape import Unbroadcast
2120

2221

2322
@numba_funcify.register(AllocEmpty)
@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
232231
return numba_basic.numba_njit(makevector_fn)
233232

234233

235-
@numba_funcify.register(Unbroadcast)
236-
def numba_funcify_Unbroadcast(op, **kwargs):
237-
@numba_basic.numba_njit
238-
def unbroadcast(x):
239-
return x
240-
241-
return unbroadcast
242-
243-
244234
@numba_funcify.register(TensorFromScalar)
245235
def numba_funcify_TensorFromScalar(op, **kwargs):
246236
@numba_basic.numba_njit(inline="always")

pytensor/link/pytorch/dispatch/shape.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.graph.basic import Constant
44
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
5-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
5+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
66

77

88
@pytorch_funcify.register(Reshape)
@@ -56,11 +56,3 @@ def specifyshape(x, *shape):
5656
return x
5757

5858
return specifyshape
59-
60-
61-
@pytorch_funcify.register(Unbroadcast)
62-
def pytorch_funcify_Unbroadcast(op, **kwargs):
63-
def unbroadcast(x):
64-
return x
65-
66-
return unbroadcast

pytensor/scan/basic.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytensor.tensor.basic import get_underlying_scalar_constant_value
1616
from pytensor.tensor.exceptions import NotScalarConstantError
1717
from pytensor.tensor.math import minimum
18-
from pytensor.tensor.shape import shape_padleft, unbroadcast
18+
from pytensor.tensor.shape import shape_padleft
1919
from pytensor.tensor.type import TensorType, integer_dtypes
2020
from pytensor.updates import OrderedUpdates
2121

@@ -748,7 +748,7 @@ def wrap_into_list(x):
748748
# defined in scan utils
749749
sit_sot_scan_inputs.append(
750750
expand_empty(
751-
unbroadcast(shape_padleft(actual_arg), 0),
751+
shape_padleft(actual_arg),
752752
actual_n_steps,
753753
)
754754
)
@@ -865,13 +865,13 @@ def wrap_into_list(x):
865865
if n_fixed_steps in (1, -1):
866866
for pos, inner_out in enumerate(outputs):
867867
# we need to see if we need to pad our sequences with an
868-
# unbroadcastable dimension; case example : we return an
868+
# extra dimension; case example : we return an
869869
# output for which we want all intermediate. If n_steps is 1
870870
# then, if we return the output as given by the innner function
871871
# this will represent only a slice and it will have one
872872
# dimension less.
873873
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
874-
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
874+
outputs[pos] = shape_padleft(inner_out)
875875

876876
if not return_list and len(outputs) == 1:
877877
outputs = outputs[0]
@@ -1002,7 +1002,7 @@ def wrap_into_list(x):
10021002
sit_sot_inner_inputs.append(new_var)
10031003
sit_sot_scan_inputs.append(
10041004
expand_empty(
1005-
unbroadcast(shape_padleft(input.variable), 0),
1005+
shape_padleft(input.variable),
10061006
actual_n_steps,
10071007
)
10081008
)

pytensor/scan/op.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
166166
"axis %d in `output_info`. This can happen if one of the "
167167
"dimension is fixed to 1 in the input, while it is still "
168168
"variable in the output, or vice-verca. You have to make "
169-
"them consistent, e.g. using pytensor.tensor."
170-
"{unbroadcast, specify_broadcastable}."
169+
"them consistent, e.g. using pytensor.tensor.specify_broadcastable."
171170
)
172171
size = min(v1.type.ndim, v2.type.ndim)
173172
for n, (b1, b2) in enumerate(

pytensor/tensor/basic.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from pytensor.tensor.shape import (
5454
Shape,
5555
Shape_i,
56-
Unbroadcast,
5756
shape,
5857
shape_padaxis,
5958
shape_padleft,
@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value(
334333
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
335334
op = v.owner.op
336335
max_recur -= 1
337-
if isinstance(
338-
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
339-
):
336+
if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp):
340337
# OutputGuard is only used in debugmode but we
341338
# keep it here to avoid problems with old pickles
342339
v = v.owner.inputs[0]
@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value(
498495
grandparent = leftmost_parent.owner.inputs[0]
499496
gp_shape = grandparent.type.shape
500497
ndim = grandparent.type.ndim
501-
if grandparent.owner and isinstance(
502-
grandparent.owner.op, Unbroadcast
503-
):
504-
ggp_shape = grandparent.owner.inputs[0].type.shape
505-
l = [
506-
_get_underlying_scalar_constant_value(s) for s in ggp_shape
507-
]
508-
gp_shape = tuple(l)
509498

510499
if not (idx < ndim):
511500
msg = (

pytensor/tensor/rewriting/shape.py

-77
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
Shape,
4343
Shape_i,
4444
SpecifyShape,
45-
Unbroadcast,
4645
specify_shape,
47-
unbroadcast,
4846
)
4947
from pytensor.tensor.subtensor import Subtensor, get_idx_list
5048
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
12961294
# structure.
12971295
replacement = shape_feature.scheduled[node]
12981296
return [shape_feature.shape_of[replacement][node.op.i]]
1299-
1300-
1301-
@register_useless
1302-
@register_canonicalize
1303-
@register_specialize
1304-
@node_rewriter([Unbroadcast])
1305-
def local_useless_unbroadcast(fgraph, node):
1306-
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
1307-
if isinstance(node.op, Unbroadcast):
1308-
x = node.inputs[0]
1309-
if x.type.ndim == node.outputs[0].type.ndim and all(
1310-
s1 == s2
1311-
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True)
1312-
if s1 == 1 or s2 == 1
1313-
):
1314-
# No broadcastable flag was modified
1315-
# No need to copy over stack trace,
1316-
# because x should already have a stack trace.
1317-
return [x]
1318-
else:
1319-
# Keep the flags that modify something
1320-
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
1321-
if new_axes == node.op.axes:
1322-
# All flags are useful
1323-
return None
1324-
else:
1325-
r = unbroadcast(x, *new_axes)
1326-
# Copy over stacktrace from previous output
1327-
copy_stack_trace(node.outputs, r)
1328-
return [r]
1329-
1330-
1331-
@register_canonicalize
1332-
@register_specialize
1333-
@node_rewriter([Unbroadcast])
1334-
def local_unbroadcast_lift(fgraph, node):
1335-
"""
1336-
Lifts `Unbroadcast` through unary Elemwise operations,
1337-
and merges consecutive `Unbroadcast`s.
1338-
1339-
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
1340-
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
1341-
1342-
TODO: Implement equivalent Elemwise lift for SpecifyShape
1343-
"""
1344-
op = node.op
1345-
if not isinstance(op, Unbroadcast):
1346-
return False
1347-
1348-
inp = node.inputs[0]
1349-
inode = inp.owner
1350-
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
1351-
if len(fgraph.clients.get(inp, ())) == 1:
1352-
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
1353-
copy_stack_trace(node.outputs, unbroadcasted)
1354-
1355-
rval = inode.op.make_node(unbroadcasted).outputs
1356-
1357-
# Copy over stacktrace from previous output (after unbroadcasting)
1358-
# and input (after elemwise operation) to new output, because an
1359-
# error in the new graph could have been caused by either of the
1360-
# two ops.
1361-
copy_stack_trace(node.outputs + node.inputs, rval)
1362-
return rval
1363-
1364-
if inode and isinstance(inode.op, Unbroadcast):
1365-
# Merge axis of each unbroadcast
1366-
axis = tuple(set(inode.op.axes).union(set(op.axes)))
1367-
iinput = inode.inputs[0]
1368-
rval = [unbroadcast(iinput, *axis)]
1369-
# Copy over stacktrace from previous output (after second unbroadcasting)
1370-
# and from previous input (after first unbroadcasting) because an error in
1371-
# the new graph could have been caused by either of the two Unbroadcast ops.
1372-
copy_stack_trace(node.outputs + node.inputs, rval)
1373-
return rval

pytensor/tensor/rewriting/subtensor.py

-37
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@
5959
from pytensor.tensor.shape import (
6060
Shape,
6161
SpecifyShape,
62-
Unbroadcast,
6362
shape_padleft,
6463
shape_tuple,
6564
specify_shape,
66-
unbroadcast,
6765
)
6866
from pytensor.tensor.sharedvar import TensorSharedVariable
6967
from pytensor.tensor.subtensor import (
@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node):
429427
Handles the following unary ops:
430428
elemwise(x,...)[idx] -> elemwise(x[idx],...)
431429
when x,... are broadcasted scalar or not broadcasted at all
432-
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
433430
434431
"""
435432
if isinstance(node.op, Subtensor):
@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node):
488485
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
489486
return [ret]
490487

491-
if isinstance(u.owner.op, Unbroadcast):
492-
# Subtensor might reduce dim., adapt broadcast pattern accordingly
493-
old_axes = u.owner.op.axes
494-
new_axes = []
495-
496-
# loop through indices being subtensor-ed
497-
# i indexes broadcastable pattern before subtensor
498-
# j indexes broadcastable pattern after subtensor
499-
j = 0
500-
for i, x in enumerate(node.op.idx_list):
501-
# if it is not a slice, it will reduce the dimension, should
502-
# not appear in the broascastable dimensions
503-
if isinstance(x, slice):
504-
if i in old_axes:
505-
new_axes.append(j)
506-
j += 1
507-
# now keep the broadcastable pattern of all
508-
# items not appearing in subtensor list
509-
for i in range(len(node.op.idx_list), len(u.broadcastable)):
510-
if i in old_axes:
511-
new_axes.append(j)
512-
j += 1
513-
514-
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
515-
# Copy over previous output stacktrace
516-
copy_stack_trace(node.outputs[0], subt_x)
517-
518-
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
519-
# Copy over previous output stacktrace
520-
# and stacktrace from previous unary operation
521-
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
522-
523-
return [rbcast_subt_x]
524-
525488

526489
@register_canonicalize
527490
@register_specialize

0 commit comments

Comments
 (0)