Skip to content

Commit 49d6682

Browse files
committed
Replace MarginalModel by model transforms
1 parent 612db93 commit 49d6682

12 files changed

+930
-681
lines changed

README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ import pymc as pm
2626
import pymc_extras as pmx
2727

2828
with pm.Model():
29+
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
2930

30-
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
31-
32-
...
31+
...
3332

3433
```
3534

docs/api_reference.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ methods in the current release of PyMC experimental.
1212
:toctree: generated/
1313

1414
as_model
15-
MarginalModel
1615
marginalize
16+
recover_marginals
1717
model_builder.ModelBuilder
1818

1919
Inference
@@ -53,6 +53,7 @@ Utils
5353

5454
spline.bspline_interpolation
5555
prior.prior_from_idata
56+
model_equivalence.equivalent_models
5657

5758

5859
Statespace Models

pymc_extras/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
1818
from pymc_extras.inference.fit import fit
19-
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
19+
from pymc_extras.model.marginal.marginal_model import (
20+
MarginalModel,
21+
marginalize,
22+
recover_marginals,
23+
)
2024
from pymc_extras.model.model_api import as_model
2125
from pymc_extras.version import __version__
2226

pymc_extras/distributions/timeseries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def transition(*args):
214214
discrete_mc_op = DiscreteMarkovChainRV(
215215
inputs=[P_, steps_, init_dist_, state_rng],
216216
outputs=[state_next_rng, discrete_mc_],
217-
ndim_supp=1,
218217
n_lags=n_lags,
218+
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
219219
)
220220

221221
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)

pymc_extras/model/marginal/distributions.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,41 @@
1+
import warnings
2+
13
from collections.abc import Sequence
24

35
import numpy as np
46
import pytensor.tensor as pt
57

68
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
9+
from pymc.distributions.distribution import _support_point, support_point
710
from pymc.logprob.abstract import MeasurableOp, _logprob
811
from pymc.logprob.basic import conditional_logp, logp
912
from pymc.pytensorf import constant_fold
1013
from pytensor import Variable
1114
from pytensor.compile.builders import OpFromGraph
1215
from pytensor.compile.mode import Mode
13-
from pytensor.graph import Op, vectorize_graph
16+
from pytensor.graph import FunctionGraph, Op, vectorize_graph
17+
from pytensor.graph.basic import equal_computations
1418
from pytensor.graph.replace import clone_replace, graph_replace
1519
from pytensor.scan import map as scan_map
1620
from pytensor.scan import scan
1721
from pytensor.tensor import TensorVariable
22+
from pytensor.tensor.random.type import RandomType
1823

1924
from pymc_extras.distributions import DiscreteMarkovChain
2025

2126

2227
class MarginalRV(OpFromGraph, MeasurableOp):
2328
"""Base class for Marginalized RVs"""
2429

25-
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
30+
def __init__(
31+
self,
32+
*args,
33+
dims_connections: tuple[tuple[int | None], ...],
34+
dims: tuple[Variable, ...],
35+
**kwargs,
36+
) -> None:
2637
self.dims_connections = dims_connections
38+
self.dims = dims
2739
super().__init__(*args, **kwargs)
2840

2941
@property
@@ -43,6 +55,74 @@ def support_axes(self) -> tuple[tuple[int]]:
4355
)
4456
return tuple(support_axes_vars)
4557

58+
def __eq__(self, other):
59+
# Just to allow easy testing of equivalent models,
60+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
61+
if type(self) is not type(other):
62+
return False
63+
64+
return equal_computations(
65+
self.inner_outputs,
66+
other.inner_outputs,
67+
self.inner_inputs,
68+
other.inner_inputs,
69+
)
70+
71+
def __hash__(self):
72+
# Just to allow easy testing of equivalent models,
73+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
74+
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))
75+
76+
77+
@_support_point.register
78+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
79+
"""Support point for a marginalized RV.
80+
81+
The support point of a marginalized RV is the support point of the inner RV,
82+
conditioned on the marginalized RV taking its support point.
83+
"""
84+
outputs = rv.owner.outputs
85+
86+
inner_rv = op.inner_outputs[outputs.index(rv)]
87+
marginalized_inner_rv, *other_dependent_inner_rvs = (
88+
out
89+
for out in op.inner_outputs
90+
if out is not inner_rv and not isinstance(out.type, RandomType)
91+
)
92+
93+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
94+
# This is necessary because the inner RVs may depend on each other
95+
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
96+
other_dependent_inner_rv_to_dummies = {
97+
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
98+
}
99+
inner_rv = clone_replace(
100+
inner_rv,
101+
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
102+
| other_dependent_inner_rv_to_dummies,
103+
)
104+
105+
# Get support point of inner RV and marginalized RV
106+
inner_rv_support_point = support_point(inner_rv)
107+
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
108+
109+
replacements = [
110+
# Replace the marginalized RV dummy by its support point
111+
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
112+
# Replace other dependent RVs dummies by the respective outer outputs.
113+
# PyMC will replace them by their support points later
114+
*(
115+
(v, outputs[op.inner_outputs.index(k)])
116+
for k, v in other_dependent_inner_rv_to_dummies.items()
117+
),
118+
# Replace outer input RVs
119+
*zip(op.inner_inputs, inputs),
120+
]
121+
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
122+
fgraph.replace_all(replacements, import_missing=True)
123+
[rv_support_point] = fgraph.outputs
124+
return rv_support_point
125+
46126

47127
class MarginalFiniteDiscreteRV(MarginalRV):
48128
"""Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132212
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133213
the inner graph.
134214
"""
135-
return clone_replace(
215+
return graph_replace(
136216
op.inner_outputs,
137217
replace=tuple(zip(op.inner_inputs, inputs)),
218+
strict=False,
138219
)
139220

140221

222+
class NonSeparableLogpWarning(UserWarning):
223+
pass
224+
225+
226+
def warn_non_separable_logp(values):
227+
if len(values) > 1:
228+
warnings.warn(
229+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
230+
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
231+
NonSeparableLogpWarning,
232+
stacklevel=2,
233+
)
234+
235+
141236
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142237

143238

@@ -199,6 +294,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199294
# Align logp with non-collapsed batch dimensions of first RV
200295
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
201296

297+
warn_non_separable_logp(values)
202298
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203299
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204300
return joint_logp, *dummy_logps
@@ -272,5 +368,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272368

273369
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274370
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
371+
warn_non_separable_logp(values)
275372
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
276373
return joint_logp, *dummy_logps

pymc_extras/model/marginal/graph_analysis.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7-
from pytensor.compile import SharedVariable
8-
from pytensor.graph import Constant, Variable, ancestors
7+
from pymc.model.fgraph import ModelVar
8+
from pytensor.graph import Variable, ancestors
99
from pytensor.graph.basic import io_toposort
1010
from pytensor.tensor import TensorType, TensorVariable
1111
from pytensor.tensor.blockwise import Blockwise
@@ -35,13 +35,9 @@ def static_shape_ancestors(vars):
3535

3636
def find_conditional_input_rvs(output_rvs, all_rvs):
3737
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
40-
return [
41-
var
42-
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44-
]
38+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
40+
return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs]
4541

4642

4743
def is_conditional_dependent(
@@ -141,6 +137,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141137
# None of the inputs are related to the batch_axes of the input_vars
142138
continue
143139

140+
elif isinstance(node.op, ModelVar):
141+
var_dims[node.outputs[0]] = inputs_dims[0]
142+
144143
elif isinstance(node.op, DimShuffle):
145144
[input_dims] = inputs_dims
146145
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)