4
4
import numpy as np
5
5
import pymc
6
6
import pytensor .tensor as pt
7
+ import scipy
7
8
from arviz import InferenceData , dict_to_dataset
8
- from pymc import SymbolicRandomVariable
9
9
from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10
+ from pymc .distributions import MvNormal , SymbolicRandomVariable
11
+ from pymc .distributions .continuous import Continuous
10
12
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
11
13
from pymc .distributions .transforms import Chain
12
14
from pymc .logprob .abstract import _logprob
13
- from pymc .logprob .basic import conditional_logp , logp
15
+ from pymc .logprob .basic import conditional_logp , icdf , logp
14
16
from pymc .logprob .transforms import IntervalTransform
15
17
from pymc .model import Model
16
- from pymc .pytensorf import compile_pymc , constant_fold
18
+ from pymc .pytensorf import collect_default_updates , compile_pymc , constant_fold
17
19
from pymc .util import RandomState , _get_seeds_per_chain , treedict
18
20
from pytensor import Mode , scan
19
21
from pytensor .compile import SharedVariable
@@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False):
159
161
f"Cannot marginalize { rv_to_marginalize } due to dependent Potential { pot } "
160
162
)
161
163
162
- old_rvs , new_rvs = replace_finite_discrete_marginal_subgraph (
163
- fg , rv_to_marginalize , self .basic_RVs + rvs_left_to_marginalize
164
+ if isinstance (rv_to_marginalize .owner .op , Continuous ):
165
+ subgraph_builder_fn = replace_continuous_marginal_subgraph
166
+ else :
167
+ subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
168
+ old_rvs , new_rvs = subgraph_builder_fn (
169
+ fg ,
170
+ rv_to_marginalize ,
171
+ self .basic_RVs + rvs_left_to_marginalize ,
172
+ user_warnings = user_warnings ,
164
173
)
165
174
166
- if user_warnings and len (new_rvs ) > 2 :
167
- warnings .warn (
168
- "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
169
- f"Their joint logp terms will be assigned to the first RV: { old_rvs [1 ]} " ,
170
- UserWarning ,
171
- )
172
-
173
175
rvs_left_to_marginalize .remove (rv_to_marginalize )
174
176
175
177
for old_rv , new_rv in zip (old_rvs , new_rvs ):
@@ -267,7 +269,11 @@ def marginalize(
267
269
)
268
270
269
271
rv_op = rv_to_marginalize .owner .op
270
- if isinstance (rv_op , DiscreteMarkovChain ):
272
+
273
+ if isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
274
+ pass
275
+
276
+ elif isinstance (rv_op , DiscreteMarkovChain ):
271
277
if rv_op .n_lags > 1 :
272
278
raise NotImplementedError (
273
279
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +282,11 @@ def marginalize(
276
282
raise NotImplementedError (
277
283
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278
284
)
279
- elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
285
+
286
+ elif isinstance (rv_op , Continuous ):
287
+ pass
288
+
289
+ else :
280
290
raise NotImplementedError (
281
291
f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
282
292
)
@@ -449,7 +459,7 @@ def transform_input(inputs):
449
459
rv_loglike_fn = None
450
460
joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
451
461
if return_samples :
452
- sample_rv_outs = pymc . Categorical .dist (logit_p = joint_logps )
462
+ sample_rv_outs = Categorical .dist (logit_p = joint_logps )
453
463
if isinstance (marginalized_rv .owner .op , DiscreteUniform ):
454
464
sample_rv_outs += rv_domain [0 ]
455
465
@@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549
559
"""Base class for Discrete Marginal Markov Chain RVs"""
550
560
551
561
562
+ class QMCMarginalNormalRV (MarginalRV ):
563
+ """Basec class for QMC Marginalized RVs"""
564
+
565
+ __props__ = ("qmc_order" ,)
566
+
567
+ def __init__ (self , * args , qmc_order : int , ** kwargs ):
568
+ self .qmc_order = qmc_order
569
+ super ().__init__ (* args , ** kwargs )
570
+
571
+
552
572
def static_shape_ancestors (vars ):
553
573
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554
574
return [
@@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers):
646
666
]
647
667
648
668
649
- def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
669
+ def replace_finite_discrete_marginal_subgraph (
670
+ fgraph , rv_to_marginalize , all_rvs , user_warnings : bool = False
671
+ ):
650
672
# TODO: This should eventually be integrated in a more general routine that can
651
673
# identify other types of supported marginalization, of which finite discrete
652
674
# RVs is just one
@@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
655
677
if not dependent_rvs :
656
678
raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
657
679
680
+ if user_warnings and len (dependent_rvs ) > 1 :
681
+ warnings .warn (
682
+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
683
+ f"Their joint logp terms will be assigned to the first RV: { dependent_rvs [0 ]} " ,
684
+ UserWarning ,
685
+ )
686
+
658
687
ndim_supp = {rv .owner .op .ndim_supp for rv in dependent_rvs }
659
688
if len (ndim_supp ) != 1 :
660
689
raise NotImplementedError (
@@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707
736
return rvs_to_marginalize , marginalized_rvs
708
737
709
738
739
+ def replace_continuous_marginal_subgraph (
740
+ fgraph , rv_to_marginalize , all_rvs , user_warnings : bool = False
741
+ ):
742
+ dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
743
+ if not dependent_rvs :
744
+ raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
745
+
746
+ marginalized_rv_input_rvs = find_conditional_input_rvs ([rv_to_marginalize ], all_rvs )
747
+ dependent_rvs_input_rvs = [
748
+ rv
749
+ for rv in find_conditional_input_rvs (dependent_rvs , all_rvs )
750
+ if rv is not rv_to_marginalize
751
+ ]
752
+
753
+ input_rvs = [* marginalized_rv_input_rvs , * dependent_rvs_input_rvs ]
754
+ rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
755
+
756
+ outputs = rvs_to_marginalize
757
+ # We are strict about shared variables in SymbolicRandomVariables
758
+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
759
+
760
+ # TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!!
761
+ marginalized_rvs = QMCMarginalNormalRV (
762
+ inputs = inputs ,
763
+ outputs = [* outputs , * collect_default_updates (inputs = inputs , outputs = outputs ).values ()],
764
+ ndim_supp = max ([rv .owner .op .ndim_supp for rv in dependent_rvs ]),
765
+ qmc_order = 13 ,
766
+ )(* inputs )[: len (outputs )]
767
+
768
+ fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
769
+ return rvs_to_marginalize , marginalized_rvs
770
+
771
+
710
772
def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
711
773
op = rv .owner .op
712
774
dist_params = rv .owner .op .dist_params (rv .owner )
@@ -870,3 +932,65 @@ def step_alpha(logp_emission, log_alpha, log_P):
870
932
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
871
933
dummy_logps = (pt .constant (0 ),) * (len (values ) - 1 )
872
934
return joint_logp , * dummy_logps
935
+
936
+
937
+ @_logprob .register (QMCMarginalNormalRV )
938
+ def qmc_marginal_rv_logp (op , values , * inputs , ** kwargs ):
939
+ # Clone the inner RV graph of the Marginalized RV
940
+ marginalized_rvs_node = op .make_node (* inputs )
941
+ # The MarginalizedRV contains the following outputs:
942
+ # 1. The variable we marginalized
943
+ # 2. The dependent variables
944
+ # 3. The updates for the marginalized and dependent variables
945
+ marginalized_rv , * inner_rvs_and_updates = clone_replace (
946
+ op .inner_outputs ,
947
+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
948
+ )
949
+ inner_rvs = inner_rvs_and_updates [: (len (inner_rvs_and_updates ) - 1 ) // 2 ]
950
+
951
+ marginalized_rv_node = marginalized_rv .owner
952
+ marginalized_rv_op = marginalized_rv_node .op
953
+
954
+ # GET QMC draws from the marginalized RV
955
+ # TODO: Make this an Op
956
+ rng = marginalized_rv_op .rng_param (marginalized_rv_node )
957
+ shape = constant_fold (tuple (marginalized_rv .shape ))
958
+ size = np .prod (shape ).astype (int )
959
+ n_draws = 2 ** op .qmc_order
960
+
961
+ # TODO: Wrap Sobol in an Op so we can control the RNG and change whenever
962
+ qmc_engine = scipy .stats .qmc .Sobol (d = size , seed = rng .get_value (borrow = False ))
963
+ uniform_draws = qmc_engine .random (n_draws ).reshape ((n_draws , * shape ))
964
+
965
+ if isinstance (marginalized_rv_op , MvNormal ):
966
+ # Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298
967
+ mean , cov = marginalized_rv_op .dist_params (marginalized_rv_node )
968
+ corr_matrix = pt .linalg .cholesky (cov ).mT
969
+ base_draws = pt .as_tensor (scipy .stats .norm .ppf (0.5 + (1 - 1e-10 ) * (uniform_draws - 0.5 )))
970
+ qmc_draws = base_draws @ corr_matrix + mean
971
+ else :
972
+ qmc_draws = icdf (marginalized_rv , uniform_draws )
973
+
974
+ qmc_draws .name = f"QMC_{ marginalized_rv_op .name } _draws"
975
+
976
+ # Obtain the logp of the dependent variables
977
+ # We need to include the marginalized RV for correctness, we remove it later.
978
+ inner_rv_values = dict (zip (inner_rvs , values ))
979
+ marginalized_vv = marginalized_rv .clone ()
980
+ rv_values = inner_rv_values | {marginalized_rv : marginalized_vv }
981
+ logps_dict = conditional_logp (rv_values = rv_values , ** kwargs )
982
+ # Pop the logp term corresponding to the marginalized RV
983
+ # (it already got accounted for in the bias of the QMC draws)
984
+ logps_dict .pop (marginalized_vv )
985
+
986
+ # Vectorize across QMC draws and take the mean on log scale
987
+ core_marginalized_logps = list (logps_dict .values ())
988
+ batched_marginalized_logps = vectorize_graph (
989
+ core_marginalized_logps , replace = {marginalized_vv : qmc_draws }
990
+ )
991
+
992
+ # Take the mean in log scale
993
+ return tuple (
994
+ pt .logsumexp (batched_marginalized_logp , axis = 0 ) - pt .log (n_draws )
995
+ for batched_marginalized_logp in batched_marginalized_logps
996
+ )
0 commit comments