diff --git a/pytensor/printing.py b/pytensor/printing.py index b7b71622e8..5bdac3e7fd 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -88,6 +88,7 @@ def debugprint( | FunctionGraph | Sequence[Variable | Apply | Function | FunctionGraph], depth: int = -1, + inner_depth: int = -1, print_type: bool = False, print_shape: bool = False, file: Literal["str"] | TextIO | None = None, @@ -123,6 +124,8 @@ def debugprint( The object(s) to be printed. depth Print graph to this depth (``-1`` for unlimited). + inner_depth + Print inner graph to this depth (``-1`` for unlimited). print_type If ``True``, print the `Type`\s of each `Variable` in the graph. print_shape @@ -161,6 +164,8 @@ def debugprint( Whether to set both `print_destroy_map` and `print_view_map` to ``True``. print_fgraph_inputs Print the inputs of `FunctionGraph`\s. + print_inner_graphs + Whether to print the inner graphs of `Op`\s Returns ------- @@ -293,10 +298,14 @@ def debugprint( ): if hasattr(var.owner, "op"): if ( - isinstance(var.owner.op, HasInnerGraph) - or hasattr(var.owner.op, "scalar_op") - and isinstance(var.owner.op.scalar_op, HasInnerGraph) - ) and var not in inner_graph_vars: + ( + isinstance(var.owner.op, HasInnerGraph) + or hasattr(var.owner.op, "scalar_op") + and isinstance(var.owner.op.scalar_op, HasInnerGraph) + ) + and inner_depth + and var not in inner_graph_vars + ): inner_graph_vars.append(var) if print_op_info: op_information.update(op_debug_information(var.owner.op, var.owner)) @@ -322,7 +331,7 @@ def debugprint( print_view_map=print_view_map, ) - if len(inner_graph_vars) > 0: + if len(inner_graph_vars) > 0 and inner_depth: print("", file=_file) prefix = "" new_prefix = prefix + " ← " @@ -374,7 +383,7 @@ def debugprint( _debugprint( ig_var, prefix=prefix, - depth=depth, + depth=inner_depth, done=done, print_type=print_type, print_shape=print_shape, @@ -397,7 +406,7 @@ def debugprint( _debugprint( inp, prefix=" → ", - depth=depth, + depth=inner_depth, done=done, print_type=print_type, print_shape=print_shape, @@ -432,7 +441,7 @@ def debugprint( _debugprint( out, prefix=new_prefix, - depth=depth, + depth=inner_depth, done=done, print_type=print_type, print_shape=print_shape,