Skip to content

Commit 33f94d9

Browse files
committed
Replace MarginalModel by model transforms
1 parent 612db93 commit 33f94d9

12 files changed

+922
-678
lines changed

README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ import pymc as pm
2626
import pymc_extras as pmx
2727

2828
with pm.Model():
29+
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
2930

30-
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
31-
32-
...
31+
...
3332

3433
```
3534

docs/api_reference.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ methods in the current release of PyMC experimental.
1212
:toctree: generated/
1313

1414
as_model
15-
MarginalModel
1615
marginalize
16+
recover_marginals
1717
model_builder.ModelBuilder
1818

1919
Inference
@@ -53,6 +53,7 @@ Utils
5353

5454
spline.bspline_interpolation
5555
prior.prior_from_idata
56+
model_equivalence.equivalent_models
5657

5758

5859
Statespace Models

pymc_extras/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
1818
from pymc_extras.inference.fit import fit
19-
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
19+
from pymc_extras.model.marginal.marginal_model import (
20+
MarginalModel,
21+
marginalize,
22+
recover_marginals,
23+
)
2024
from pymc_extras.model.model_api import as_model
2125
from pymc_extras.version import __version__
2226

pymc_extras/distributions/timeseries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def transition(*args):
214214
discrete_mc_op = DiscreteMarkovChainRV(
215215
inputs=[P_, steps_, init_dist_, state_rng],
216216
outputs=[state_next_rng, discrete_mc_],
217-
ndim_supp=1,
218217
n_lags=n_lags,
218+
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
219219
)
220220

221221
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)

pymc_extras/model/marginal/distributions.py

+96-4
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,36 @@
1+
import warnings
2+
13
from collections.abc import Sequence
24

35
import numpy as np
46
import pytensor.tensor as pt
57

68
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
9+
from pymc.distributions.distribution import _support_point, support_point
710
from pymc.logprob.abstract import MeasurableOp, _logprob
811
from pymc.logprob.basic import conditional_logp, logp
9-
from pymc.pytensorf import constant_fold
12+
from pymc.model.fgraph import ModelVar
13+
from pymc.pytensorf import constant_fold, StringType
1014
from pytensor import Variable
1115
from pytensor.compile.builders import OpFromGraph
1216
from pytensor.compile.mode import Mode
13-
from pytensor.graph import Op, vectorize_graph
17+
from pytensor.graph import FunctionGraph, Op, vectorize_graph
18+
from pytensor.graph.basic import equal_computations, Apply
1419
from pytensor.graph.replace import clone_replace, graph_replace
1520
from pytensor.scan import map as scan_map
1621
from pytensor.scan import scan
1722
from pytensor.tensor import TensorVariable
23+
from pytensor.tensor.random.type import RandomType
1824

1925
from pymc_extras.distributions import DiscreteMarkovChain
2026

2127

2228
class MarginalRV(OpFromGraph, MeasurableOp):
2329
"""Base class for Marginalized RVs"""
2430

25-
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
31+
def __init__(self, *args, dims_connections: tuple[tuple[int | None], ...], dims: tuple[Variable, ...], **kwargs) -> None:
2632
self.dims_connections = dims_connections
33+
self.dims = dims
2734
super().__init__(*args, **kwargs)
2835

2936
@property
@@ -43,6 +50,74 @@ def support_axes(self) -> tuple[tuple[int]]:
4350
)
4451
return tuple(support_axes_vars)
4552

53+
def __eq__(self, other):
54+
# Just to allow easy testing of equivalent models,
55+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
56+
if type(self) is not type(other):
57+
return False
58+
59+
return equal_computations(
60+
self.inner_outputs,
61+
other.inner_outputs,
62+
self.inner_inputs,
63+
other.inner_inputs,
64+
)
65+
66+
def __hash__(self):
67+
# Just to allow easy testing of equivalent models,
68+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
69+
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))
70+
71+
72+
@_support_point.register
73+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
74+
"""Support point for a marginalized RV.
75+
76+
The support point of a marginalized RV is the support point of the inner RV,
77+
conditioned on the marginalized RV taking its support point.
78+
"""
79+
outputs = rv.owner.outputs
80+
81+
inner_rv = op.inner_outputs[outputs.index(rv)]
82+
marginalized_inner_rv, *other_dependent_inner_rvs = (
83+
out
84+
for out in op.inner_outputs
85+
if out is not inner_rv and not isinstance(out.type, RandomType)
86+
)
87+
88+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
89+
# This is necessary because the inner RVs may depend on each other
90+
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
91+
other_dependent_inner_rv_to_dummies = {
92+
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
93+
}
94+
inner_rv = clone_replace(
95+
inner_rv,
96+
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
97+
| other_dependent_inner_rv_to_dummies,
98+
)
99+
100+
# Get support point of inner RV and marginalized RV
101+
inner_rv_support_point = support_point(inner_rv)
102+
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
103+
104+
replacements = [
105+
# Replace the marginalized RV dummy by its support point
106+
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
107+
# Replace other dependent RVs dummies by the respective outer outputs.
108+
# PyMC will replace them by their support points later
109+
*(
110+
(v, outputs[op.inner_outputs.index(k)])
111+
for k, v in other_dependent_inner_rv_to_dummies.items()
112+
),
113+
# Replace outer input RVs
114+
*zip(op.inner_inputs, inputs),
115+
]
116+
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
117+
fgraph.replace_all(replacements, import_missing=True)
118+
[rv_support_point] = fgraph.outputs
119+
return rv_support_point
120+
46121

47122
class MarginalFiniteDiscreteRV(MarginalRV):
48123
"""Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +207,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132207
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133208
the inner graph.
134209
"""
135-
return clone_replace(
210+
return graph_replace(
136211
op.inner_outputs,
137212
replace=tuple(zip(op.inner_inputs, inputs)),
213+
strict=False,
138214
)
139215

140216

217+
class NonSeparableLogpWarning(UserWarning):
218+
pass
219+
220+
221+
def warn_non_separable_logp(values):
222+
if len(values) > 1:
223+
warnings.warn(
224+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
225+
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
226+
NonSeparableLogpWarning,
227+
stacklevel=2,
228+
)
229+
230+
141231
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142232

143233

@@ -199,6 +289,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199289
# Align logp with non-collapsed batch dimensions of first RV
200290
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
201291

292+
warn_non_separable_logp(values)
202293
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203294
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204295
return joint_logp, *dummy_logps
@@ -272,5 +363,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272363

273364
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274365
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
366+
warn_non_separable_logp(values)
275367
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
276368
return joint_logp, *dummy_logps

pymc_extras/model/marginal/graph_analysis.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7+
from pymc.model.fgraph import ModelVar
78
from pytensor.compile import SharedVariable
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
3536

3637
def find_conditional_input_rvs(output_rvs, all_rvs):
3738
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
39+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
40+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
4041
return [
4142
var
4243
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44+
if var in other_rvs
4445
]
4546

4647

@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141142
# None of the inputs are related to the batch_axes of the input_vars
142143
continue
143144

145+
elif isinstance(node.op, ModelVar):
146+
var_dims[node.outputs[0]] = inputs_dims[0]
147+
144148
elif isinstance(node.op, DimShuffle):
145149
[input_dims] = inputs_dims
146150
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)