Skip to content

Commit 66432ec

Browse files
committed
Add inner_depth parameter to debugprint for depth controlled inner graph printing
1 parent af698a7 commit 66432ec

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

pytensor/printing.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def debugprint(
8888
| FunctionGraph
8989
| Sequence[Variable | Apply | Function | FunctionGraph],
9090
depth: int = -1,
91+
inner_depth: int = -1,
9192
print_type: bool = False,
9293
print_shape: bool = False,
9394
file: Literal["str"] | TextIO | None = None,
@@ -299,7 +300,7 @@ def debugprint(
299300
isinstance(var.owner.op, HasInnerGraph)
300301
or hasattr(var.owner.op, "scalar_op")
301302
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
302-
) and var not in inner_graph_vars:
303+
) and not inner_depth and var not in inner_graph_vars:
303304
inner_graph_vars.append(var)
304305
if print_op_info:
305306
op_information.update(op_debug_information(var.owner.op, var.owner))
@@ -325,7 +326,7 @@ def debugprint(
325326
print_view_map=print_view_map,
326327
)
327328

328-
if len(inner_graph_vars) > 0 and print_inner_graphs:
329+
if len(inner_graph_vars) > 0 and inner_depth:
329330
print("", file=_file)
330331
prefix = ""
331332
new_prefix = prefix + " ← "
@@ -377,7 +378,7 @@ def debugprint(
377378
_debugprint(
378379
ig_var,
379380
prefix=prefix,
380-
depth=depth,
381+
depth=inner_depth,
381382
done=done,
382383
print_type=print_type,
383384
print_shape=print_shape,
@@ -400,7 +401,7 @@ def debugprint(
400401
_debugprint(
401402
inp,
402403
prefix=" → ",
403-
depth=depth,
404+
depth=inner_depth,
404405
done=done,
405406
print_type=print_type,
406407
print_shape=print_shape,
@@ -435,7 +436,7 @@ def debugprint(
435436
_debugprint(
436437
out,
437438
prefix=new_prefix,
438-
depth=depth,
439+
depth=inner_depth,
439440
done=done,
440441
print_type=print_type,
441442
print_shape=print_shape,

0 commit comments

Comments
 (0)