Skip to content

Commit 0a97875

Browse files
committed
Add QMC marginalization
1 parent 83ebb80 commit 0a97875

File tree

2 files changed

+119
-5
lines changed

2 files changed

+119
-5
lines changed

pymc_experimental/model/marginal_model.py

+102-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from typing import Sequence, Union
33

44
import numpy as np
5-
import pymc
65
import pytensor.tensor as pt
6+
import scipy
77
from arviz import InferenceData, dict_to_dataset
8-
from pymc import SymbolicRandomVariable
8+
from pymc import SymbolicRandomVariable, icdf
99
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
10+
from pymc.distributions.continuous import Continuous, Normal
1011
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
1112
from pymc.distributions.transforms import Chain
1213
from pymc.logprob.abstract import _logprob
@@ -159,7 +160,11 @@ def _marginalize(self, user_warnings=False):
159160
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
160161
)
161162

162-
old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
163+
if isinstance(rv_to_marginalize.owner.op, Continuous):
164+
subgraph_builder_fn = replace_continuous_marginal_subgraph
165+
else:
166+
subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
167+
old_rvs, new_rvs = subgraph_builder_fn(
163168
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
164169
)
165170

@@ -267,7 +272,11 @@ def marginalize(
267272
)
268273

269274
rv_op = rv_to_marginalize.owner.op
270-
if isinstance(rv_op, DiscreteMarkovChain):
275+
276+
if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
277+
pass
278+
279+
elif isinstance(rv_op, DiscreteMarkovChain):
271280
if rv_op.n_lags > 1:
272281
raise NotImplementedError(
273282
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +285,11 @@ def marginalize(
276285
raise NotImplementedError(
277286
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278287
)
279-
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
288+
289+
elif isinstance(rv_op, Normal):
290+
pass
291+
292+
else:
280293
raise NotImplementedError(
281294
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
282295
)
@@ -549,6 +562,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549562
"""Base class for Discrete Marginal Markov Chain RVs"""
550563

551564

565+
class QMCMarginalNormalRV(MarginalRV):
566+
"""Basec class for QMC Marginalized RVs"""
567+
568+
__props__ = ("qmc_order",)
569+
570+
def __init__(self, *args, qmc_order: int, **kwargs):
571+
self.qmc_order = qmc_order
572+
super().__init__(*args, **kwargs)
573+
574+
552575
def static_shape_ancestors(vars):
553576
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554577
return [
@@ -707,6 +730,36 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707730
return rvs_to_marginalize, marginalized_rvs
708731

709732

733+
def replace_continuous_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
734+
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
735+
if not dependent_rvs:
736+
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
737+
738+
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
739+
dependent_rvs_input_rvs = [
740+
rv
741+
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
742+
if rv is not rv_to_marginalize
743+
]
744+
745+
input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
746+
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
747+
748+
outputs = rvs_to_marginalize
749+
# We are strict about shared variables in SymbolicRandomVariables
750+
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
751+
752+
marginalized_rvs = QMCMarginalNormalRV(
753+
inputs=inputs,
754+
outputs=outputs,
755+
ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]),
756+
qmc_order=13,
757+
)(*inputs)
758+
759+
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
760+
return rvs_to_marginalize, marginalized_rvs
761+
762+
710763
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
711764
op = rv.owner.op
712765
dist_params = rv.owner.op.dist_params(rv.owner)
@@ -870,3 +923,47 @@ def step_alpha(logp_emission, log_alpha, log_P):
870923
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
871924
dummy_logps = (pt.constant(0),) * (len(values) - 1)
872925
return joint_logp, *dummy_logps
926+
927+
928+
@_logprob.register(QMCMarginalNormalRV)
929+
def qmc_marginal_rv_logp(op, values, *inputs, **kwargs):
930+
# Clone the inner RV graph of the Marginalized RV
931+
marginalized_rvs_node = op.make_node(*inputs)
932+
marginalized_rv, *inner_rvs = clone_replace(
933+
op.inner_outputs,
934+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
935+
)
936+
937+
marginalized_rv_node = marginalized_rv.owner
938+
marginalized_rv_op = marginalized_rv_node.op
939+
940+
# GET QMC draws from the marginalized RV
941+
# TODO: Make this an Op
942+
rng = marginalized_rv_op.rng_param(marginalized_rv_node)
943+
shape = constant_fold(tuple(marginalized_rv.shape))
944+
size = np.prod(shape).astype(int)
945+
n_draws = 2**op.qmc_order
946+
qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False))
947+
uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape))
948+
qmc_draws = icdf(marginalized_rv, uniform_draws)
949+
qmc_draws.name = f"QMC_{op.name}_draws"
950+
951+
# Obtain the logp of the dependent variables
952+
# We need to include the marginalized RV for correctness, we remove it later.
953+
inner_rv_values = dict(zip(inner_rvs, values))
954+
marginalized_vv = marginalized_rv.clone()
955+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
956+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
957+
# Pop the logp term corresponding to the marginalized RV
958+
# (it already got accounted for in the bias of the QMC draws)
959+
logps_dict.pop(marginalized_vv)
960+
961+
# Vectorize across QMC draws and take the mean on log scale
962+
core_marginalized_logps = list(logps_dict.values())
963+
batched_marginalized_logps = vectorize_graph(
964+
core_marginalized_logps, replace={marginalized_vv: qmc_draws}
965+
)
966+
return tuple(
967+
pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(size)
968+
for batched_marginalized_logp in batched_marginalized_logps
969+
)

pymc_experimental/tests/model/test_marginal_model.py

+17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pymc as pm
77
import pytensor.tensor as pt
88
import pytest
9+
import scipy
910
from arviz import InferenceData, dict_to_dataset
1011
from pymc.distributions import transforms
1112
from pymc.logprob.abstract import _logprob
@@ -802,3 +803,19 @@ def create_model(model_class):
802803
marginal_m.compile_logp()(ip),
803804
reference_m.compile_logp()(ip),
804805
)
806+
807+
808+
def test_marginalize_normal_via_qmc():
809+
with MarginalModel() as m:
810+
SD = pm.HalfNormal("SD", default_transform=None)
811+
X = pm.Normal("X", sigma=SD)
812+
Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])
813+
814+
m.marginalize([X]) # ideally method="qmc"
815+
816+
# P(Y=[1, 2, 3] | SD = 1) = int_x P(Y=[1, 2, 3] | SD=1, X=x) P(X=x | SD=1) = Norm([1, 2, 3], 0.5, sqrt(2))
817+
[logp_eval] = m.compile_logp(vars=[Y], sum=False)({"SD": 1})
818+
np.testing.assert_allclose(
819+
logp_eval,
820+
scipy.stats.norm.logpdf([1, 2, 3], 0.5, np.sqrt(2) / 2),
821+
)

0 commit comments

Comments
 (0)