File tree 4 files changed +38
-0
lines changed
4 files changed +38
-0
lines changed Original file line number Diff line number Diff line change @@ -1097,6 +1097,18 @@ def sync_shared(self):
1097
1097
# NOTE: sync was needed on old gpu backend
1098
1098
pass
1099
1099
1100
+ def dprint (self , ** kwargs ):
1101
+ """Debug print itself
1102
+
1103
+ Parameters
1104
+ ----------
1105
+ kwargs:
1106
+ Optional keyword arguments to pass to debugprint function.
1107
+ """
1108
+ from pytensor .printing import debugprint
1109
+
1110
+ return debugprint (self , ** kwargs )
1111
+
1100
1112
1101
1113
# pickling/deepcopy support for Function
1102
1114
def _pickle_Function (f ):
Original file line number Diff line number Diff line change @@ -927,3 +927,15 @@ def __contains__(self, item: Variable | Apply) -> bool:
927
927
return item in self .apply_nodes
928
928
else :
929
929
raise TypeError ()
930
+
931
+ def dprint (self , ** kwargs ):
932
+ """Debug print itself
933
+
934
+ Parameters
935
+ ----------
936
+ kwargs:
937
+ Optional keyword arguments to pass to debugprint function.
938
+ """
939
+ from pytensor .printing import debugprint
940
+
941
+ return debugprint (self , ** kwargs )
Original file line number Diff line number Diff line change 16
16
from pytensor .graph .rewriting .basic import OpKeyGraphRewriter , PatternNodeRewriter
17
17
from pytensor .graph .utils import MissingInputError
18
18
from pytensor .link .vm import VMLinker
19
+ from pytensor .printing import debugprint
19
20
from pytensor .tensor .math import dot , tanh
20
21
from pytensor .tensor .math import sum as pt_sum
21
22
from pytensor .tensor .type import (
@@ -862,6 +863,12 @@ def test_key_string_requirement(self):
862
863
with pytest .raises (AssertionError ):
863
864
function ([x ], outputs = {(1 , "b" ): x , 1.0 : x ** 2 })
864
865
866
+ def test_dprint (self ):
867
+ x = pt .scalar ("x" )
868
+ out = x + 1
869
+ f = function ([x ], out )
870
+ assert f .dprint (file = "str" ) == debugprint (f , file = "str" )
871
+
865
872
866
873
class TestPicklefunction :
867
874
def test_deepcopy (self ):
Original file line number Diff line number Diff line change 8
8
from pytensor .graph .basic import NominalVariable
9
9
from pytensor .graph .fg import FunctionGraph
10
10
from pytensor .graph .utils import MissingInputError
11
+ from pytensor .printing import debugprint
11
12
from tests .graph .utils import (
12
13
MyConstant ,
13
14
MyOp ,
@@ -706,3 +707,9 @@ def test_nominals(self):
706
707
assert nm2 not in fg .inputs
707
708
assert nm in fg .variables
708
709
assert nm2 in fg .variables
710
+
711
+ def test_dprint (self ):
712
+ r1 , r2 = MyVariable ("x" ), MyVariable ("y" )
713
+ o1 = op1 (r1 , r2 )
714
+ fg = FunctionGraph ([r1 , r2 ], [o1 ], clone = False )
715
+ assert fg .dprint (file = "str" ) == debugprint (fg , file = "str" )
You can’t perform that action at this time.
0 commit comments