Skip to content

Commit 5c75d1a

Browse files
committed
Speedup node eval
1 parent c29963b commit 5c75d1a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pytensor/graph/op.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -513,17 +513,24 @@ def make_py_thunk(
513513
"""
514514
node_input_storage = [storage_map[r] for r in node.inputs]
515515
node_output_storage = [storage_map[r] for r in node.outputs]
516+
node_compute_map = [compute_map[r] for r in node.outputs]
516517

517518
if debug and hasattr(self, "debug_perform"):
518519
p = node.op.debug_perform
519520
else:
520521
p = node.op.perform
521522

522523
@is_thunk_type
523-
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
524+
def rval(
525+
p=p,
526+
i=node_input_storage,
527+
o=node_output_storage,
528+
n=node,
529+
cm=node_compute_map,
530+
):
524531
r = p(n, [x[0] for x in i], o)
525-
for o in node.outputs:
526-
compute_map[o][0] = True
532+
for entry in cm:
533+
entry[0] = True
527534
return r
528535

529536
rval.inputs = node_input_storage

0 commit comments

Comments
 (0)