1+ from collections .abc import Sequence
2+
13import numpy as np
24import pytest
35
46import pytensor .tensor as pt
57from pytensor import config , shared
68from pytensor .compile .function import function
79from pytensor .compile .mode import Mode
8- from pytensor .graph .basic import Constant
10+ from pytensor .graph .basic import Constant , Variable , ancestors
911from pytensor .graph .fg import FunctionGraph
1012from pytensor .graph .rewriting .basic import EquilibriumGraphRewriter
1113from pytensor .graph .rewriting .db import RewriteDatabaseQuery
3638no_mode = Mode ("py" , RewriteDatabaseQuery (include = [], exclude = []))
3739
3840
41+ def count_rv_nodes_in_graph (outputs : Sequence [Variable ]) -> int :
42+ return len (
43+ {
44+ var .owner
45+ for var in ancestors (outputs )
46+ if var .owner and isinstance (var .owner .op , RandomVariable )
47+ }
48+ )
49+
50+
3951def apply_local_rewrite_to_rv (
4052 rewrite , op_fn , dist_op , dist_params , size , rng , name = None
4153):
@@ -58,7 +70,14 @@ def apply_local_rewrite_to_rv(
5870 s_pt .tag .test_value = s
5971 size_pt .append (s_pt )
6072
61- dist_st = op_fn (dist_op (* dist_params_pt , size = size_pt , rng = rng , name = name ))
73+ next_rng , rv = dist_op (
74+ * dist_params_pt , size = size_pt , rng = rng , name = name
75+ ).owner .outputs
76+ dist_st = op_fn (rv )
77+
78+ assert (
79+ count_rv_nodes_in_graph ([dist_st , next_rng ]) == 1
80+ ), "Function expects a single RV in the graph"
6281
6382 f_inputs = [
6483 p
@@ -72,13 +91,16 @@ def apply_local_rewrite_to_rv(
7291
7392 f_rewritten = function (
7493 f_inputs ,
75- dist_st ,
94+ [ dist_st , next_rng ] ,
7695 mode = mode ,
7796 )
7897
79- (new_out ,) = f_rewritten .maker .fgraph .outputs
98+ new_rv , new_next_rng = f_rewritten .maker .fgraph .outputs
99+ assert (
100+ count_rv_nodes_in_graph ([new_rv , new_next_rng ]) == 1
101+ ), "Rewritten should have a single RV in the graph"
80102
81- return new_out , f_inputs , dist_st , f_rewritten
103+ return new_rv , f_inputs , dist_st , f_rewritten
82104
83105
84106class TestRVExpraProps (RandomVariable ):
@@ -422,7 +444,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
422444
423445 arg_values = [p .get_test_value () for p in f_inputs ]
424446 res_base = f_base (* arg_values )
425- res_rewritten = f_rewritten (* arg_values )
447+ res_rewritten , _ = f_rewritten (* arg_values )
426448
427449 np .testing .assert_allclose (res_base , res_rewritten , rtol = rtol )
428450
@@ -825,7 +847,7 @@ def is_subtensor_or_dimshuffle_subtensor(inp) -> bool:
825847
826848 arg_values = [p .get_test_value () for p in f_inputs ]
827849 res_base = f_base (* arg_values )
828- res_rewritten = f_rewritten (* arg_values )
850+ res_rewritten , _ = f_rewritten (* arg_values )
829851
830852 np .testing .assert_allclose (res_base , res_rewritten , rtol = 1e-3 , atol = 1e-2 )
831853
0 commit comments