1
+ import warnings
2
+
1
3
from collections .abc import Sequence
2
4
3
5
import numpy as np
4
6
import pytensor .tensor as pt
5
7
6
8
from pymc .distributions import Bernoulli , Categorical , DiscreteUniform
9
+ from pymc .distributions .distribution import _support_point , support_point
7
10
from pymc .logprob .abstract import MeasurableOp , _logprob
8
11
from pymc .logprob .basic import conditional_logp , logp
9
12
from pymc .pytensorf import constant_fold
10
13
from pytensor import Variable
11
14
from pytensor .compile .builders import OpFromGraph
12
15
from pytensor .compile .mode import Mode
13
- from pytensor .graph import Op , vectorize_graph
16
+ from pytensor .graph import FunctionGraph , Op , vectorize_graph
17
+ from pytensor .graph .basic import equal_computations
14
18
from pytensor .graph .replace import clone_replace , graph_replace
15
19
from pytensor .scan import map as scan_map
16
20
from pytensor .scan import scan
17
21
from pytensor .tensor import TensorVariable
22
+ from pytensor .tensor .random .type import RandomType
18
23
19
24
from pymc_extras .distributions import DiscreteMarkovChain
20
25
21
26
22
27
class MarginalRV (OpFromGraph , MeasurableOp ):
23
28
"""Base class for Marginalized RVs"""
24
29
25
- def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
30
+ def __init__ (
31
+ self ,
32
+ * args ,
33
+ dims_connections : tuple [tuple [int | None ], ...],
34
+ dims : tuple [Variable , ...],
35
+ ** kwargs ,
36
+ ) -> None :
26
37
self .dims_connections = dims_connections
38
+ self .dims = dims
27
39
super ().__init__ (* args , ** kwargs )
28
40
29
41
@property
@@ -43,6 +55,74 @@ def support_axes(self) -> tuple[tuple[int]]:
43
55
)
44
56
return tuple (support_axes_vars )
45
57
58
+ def __eq__ (self , other ):
59
+ # Just to allow easy testing of equivalent models,
60
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
61
+ if type (self ) is not type (other ):
62
+ return False
63
+
64
+ return equal_computations (
65
+ self .inner_outputs ,
66
+ other .inner_outputs ,
67
+ self .inner_inputs ,
68
+ other .inner_inputs ,
69
+ )
70
+
71
+ def __hash__ (self ):
72
+ # Just to allow easy testing of equivalent models,
73
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
74
+ return hash ((type (self ), len (self .inner_inputs ), len (self .inner_outputs )))
75
+
76
+
77
+ @_support_point .register
78
+ def support_point_marginal_rv (op : MarginalRV , rv , * inputs ):
79
+ """Support point for a marginalized RV.
80
+
81
+ The support point of a marginalized RV is the support point of the inner RV,
82
+ conditioned on the marginalized RV taking its support point.
83
+ """
84
+ outputs = rv .owner .outputs
85
+
86
+ inner_rv = op .inner_outputs [outputs .index (rv )]
87
+ marginalized_inner_rv , * other_dependent_inner_rvs = (
88
+ out
89
+ for out in op .inner_outputs
90
+ if out is not inner_rv and not isinstance (out .type , RandomType )
91
+ )
92
+
93
+ # Replace references to inner rvs by the dummy variables (including the marginalized RV)
94
+ # This is necessary because the inner RVs may depend on each other
95
+ marginalized_inner_rv_dummy = marginalized_inner_rv .clone ()
96
+ other_dependent_inner_rv_to_dummies = {
97
+ inner_rv : inner_rv .clone () for inner_rv in other_dependent_inner_rvs
98
+ }
99
+ inner_rv = clone_replace (
100
+ inner_rv ,
101
+ replace = {marginalized_inner_rv : marginalized_inner_rv_dummy }
102
+ | other_dependent_inner_rv_to_dummies ,
103
+ )
104
+
105
+ # Get support point of inner RV and marginalized RV
106
+ inner_rv_support_point = support_point (inner_rv )
107
+ marginalized_inner_rv_support_point = support_point (marginalized_inner_rv )
108
+
109
+ replacements = [
110
+ # Replace the marginalized RV dummy by its support point
111
+ (marginalized_inner_rv_dummy , marginalized_inner_rv_support_point ),
112
+ # Replace other dependent RVs dummies by the respective outer outputs.
113
+ # PyMC will replace them by their support points later
114
+ * (
115
+ (v , outputs [op .inner_outputs .index (k )])
116
+ for k , v in other_dependent_inner_rv_to_dummies .items ()
117
+ ),
118
+ # Replace outer input RVs
119
+ * zip (op .inner_inputs , inputs ),
120
+ ]
121
+ fgraph = FunctionGraph (outputs = [inner_rv_support_point ], clone = False )
122
+ fgraph .replace_all (replacements , import_missing = True )
123
+ [rv_support_point ] = fgraph .outputs
124
+ return rv_support_point
125
+
46
126
47
127
class MarginalFiniteDiscreteRV (MarginalRV ):
48
128
"""Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132
212
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
213
the inner graph.
134
214
"""
135
- return clone_replace (
215
+ return graph_replace (
136
216
op .inner_outputs ,
137
217
replace = tuple (zip (op .inner_inputs , inputs )),
218
+ strict = False ,
138
219
)
139
220
140
221
222
+ class NonSeparableLogpWarning (UserWarning ):
223
+ pass
224
+
225
+
226
+ def warn_non_separable_logp (values ):
227
+ if len (values ) > 1 :
228
+ warnings .warn (
229
+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
230
+ f"Their joint logp terms will be assigned to the first value: { values [0 ]} ." ,
231
+ NonSeparableLogpWarning ,
232
+ stacklevel = 2 ,
233
+ )
234
+
235
+
141
236
DUMMY_ZERO = pt .constant (0 , name = "dummy_zero" )
142
237
143
238
@@ -199,6 +294,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199
294
# Align logp with non-collapsed batch dimensions of first RV
200
295
joint_logp = align_logp_dims (dims = op .dims_connections [0 ], logp = joint_logp )
201
296
297
+ warn_non_separable_logp (values )
202
298
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203
299
dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
204
300
return joint_logp , * dummy_logps
@@ -272,5 +368,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272
368
273
369
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274
370
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
371
+ warn_non_separable_logp (values )
275
372
dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
276
373
return joint_logp , * dummy_logps
0 commit comments