2
2
from typing import Sequence , Union
3
3
4
4
import numpy as np
5
- import pymc
6
5
import pytensor .tensor as pt
6
+ import scipy
7
7
from arviz import InferenceData , dict_to_dataset
8
- from pymc import SymbolicRandomVariable
8
+ from pymc import SymbolicRandomVariable , icdf
9
9
from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10
+ from pymc .distributions .continuous import Continuous , Normal
10
11
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
11
12
from pymc .distributions .transforms import Chain
12
13
from pymc .logprob .abstract import _logprob
@@ -159,7 +160,11 @@ def _marginalize(self, user_warnings=False):
159
160
f"Cannot marginalize { rv_to_marginalize } due to dependent Potential { pot } "
160
161
)
161
162
162
- old_rvs , new_rvs = replace_finite_discrete_marginal_subgraph (
163
+ if isinstance (rv_to_marginalize .owner .op , Continuous ):
164
+ subgraph_builder_fn = replace_continuous_marginal_subgraph
165
+ else :
166
+ subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
167
+ old_rvs , new_rvs = subgraph_builder_fn (
163
168
fg , rv_to_marginalize , self .basic_RVs + rvs_left_to_marginalize
164
169
)
165
170
@@ -267,7 +272,11 @@ def marginalize(
267
272
)
268
273
269
274
rv_op = rv_to_marginalize .owner .op
270
- if isinstance (rv_op , DiscreteMarkovChain ):
275
+
276
+ if isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
277
+ pass
278
+
279
+ elif isinstance (rv_op , DiscreteMarkovChain ):
271
280
if rv_op .n_lags > 1 :
272
281
raise NotImplementedError (
273
282
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +285,11 @@ def marginalize(
276
285
raise NotImplementedError (
277
286
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278
287
)
279
- elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
288
+
289
+ elif isinstance (rv_op , Normal ):
290
+ pass
291
+
292
+ else :
280
293
raise NotImplementedError (
281
294
f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
282
295
)
@@ -549,6 +562,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549
562
"""Base class for Discrete Marginal Markov Chain RVs"""
550
563
551
564
565
+ class QMCMarginalNormalRV (MarginalRV ):
566
+ """Basec class for QMC Marginalized RVs"""
567
+
568
+ __props__ = ("qmc_order" ,)
569
+
570
+ def __init__ (self , * args , qmc_order : int , ** kwargs ):
571
+ self .qmc_order = qmc_order
572
+ super ().__init__ (* args , ** kwargs )
573
+
574
+
552
575
def static_shape_ancestors (vars ):
553
576
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554
577
return [
@@ -707,6 +730,36 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707
730
return rvs_to_marginalize , marginalized_rvs
708
731
709
732
733
+ def replace_continuous_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
734
+ dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
735
+ if not dependent_rvs :
736
+ raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
737
+
738
+ marginalized_rv_input_rvs = find_conditional_input_rvs ([rv_to_marginalize ], all_rvs )
739
+ dependent_rvs_input_rvs = [
740
+ rv
741
+ for rv in find_conditional_input_rvs (dependent_rvs , all_rvs )
742
+ if rv is not rv_to_marginalize
743
+ ]
744
+
745
+ input_rvs = [* marginalized_rv_input_rvs , * dependent_rvs_input_rvs ]
746
+ rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
747
+
748
+ outputs = rvs_to_marginalize
749
+ # We are strict about shared variables in SymbolicRandomVariables
750
+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
751
+
752
+ marginalized_rvs = QMCMarginalNormalRV (
753
+ inputs = inputs ,
754
+ outputs = outputs ,
755
+ ndim_supp = max ([rv .owner .op .ndim_supp for rv in dependent_rvs ]),
756
+ qmc_order = 13 ,
757
+ )(* inputs )
758
+
759
+ fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
760
+ return rvs_to_marginalize , marginalized_rvs
761
+
762
+
710
763
def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
711
764
op = rv .owner .op
712
765
dist_params = rv .owner .op .dist_params (rv .owner )
@@ -870,3 +923,47 @@ def step_alpha(logp_emission, log_alpha, log_P):
870
923
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
871
924
dummy_logps = (pt .constant (0 ),) * (len (values ) - 1 )
872
925
return joint_logp , * dummy_logps
926
+
927
+
928
+ @_logprob .register (QMCMarginalNormalRV )
929
+ def qmc_marginal_rv_logp (op , values , * inputs , ** kwargs ):
930
+ # Clone the inner RV graph of the Marginalized RV
931
+ marginalized_rvs_node = op .make_node (* inputs )
932
+ marginalized_rv , * inner_rvs = clone_replace (
933
+ op .inner_outputs ,
934
+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
935
+ )
936
+
937
+ marginalized_rv_node = marginalized_rv .owner
938
+ marginalized_rv_op = marginalized_rv_node .op
939
+
940
+ # GET QMC draws from the marginalized RV
941
+ # TODO: Make this an Op
942
+ rng = marginalized_rv_op .rng_param (marginalized_rv_node )
943
+ shape = constant_fold (tuple (marginalized_rv .shape ))
944
+ size = np .prod (shape ).astype (int )
945
+ n_draws = 2 ** op .qmc_order
946
+ qmc_engine = scipy .stats .qmc .Sobol (d = size , seed = rng .get_value (borrow = False ))
947
+ uniform_draws = qmc_engine .random (n_draws ).reshape ((n_draws , * shape ))
948
+ qmc_draws = icdf (marginalized_rv , uniform_draws )
949
+ qmc_draws .name = f"QMC_{ op .name } _draws"
950
+
951
+ # Obtain the logp of the dependent variables
952
+ # We need to include the marginalized RV for correctness, we remove it later.
953
+ inner_rv_values = dict (zip (inner_rvs , values ))
954
+ marginalized_vv = marginalized_rv .clone ()
955
+ rv_values = inner_rv_values | {marginalized_rv : marginalized_vv }
956
+ logps_dict = conditional_logp (rv_values = rv_values , ** kwargs )
957
+ # Pop the logp term corresponding to the marginalized RV
958
+ # (it already got accounted for in the bias of the QMC draws)
959
+ logps_dict .pop (marginalized_vv )
960
+
961
+ # Vectorize across QMC draws and take the mean on log scale
962
+ core_marginalized_logps = list (logps_dict .values ())
963
+ batched_marginalized_logps = vectorize_graph (
964
+ core_marginalized_logps , replace = {marginalized_vv : qmc_draws }
965
+ )
966
+ return tuple (
967
+ pt .logsumexp (batched_marginalized_logp , axis = 0 ) - pt .log (size )
968
+ for batched_marginalized_logp in batched_marginalized_logps
969
+ )
0 commit comments