@@ -2226,21 +2226,21 @@ def make_node(self, x, axis, splits):
2226
2226
2227
2227
return Apply (self , inputs , outputs )
2228
2228
2229
- def perform (self , node , inputs , outputs ):
2229
+ def perform (self , node , inputs , outputs_storage ):
2230
2230
x , axis , splits = inputs
2231
2231
2232
2232
if len (splits ) != self .len_splits :
2233
2233
raise ValueError ("Length of splits is not equal to n_splits" )
2234
- if np .sum (splits ) != x .shape [axis ]:
2234
+ if splits .sum () != x .shape [axis ]:
2235
2235
raise ValueError (
2236
- f"Split sizes sum to { np .sum (splits )} ; expected { x .shape [axis ]} "
2236
+ f"Split sizes sum to { splits .sum ()} ; expected { x .shape [axis ]} "
2237
2237
)
2238
- if np . any (splits < 0 ):
2238
+ if (splits < 0 ). any ( ):
2239
2239
raise ValueError ("Split sizes cannot be negative" )
2240
2240
2241
2241
split_outs = np .split (x , np .cumsum (splits [:- 1 ]), axis = axis )
2242
- for i , out in enumerate ( split_outs ):
2243
- outputs [ i ] [0 ] = out
2242
+ for out_storage , out in zip ( outputs_storage , split_outs , strict = False ):
2243
+ out_storage [0 ] = out
2244
2244
2245
2245
def infer_shape (self , fgraph , node , in_shapes ):
2246
2246
axis = node .inputs [1 ]
@@ -2254,10 +2254,10 @@ def infer_shape(self, fgraph, node, in_shapes):
2254
2254
out_shapes .append (temp )
2255
2255
return out_shapes
2256
2256
2257
- def grad (self , inputs , g_outputs ):
2257
+ def L_op (self , inputs , outputs , g_outputs ):
2258
2258
"""Join the gradients along the axis that was used to split x."""
2259
2259
x , axis , n = inputs
2260
- outputs = self ( * inputs , return_list = True )
2260
+
2261
2261
# If all the output gradients are disconnected, then so are the inputs
2262
2262
if builtins .all (isinstance (g .type , DisconnectedType ) for g in g_outputs ):
2263
2263
return [
0 commit comments