Skip to content

Commit d8971c4

Browse files
ricardoV94aseyboldtlarryshamalamazaxtaxtheorashid
committed
Add QMC marginalization
Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: larryshamalama <[email protected]> Co-authored-by: Rob Zinkov <[email protected]> Co-authored-by: theorashid <[email protected]>
1 parent 87d4aea commit d8971c4

File tree

2 files changed

+203
-16
lines changed

2 files changed

+203
-16
lines changed

pymc_experimental/model/marginal_model.py

+140-16
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
import numpy as np
55
import pymc
66
import pytensor.tensor as pt
7+
import scipy
78
from arviz import InferenceData, dict_to_dataset
8-
from pymc import SymbolicRandomVariable
99
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
10+
from pymc.distributions import MvNormal, SymbolicRandomVariable
11+
from pymc.distributions.continuous import Continuous
1012
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
1113
from pymc.distributions.transforms import Chain
1214
from pymc.logprob.abstract import _logprob
13-
from pymc.logprob.basic import conditional_logp, logp
15+
from pymc.logprob.basic import conditional_logp, icdf, logp
1416
from pymc.logprob.transforms import IntervalTransform
1517
from pymc.model import Model
16-
from pymc.pytensorf import compile_pymc, constant_fold
18+
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold
1719
from pymc.util import RandomState, _get_seeds_per_chain, treedict
1820
from pytensor import Mode, scan
1921
from pytensor.compile import SharedVariable
@@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False):
159161
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
160162
)
161163

162-
old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
163-
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
164+
if isinstance(rv_to_marginalize.owner.op, Continuous):
165+
subgraph_builder_fn = replace_continuous_marginal_subgraph
166+
else:
167+
subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
168+
old_rvs, new_rvs = subgraph_builder_fn(
169+
fg,
170+
rv_to_marginalize,
171+
self.basic_RVs + rvs_left_to_marginalize,
172+
user_warnings=user_warnings,
164173
)
165174

166-
if user_warnings and len(new_rvs) > 2:
167-
warnings.warn(
168-
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
169-
f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
170-
UserWarning,
171-
)
172-
173175
rvs_left_to_marginalize.remove(rv_to_marginalize)
174176

175177
for old_rv, new_rv in zip(old_rvs, new_rvs):
@@ -267,7 +269,11 @@ def marginalize(
267269
)
268270

269271
rv_op = rv_to_marginalize.owner.op
270-
if isinstance(rv_op, DiscreteMarkovChain):
272+
273+
if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
274+
pass
275+
276+
elif isinstance(rv_op, DiscreteMarkovChain):
271277
if rv_op.n_lags > 1:
272278
raise NotImplementedError(
273279
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +282,11 @@ def marginalize(
276282
raise NotImplementedError(
277283
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278284
)
279-
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
285+
286+
elif isinstance(rv_op, Continuous):
287+
pass
288+
289+
else:
280290
raise NotImplementedError(
281291
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
282292
)
@@ -449,7 +459,7 @@ def transform_input(inputs):
449459
rv_loglike_fn = None
450460
joint_logps_norm = log_softmax(joint_logps, axis=-1)
451461
if return_samples:
452-
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
462+
sample_rv_outs = Categorical.dist(logit_p=joint_logps)
453463
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
454464
sample_rv_outs += rv_domain[0]
455465

@@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549559
"""Base class for Discrete Marginal Markov Chain RVs"""
550560

551561

562+
class QMCMarginalNormalRV(MarginalRV):
563+
"""Basec class for QMC Marginalized RVs"""
564+
565+
__props__ = ("qmc_order",)
566+
567+
def __init__(self, *args, qmc_order: int, **kwargs):
568+
self.qmc_order = qmc_order
569+
super().__init__(*args, **kwargs)
570+
571+
552572
def static_shape_ancestors(vars):
553573
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554574
return [
@@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers):
646666
]
647667

648668

649-
def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
669+
def replace_finite_discrete_marginal_subgraph(
670+
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
671+
):
650672
# TODO: This should eventually be integrated in a more general routine that can
651673
# identify other types of supported marginalization, of which finite discrete
652674
# RVs is just one
@@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
655677
if not dependent_rvs:
656678
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
657679

680+
if user_warnings and len(dependent_rvs) > 1:
681+
warnings.warn(
682+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
683+
f"Their joint logp terms will be assigned to the first RV: {dependent_rvs[0]}",
684+
UserWarning,
685+
)
686+
658687
ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
659688
if len(ndim_supp) != 1:
660689
raise NotImplementedError(
@@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707736
return rvs_to_marginalize, marginalized_rvs
708737

709738

739+
def replace_continuous_marginal_subgraph(
740+
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
741+
):
742+
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
743+
if not dependent_rvs:
744+
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
745+
746+
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
747+
dependent_rvs_input_rvs = [
748+
rv
749+
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
750+
if rv is not rv_to_marginalize
751+
]
752+
753+
input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
754+
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
755+
756+
outputs = rvs_to_marginalize
757+
# We are strict about shared variables in SymbolicRandomVariables
758+
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
759+
760+
# TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!!
761+
marginalized_rvs = QMCMarginalNormalRV(
762+
inputs=inputs,
763+
outputs=[*outputs, *collect_default_updates(inputs=inputs, outputs=outputs).values()],
764+
ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]),
765+
qmc_order=13,
766+
)(*inputs)[: len(outputs)]
767+
768+
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
769+
return rvs_to_marginalize, marginalized_rvs
770+
771+
710772
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
711773
op = rv.owner.op
712774
dist_params = rv.owner.op.dist_params(rv.owner)
@@ -870,3 +932,65 @@ def step_alpha(logp_emission, log_alpha, log_P):
870932
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
871933
dummy_logps = (pt.constant(0),) * (len(values) - 1)
872934
return joint_logp, *dummy_logps
935+
936+
937+
@_logprob.register(QMCMarginalNormalRV)
938+
def qmc_marginal_rv_logp(op, values, *inputs, **kwargs):
939+
# Clone the inner RV graph of the Marginalized RV
940+
marginalized_rvs_node = op.make_node(*inputs)
941+
# The MarginalizedRV contains the following outputs:
942+
# 1. The variable we marginalized
943+
# 2. The dependent variables
944+
# 3. The updates for the marginalized and dependent variables
945+
marginalized_rv, *inner_rvs_and_updates = clone_replace(
946+
op.inner_outputs,
947+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
948+
)
949+
inner_rvs = inner_rvs_and_updates[: (len(inner_rvs_and_updates) - 1) // 2]
950+
951+
marginalized_rv_node = marginalized_rv.owner
952+
marginalized_rv_op = marginalized_rv_node.op
953+
954+
# GET QMC draws from the marginalized RV
955+
# TODO: Make this an Op
956+
rng = marginalized_rv_op.rng_param(marginalized_rv_node)
957+
shape = constant_fold(tuple(marginalized_rv.shape))
958+
size = np.prod(shape).astype(int)
959+
n_draws = 2**op.qmc_order
960+
961+
# TODO: Wrap Sobol in an Op so we can control the RNG and change whenever
962+
qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False))
963+
uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape))
964+
965+
if isinstance(marginalized_rv_op, MvNormal):
966+
# Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298
967+
mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node)
968+
corr_matrix = pt.linalg.cholesky(cov).mT
969+
base_draws = pt.as_tensor(scipy.stats.norm.ppf(0.5 + (1 - 1e-10) * (uniform_draws - 0.5)))
970+
qmc_draws = base_draws @ corr_matrix + mean
971+
else:
972+
qmc_draws = icdf(marginalized_rv, uniform_draws)
973+
974+
qmc_draws.name = f"QMC_{marginalized_rv_op.name}_draws"
975+
976+
# Obtain the logp of the dependent variables
977+
# We need to include the marginalized RV for correctness, we remove it later.
978+
inner_rv_values = dict(zip(inner_rvs, values))
979+
marginalized_vv = marginalized_rv.clone()
980+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
981+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
982+
# Pop the logp term corresponding to the marginalized RV
983+
# (it already got accounted for in the bias of the QMC draws)
984+
logps_dict.pop(marginalized_vv)
985+
986+
# Vectorize across QMC draws and take the mean on log scale
987+
core_marginalized_logps = list(logps_dict.values())
988+
batched_marginalized_logps = vectorize_graph(
989+
core_marginalized_logps, replace={marginalized_vv: qmc_draws}
990+
)
991+
992+
# Take the mean in log scale
993+
return tuple(
994+
pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(n_draws)
995+
for batched_marginalized_logp in batched_marginalized_logps
996+
)

pymc_experimental/tests/model/test_marginal_model.py

+63
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import pymc as pm
77
import pytensor.tensor as pt
88
import pytest
9+
import scipy
910
from arviz import InferenceData, dict_to_dataset
1011
from pymc.distributions import transforms
1112
from pymc.logprob.abstract import _logprob
1213
from pymc.model.fgraph import fgraph_from_model
1314
from pymc.pytensorf import inputvars
1415
from pymc.util import UNSET
16+
from pytensor.graph import FunctionGraph
1517
from scipy.special import log_softmax, logsumexp
1618
from scipy.stats import halfnorm, norm
1719

@@ -21,6 +23,7 @@
2123
MarginalModel,
2224
is_conditional_dependent,
2325
marginalize,
26+
replace_continuous_marginal_subgraph,
2427
)
2528
from pymc_experimental.tests.utils import equal_computations_up_to_root
2629

@@ -803,3 +806,63 @@ def create_model(model_class):
803806
marginal_m.compile_logp()(ip),
804807
reference_m.compile_logp()(ip),
805808
)
809+
810+
811+
@pytest.mark.parametrize("univariate", (True, False), ids=["univariate", "multivariate"])
812+
@pytest.mark.parametrize(
813+
"multiple_dependent", (False, True), ids=["single-dependent", "multiple-dependent"]
814+
)
815+
def test_marginalize_normal_qmc(univariate, multiple_dependent):
816+
with MarginalModel() as m:
817+
SD = pm.HalfNormal("SD", default_transform=None)
818+
if univariate:
819+
X = pm.Normal("X", sigma=SD, shape=(3,))
820+
else:
821+
X = pm.MvNormal("X", mu=[0, 0, 0], cov=np.eye(3) * SD**2)
822+
823+
if multiple_dependent:
824+
Y = [
825+
pm.Normal("Y[0]", mu=(2 * X[0] + 1), sigma=1, observed=1),
826+
pm.Normal("Y[1:]", mu=(2 * X[1:] + 1), sigma=1, observed=[2, 3]),
827+
]
828+
else:
829+
Y = [pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])]
830+
831+
m.marginalize([X]) # ideally method="qmc"
832+
833+
logp_eval = np.hstack(m.compile_logp(vars=Y, sum=False)({"SD": 2.0}))
834+
835+
np.testing.assert_allclose(
836+
logp_eval,
837+
scipy.stats.norm.logpdf([1, 2, 3], 1, np.sqrt(17)),
838+
rtol=1e-5,
839+
)
840+
841+
842+
def test_marginalize_non_trivial_mvnormal_qmc():
843+
with MarginalModel() as m:
844+
SD = pm.HalfNormal("SD", default_transform=None)
845+
X = pm.MvNormal("X", cov=[[1.0, 0.5], [0.5, 1.0]] * SD**2)
846+
Y = pm.MvNormal("Y", mu=2 * X + 1, cov=np.eye(2), observed=[1, 2])
847+
848+
m.marginalize([X])
849+
850+
[logp_eval] = m.compile_logp(vars=Y, sum=False)({"SD": 1})
851+
852+
np.testing.assert_allclose(
853+
logp_eval,
854+
scipy.stats.multivariate_normal.logpdf([1, 2], [1, 1], [[5, 2], [2, 5]]),
855+
rtol=1e-5,
856+
)
857+
858+
859+
def test_marginalize_sample():
860+
with pm.Model() as m:
861+
SD = pm.HalfNormal("SD")
862+
X = pm.Normal.dist(sigma=SD, name="X")
863+
Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])
864+
865+
fg = FunctionGraph(outputs=[SD, Y, X], clone=False)
866+
old_rvs, new_rvs = replace_continuous_marginal_subgraph(fg, X, [Y, SD, X])
867+
res1, res2 = pm.draw(new_rvs, draws=2)
868+
assert not np.allclose(res1, res2)

0 commit comments

Comments
 (0)