43
43
import numpy as np
44
44
import xarray
45
45
46
- from aesara .graph .basic import Constant , Variable
47
- from aesara .tensor import TensorVariable
46
+ from aesara import tensor as at
47
+ from aesara .graph .basic import Apply , Constant , Variable , general_toposort , walk
48
+ from aesara .graph .fg import FunctionGraph
49
+ from aesara .tensor .random .op import RandomVariable
50
+ from aesara .tensor .random .var import (
51
+ RandomGeneratorSharedVariable ,
52
+ RandomStateSharedVariable ,
53
+ )
48
54
from aesara .tensor .sharedvar import SharedVariable
49
55
from arviz import InferenceData
50
56
from fastprogress .fastprogress import progress_bar
51
57
from typing_extensions import TypeAlias
52
58
53
59
import pymc as pm
54
60
55
- from pymc .aesaraf import change_rv_size , compile_pymc , inputvars , walk_model
61
+ from pymc .aesaraf import change_rv_size , compile_pymc
56
62
from pymc .backends .arviz import _DefaultTrace
57
63
from pymc .backends .base import BaseTrace , MultiTrace
58
64
from pymc .backends .ndarray import NDArray
75
81
get_default_varnames ,
76
82
get_untransformed_name ,
77
83
is_transformed_name ,
84
+ point_wrapper ,
78
85
)
79
86
from pymc .vartypes import discrete_types
80
87
83
90
__all__ = [
84
91
"sample" ,
85
92
"iter_sample" ,
93
+ "compile_forward_sampling_function" ,
86
94
"sample_posterior_predictive" ,
87
95
"sample_posterior_predictive_w" ,
88
96
"init_nuts" ,
@@ -1534,6 +1542,147 @@ def stop_tuning(step):
1534
1542
return step
1535
1543
1536
1544
1545
+ def get_vars_in_point_list (trace , model ):
1546
+ """Get the list of Variable instances in the model that have values stored in the trace."""
1547
+ if not isinstance (trace , MultiTrace ):
1548
+ names_in_trace = list (trace [0 ])
1549
+ else :
1550
+ names_in_trace = trace .varnames
1551
+ vars_in_trace = [model [v ] for v in names_in_trace ]
1552
+ return vars_in_trace
1553
+
1554
+
1555
+ def compile_forward_sampling_function (
1556
+ outputs : List [Variable ],
1557
+ vars_in_trace : List [Variable ],
1558
+ basic_rvs : Optional [List [Variable ]] = None ,
1559
+ givens_dict : Optional [Dict [Variable , Any ]] = None ,
1560
+ ** kwargs ,
1561
+ ) -> Callable [..., Union [np .ndarray , List [np .ndarray ]]]:
1562
+ """Compile a function to draw samples, conditioned on the values of some variables.
1563
+
1564
+ The goal of this function is to walk the aesara computational graph from the list
1565
+ of output nodes down to the root nodes, and then compile a function that will produce
1566
+ values for these output nodes. The compiled function will take as inputs the subset of
1567
+ variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
1568
+
1569
+ Volatile variables are variables whose values could change between runs of the
1570
+ compiled function or after inference has been run. These variables are:
1571
+
1572
+ - Variables in the outputs list
1573
+ - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1574
+ - Basic RVs that are not in the ``vars_in_trace`` list
1575
+ - Variables that are keys in the ``givens_dict``
1576
+ - Variables that have volatile inputs
1577
+
1578
+ Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1579
+ that are in the ``basic_rvs`` list.
1580
+
1581
+ Concretely, this function can be used to compile a function to sample from the
1582
+ posterior predictive distribution of a model that has variables that are conditioned
1583
+ on ``MutableData`` instances. The variables that depend on the mutable data will be
1584
+ considered volatile, and as such, they wont be included as inputs into the compiled function.
1585
+ This means that if they have values stored in the posterior, these values will be ignored
1586
+ and new values will be computed (in the case of deterministics and potentials) or sampled
1587
+ (in the case of random variables).
1588
+
1589
+ This function also enables a way to impute values for any variable in the computational
1590
+ graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
1591
+ to set the ``givens`` argument of the aesara function compilation. This will essentially
1592
+ replace a node in the computational graph with any other expression that has the same
1593
+ type as the desired node. Passing variables in the givens_dict is considered an intervention
1594
+ that might lead to different variable values from those that could have been seen during
1595
+ inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
1596
+ volatile**.
1597
+
1598
+ Parameters
1599
+ ----------
1600
+ outputs : List[aesara.graph.basic.Variable]
1601
+ The list of variables that will be returned by the compiled function
1602
+ vars_in_trace : List[aesara.graph.basic.Variable]
1603
+ The list of variables that are assumed to have values stored in the trace
1604
+ basic_rvs : Optional[List[aesara.graph.basic.Variable]]
1605
+ A list of random variables that are defined in the model. This list (which could be the
1606
+ output of ``model.basic_RVs``) should have a reference to the variables that should
1607
+ be considered as random variable instances. This includes variables that have
1608
+ a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1609
+ Censored distributions. If ``None``, only pure random variables will be considered
1610
+ as potential random variables.
1611
+ givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
1612
+ A dictionary that maps tensor variables to the values that should be used to replace them
1613
+ in the compiled function. The types of the key and value should match or an error will be
1614
+ raised during compilation.
1615
+ """
1616
+ if givens_dict is None :
1617
+ givens_dict = {}
1618
+
1619
+ if basic_rvs is None :
1620
+ basic_rvs = []
1621
+
1622
+ # We need a function graph to walk the clients and propagate the volatile property
1623
+ fg = FunctionGraph (outputs = outputs , clone = False )
1624
+
1625
+ # Walk the graph from inputs to outputs and tag the volatile variables
1626
+ nodes : List [Variable ] = general_toposort (
1627
+ fg .outputs , deps = lambda x : x .owner .inputs if x .owner else []
1628
+ )
1629
+ volatile_nodes : Set [Any ] = set ()
1630
+ for node in nodes :
1631
+ if (
1632
+ node in fg .outputs
1633
+ or node in givens_dict
1634
+ or ( # SharedVariables, except RandomState/Generators
1635
+ isinstance (node , SharedVariable )
1636
+ and not isinstance (node , (RandomStateSharedVariable , RandomGeneratorSharedVariable ))
1637
+ )
1638
+ or ( # Basic RVs that are not in the trace
1639
+ node .owner
1640
+ and isinstance (node .owner .op , RandomVariable )
1641
+ and node in basic_rvs
1642
+ and node not in vars_in_trace
1643
+ )
1644
+ or ( # Variables that have any volatile input
1645
+ node .owner and any (inp in volatile_nodes for inp in node .owner .inputs )
1646
+ )
1647
+ ):
1648
+ volatile_nodes .add (node )
1649
+
1650
+ # Collect the function inputs by walking the graph from the outputs. Inputs will be:
1651
+ # 1. Random variables that are not volatile
1652
+ # 2. Variables that have no owner and are not constant or shared
1653
+ inputs = []
1654
+
1655
+ def expand (node ):
1656
+ if (
1657
+ (
1658
+ node .owner is None and not isinstance (node , (Constant , SharedVariable ))
1659
+ ) # Variables without owners that are not constant or shared
1660
+ or node in vars_in_trace # Variables in the trace
1661
+ ) and node not in volatile_nodes :
1662
+ # This test will include variables without owners, and that are not constant
1663
+ # or shared, because these nodes will never be considered volatile
1664
+ inputs .append (node )
1665
+ if node .owner :
1666
+ return node .owner .inputs
1667
+
1668
+ # walk produces a generator, so we have to actually exhaust the generator in a list to walk
1669
+ # the entire graph
1670
+ list (walk (fg .outputs , expand ))
1671
+
1672
+ # Populate the givens list
1673
+ givens = [
1674
+ (
1675
+ node ,
1676
+ value
1677
+ if isinstance (value , (Variable , Apply ))
1678
+ else at .constant (value , dtype = getattr (node , "dtype" , None ), name = node .name ),
1679
+ )
1680
+ for node , value in givens_dict .items ()
1681
+ ]
1682
+
1683
+ return compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs )
1684
+
1685
+
1537
1686
def sample_posterior_predictive (
1538
1687
trace ,
1539
1688
samples : Optional [int ] = None ,
@@ -1718,38 +1867,23 @@ def sample_posterior_predictive(
1718
1867
return trace
1719
1868
return {}
1720
1869
1721
- inputs : Sequence [TensorVariable ]
1722
- input_names : Sequence [str ]
1723
- if not isinstance (_trace , MultiTrace ):
1724
- names_in_trace = list (_trace [0 ])
1725
- else :
1726
- names_in_trace = _trace .varnames
1727
- inputs_and_names = [
1728
- (rv , rv .name )
1729
- for rv in walk_model (vars_to_sample , walk_past_rvs = True )
1730
- if rv not in vars_to_sample
1731
- and rv in model .named_vars .values ()
1732
- and not isinstance (rv , (Constant , SharedVariable ))
1733
- and rv .name in names_in_trace
1734
- ]
1735
- if inputs_and_names :
1736
- inputs , input_names = zip (* inputs_and_names )
1737
- else :
1738
- inputs , input_names = [], []
1739
-
1740
1870
if size is not None :
1741
1871
vars_to_sample = [change_rv_size (v , size , expand = True ) for v in vars_to_sample ]
1872
+ vars_in_trace = get_vars_in_point_list (_trace , model )
1742
1873
1743
1874
if compile_kwargs is None :
1744
1875
compile_kwargs = {}
1745
-
1746
- sampler_fn = compile_pymc (
1747
- inputs ,
1748
- vars_to_sample ,
1749
- allow_input_downcast = True ,
1750
- accept_inplace = True ,
1751
- on_unused_input = "ignore" ,
1752
- ** compile_kwargs ,
1876
+ compile_kwargs .setdefault ("allow_input_downcast" , True )
1877
+ compile_kwargs .setdefault ("accept_inplace" , True )
1878
+
1879
+ sampler_fn = point_wrapper (
1880
+ compile_forward_sampling_function (
1881
+ outputs = vars_to_sample ,
1882
+ vars_in_trace = vars_in_trace ,
1883
+ basic_rvs = model .basic_RVs ,
1884
+ givens_dict = None ,
1885
+ ** compile_kwargs ,
1886
+ )
1753
1887
)
1754
1888
1755
1889
ppc_trace_t = _DefaultTrace (samples )
@@ -1775,7 +1909,7 @@ def sample_posterior_predictive(
1775
1909
else :
1776
1910
param = _trace [idx % len_trace ]
1777
1911
1778
- values = sampler_fn (* ( param [ n ] for n in input_names ) )
1912
+ values = sampler_fn (** param )
1779
1913
1780
1914
for k , v in zip (vars_ , values ):
1781
1915
ppc_trace_t .insert (k .name , v , idx )
@@ -2063,16 +2197,16 @@ def sample_prior_predictive(
2063
2197
names .append (rv_var .name )
2064
2198
vars_to_sample .append (rv_var )
2065
2199
2066
- inputs = [i for i in inputvars (vars_to_sample ) if not isinstance (i , (Constant , SharedVariable ))]
2067
-
2068
2200
if compile_kwargs is None :
2069
2201
compile_kwargs = {}
2202
+ compile_kwargs .setdefault ("allow_input_downcast" , True )
2203
+ compile_kwargs .setdefault ("accept_inplace" , True )
2070
2204
2071
- sampler_fn = compile_pymc (
2072
- inputs ,
2205
+ sampler_fn = compile_forward_sampling_function (
2073
2206
vars_to_sample ,
2074
- allow_input_downcast = True ,
2075
- accept_inplace = True ,
2207
+ vars_in_trace = [],
2208
+ basic_rvs = model .basic_RVs ,
2209
+ givens_dict = None ,
2076
2210
** compile_kwargs ,
2077
2211
)
2078
2212
0 commit comments