Skip to content

Commit df769f6

Browse files
committed
Add dprint shortcut to FunctionGraph and Function
1 parent 448e558 commit df769f6

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

pytensor/compile/function/types.py

+12
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,18 @@ def sync_shared(self):
10971097
# NOTE: sync was needed on old gpu backend
10981098
pass
10991099

1100+
def dprint(self, **kwargs):
1101+
"""Debug print itself
1102+
1103+
Parameters
1104+
----------
1105+
kwargs:
1106+
Optional keyword arguments to pass to debugprint function.
1107+
"""
1108+
from pytensor.printing import debugprint
1109+
1110+
return debugprint(self, **kwargs)
1111+
11001112

11011113
# pickling/deepcopy support for Function
11021114
def _pickle_Function(f):

pytensor/graph/fg.py

+12
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,15 @@ def __contains__(self, item: Variable | Apply) -> bool:
927927
return item in self.apply_nodes
928928
else:
929929
raise TypeError()
930+
931+
def dprint(self, **kwargs):
932+
"""Debug print itself
933+
934+
Parameters
935+
----------
936+
kwargs:
937+
Optional keyword arguments to pass to debugprint function.
938+
"""
939+
from pytensor.printing import debugprint
940+
941+
return debugprint(self, **kwargs)

tests/compile/function/test_types.py

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
1717
from pytensor.graph.utils import MissingInputError
1818
from pytensor.link.vm import VMLinker
19+
from pytensor.printing import debugprint
1920
from pytensor.tensor.math import dot, tanh
2021
from pytensor.tensor.math import sum as pt_sum
2122
from pytensor.tensor.type import (
@@ -862,6 +863,12 @@ def test_key_string_requirement(self):
862863
with pytest.raises(AssertionError):
863864
function([x], outputs={(1, "b"): x, 1.0: x**2})
864865

866+
def test_dprint(self):
867+
x = pt.scalar("x")
868+
out = x + 1
869+
f = function([x], out)
870+
assert f.dprint(file="str") == debugprint(f, file="str")
871+
865872

866873
class TestPicklefunction:
867874
def test_deepcopy(self):

tests/graph/test_fg.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.graph.basic import NominalVariable
99
from pytensor.graph.fg import FunctionGraph
1010
from pytensor.graph.utils import MissingInputError
11+
from pytensor.printing import debugprint
1112
from tests.graph.utils import (
1213
MyConstant,
1314
MyOp,
@@ -706,3 +707,9 @@ def test_nominals(self):
706707
assert nm2 not in fg.inputs
707708
assert nm in fg.variables
708709
assert nm2 in fg.variables
710+
711+
def test_dprint(self):
712+
r1, r2 = MyVariable("x"), MyVariable("y")
713+
o1 = op1(r1, r2)
714+
fg = FunctionGraph([r1, r2], [o1], clone=False)
715+
assert fg.dprint(file="str") == debugprint(fg, file="str")

0 commit comments

Comments
 (0)