Skip to content

Commit 7f6676d

Browse files
committed
Cleanup Split methods
1 parent d50db11 commit 7f6676d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pytensor/tensor/basic.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2226,21 +2226,21 @@ def make_node(self, x, axis, splits):
22262226

22272227
return Apply(self, inputs, outputs)
22282228

2229-
def perform(self, node, inputs, outputs):
2229+
def perform(self, node, inputs, outputs_storage):
22302230
x, axis, splits = inputs
22312231

22322232
if len(splits) != self.len_splits:
22332233
raise ValueError("Length of splits is not equal to n_splits")
2234-
if np.sum(splits) != x.shape[axis]:
2234+
if splits.sum() != x.shape[axis]:
22352235
raise ValueError(
2236-
f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}"
2236+
f"Split sizes sum to {splits.sum()}; expected {x.shape[axis]}"
22372237
)
2238-
if np.any(splits < 0):
2238+
if (splits < 0).any():
22392239
raise ValueError("Split sizes cannot be negative")
22402240

22412241
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
2242-
for i, out in enumerate(split_outs):
2243-
outputs[i][0] = out
2242+
for out_storage, out in zip(outputs_storage, split_outs, strict=False):
2243+
out_storage[0] = out
22442244

22452245
def infer_shape(self, fgraph, node, in_shapes):
22462246
axis = node.inputs[1]
@@ -2254,10 +2254,10 @@ def infer_shape(self, fgraph, node, in_shapes):
22542254
out_shapes.append(temp)
22552255
return out_shapes
22562256

2257-
def grad(self, inputs, g_outputs):
2257+
def L_op(self, inputs, outputs, g_outputs):
22582258
"""Join the gradients along the axis that was used to split x."""
22592259
x, axis, n = inputs
2260-
outputs = self(*inputs, return_list=True)
2260+
22612261
# If all the output gradients are disconnected, then so are the inputs
22622262
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
22632263
return [

0 commit comments

Comments
 (0)