@@ -88,6 +88,7 @@ def debugprint(
88
88
| FunctionGraph
89
89
| Sequence [Variable | Apply | Function | FunctionGraph ],
90
90
depth : int = - 1 ,
91
+ inner_depth : int = - 1 ,
91
92
print_type : bool = False ,
92
93
print_shape : bool = False ,
93
94
file : Literal ["str" ] | TextIO | None = None ,
@@ -299,7 +300,7 @@ def debugprint(
299
300
isinstance (var .owner .op , HasInnerGraph )
300
301
or hasattr (var .owner .op , "scalar_op" )
301
302
and isinstance (var .owner .op .scalar_op , HasInnerGraph )
302
- ) and var not in inner_graph_vars :
303
+ ) and not inner_depth and var not in inner_graph_vars :
303
304
inner_graph_vars .append (var )
304
305
if print_op_info :
305
306
op_information .update (op_debug_information (var .owner .op , var .owner ))
@@ -325,7 +326,7 @@ def debugprint(
325
326
print_view_map = print_view_map ,
326
327
)
327
328
328
- if len (inner_graph_vars ) > 0 and print_inner_graphs :
329
+ if len (inner_graph_vars ) > 0 and inner_depth :
329
330
print ("" , file = _file )
330
331
prefix = ""
331
332
new_prefix = prefix + " ← "
@@ -377,7 +378,7 @@ def debugprint(
377
378
_debugprint (
378
379
ig_var ,
379
380
prefix = prefix ,
380
- depth = depth ,
381
+ depth = inner_depth ,
381
382
done = done ,
382
383
print_type = print_type ,
383
384
print_shape = print_shape ,
@@ -400,7 +401,7 @@ def debugprint(
400
401
_debugprint (
401
402
inp ,
402
403
prefix = " → " ,
403
- depth = depth ,
404
+ depth = inner_depth ,
404
405
done = done ,
405
406
print_type = print_type ,
406
407
print_shape = print_shape ,
@@ -435,7 +436,7 @@ def debugprint(
435
436
_debugprint (
436
437
out ,
437
438
prefix = new_prefix ,
438
- depth = depth ,
439
+ depth = inner_depth ,
439
440
done = done ,
440
441
print_type = print_type ,
441
442
print_shape = print_shape ,
0 commit comments