Skip to content

Commit dae731d

Browse files
committed
Optimize blockwise fallback gufunc function
1 parent e39fda3 commit dae731d

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

pytensor/tensor/blockwise.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,23 @@ def _create_node_gufunc(self, node) -> None:
308308
# Wrap core_op perform method in numpy vectorize
309309
n_outs = len(self.outputs_sig)
310310
core_node = self._create_dummy_core_node(node.inputs)
311+
inner_outputs_storage = [[None] for _ in range(n_outs)]
312+
313+
def core_func(
314+
*inner_inputs,
315+
core_node=core_node,
316+
inner_outputs_storage=inner_outputs_storage,
317+
):
318+
self.core_op.perform(
319+
core_node,
320+
[np.asarray(inp) for inp in inner_inputs],
321+
inner_outputs_storage,
322+
)
311323

312-
def core_func(*inner_inputs):
313-
inner_outputs = [[None] for _ in range(n_outs)]
314-
315-
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
316-
self.core_op.perform(core_node, inner_inputs, inner_outputs)
317-
318-
if len(inner_outputs) == 1:
319-
return inner_outputs[0][0]
324+
if n_outs == 1:
325+
return inner_outputs_storage[0][0]
320326
else:
321-
return tuple(r[0] for r in inner_outputs)
327+
return tuple(r[0] for r in inner_outputs_storage)
322328

323329
gufunc = np.vectorize(core_func, signature=self.signature)
324330

0 commit comments

Comments
 (0)