1
+ from collections .abc import Sequence
2
+
1
3
import numpy as np
2
4
import pytest
3
5
4
6
import pytensor .tensor as pt
5
7
from pytensor import config , shared
6
8
from pytensor .compile .function import function
7
9
from pytensor .compile .mode import Mode
8
- from pytensor .graph .basic import Constant
10
+ from pytensor .graph .basic import Constant , Variable , ancestors
9
11
from pytensor .graph .fg import FunctionGraph
10
12
from pytensor .graph .rewriting .basic import EquilibriumGraphRewriter
11
13
from pytensor .graph .rewriting .db import RewriteDatabaseQuery
36
38
no_mode = Mode ("py" , RewriteDatabaseQuery (include = [], exclude = []))
37
39
38
40
41
+ def count_rv_nodes_in_graph (outputs : Sequence [Variable ]) -> int :
42
+ return len (
43
+ {
44
+ var .owner
45
+ for var in ancestors (outputs )
46
+ if var .owner and isinstance (var .owner .op , RandomVariable )
47
+ }
48
+ )
49
+
50
+
39
51
def apply_local_rewrite_to_rv (
40
52
rewrite , op_fn , dist_op , dist_params , size , rng , name = None
41
53
):
@@ -58,7 +70,14 @@ def apply_local_rewrite_to_rv(
58
70
s_pt .tag .test_value = s
59
71
size_pt .append (s_pt )
60
72
61
- dist_st = op_fn (dist_op (* dist_params_pt , size = size_pt , rng = rng , name = name ))
73
+ next_rng , rv = dist_op (
74
+ * dist_params_pt , size = size_pt , rng = rng , name = name
75
+ ).owner .outputs
76
+ dist_st = op_fn (rv )
77
+
78
+ assert (
79
+ count_rv_nodes_in_graph ([dist_st , next_rng ]) == 1
80
+ ), "Function expects a single RV in the graph"
62
81
63
82
f_inputs = [
64
83
p
@@ -72,13 +91,16 @@ def apply_local_rewrite_to_rv(
72
91
73
92
f_rewritten = function (
74
93
f_inputs ,
75
- dist_st ,
94
+ [ dist_st , next_rng ] ,
76
95
mode = mode ,
77
96
)
78
97
79
- (new_out ,) = f_rewritten .maker .fgraph .outputs
98
+ new_rv , new_next_rng = f_rewritten .maker .fgraph .outputs
99
+ assert (
100
+ count_rv_nodes_in_graph ([new_rv , new_next_rng ]) == 1
101
+ ), "Rewritten should have a single RV in the graph"
80
102
81
- return new_out , f_inputs , dist_st , f_rewritten
103
+ return new_rv , f_inputs , dist_st , f_rewritten
82
104
83
105
84
106
class TestRVExpraProps (RandomVariable ):
@@ -422,7 +444,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
422
444
423
445
arg_values = [p .get_test_value () for p in f_inputs ]
424
446
res_base = f_base (* arg_values )
425
- res_rewritten = f_rewritten (* arg_values )
447
+ res_rewritten , _ = f_rewritten (* arg_values )
426
448
427
449
np .testing .assert_allclose (res_base , res_rewritten , rtol = rtol )
428
450
@@ -825,7 +847,7 @@ def is_subtensor_or_dimshuffle_subtensor(inp) -> bool:
825
847
826
848
arg_values = [p .get_test_value () for p in f_inputs ]
827
849
res_base = f_base (* arg_values )
828
- res_rewritten = f_rewritten (* arg_values )
850
+ res_rewritten , _ = f_rewritten (* arg_values )
829
851
830
852
np .testing .assert_allclose (res_base , res_rewritten , rtol = 1e-3 , atol = 1e-2 )
831
853
0 commit comments