4343import numpy as np
4444import xarray
4545
46- from aesara .graph .basic import Constant , Variable
47- from aesara .tensor import TensorVariable
46+ from aesara import tensor as at
47+ from aesara .graph .basic import Apply , Constant , Variable , general_toposort , walk
48+ from aesara .graph .fg import FunctionGraph
49+ from aesara .tensor .random .op import RandomVariable
50+ from aesara .tensor .random .var import (
51+ RandomGeneratorSharedVariable ,
52+ RandomStateSharedVariable ,
53+ )
4854from aesara .tensor .sharedvar import SharedVariable
4955from arviz import InferenceData
5056from fastprogress .fastprogress import progress_bar
5157from typing_extensions import TypeAlias
5258
5359import pymc as pm
5460
55- from pymc .aesaraf import change_rv_size , compile_pymc , inputvars , walk_model
61+ from pymc .aesaraf import change_rv_size , compile_pymc
5662from pymc .backends .arviz import _DefaultTrace
5763from pymc .backends .base import BaseTrace , MultiTrace
5864from pymc .backends .ndarray import NDArray
7581 get_default_varnames ,
7682 get_untransformed_name ,
7783 is_transformed_name ,
84+ point_wrapper ,
7885)
7986from pymc .vartypes import discrete_types
8087
8390__all__ = [
8491 "sample" ,
8592 "iter_sample" ,
93+ "compile_forward_sampling_function" ,
8694 "sample_posterior_predictive" ,
8795 "sample_posterior_predictive_w" ,
8896 "init_nuts" ,
@@ -1534,6 +1542,147 @@ def stop_tuning(step):
15341542 return step
15351543
15361544
1545+ def get_vars_in_point_list (trace , model ):
1546+ """Get the list of Variable instances in the model that have values stored in the trace."""
1547+ if not isinstance (trace , MultiTrace ):
1548+ names_in_trace = list (trace [0 ])
1549+ else :
1550+ names_in_trace = trace .varnames
1551+ vars_in_trace = [model [v ] for v in names_in_trace ]
1552+ return vars_in_trace
1553+
1554+
1555+ def compile_forward_sampling_function (
1556+ outputs : List [Variable ],
1557+ vars_in_trace : List [Variable ],
1558+ basic_rvs : Optional [List [Variable ]] = None ,
1559+ givens_dict : Optional [Dict [Variable , Any ]] = None ,
1560+ ** kwargs ,
1561+ ) -> Callable [..., Union [np .ndarray , List [np .ndarray ]]]:
1562+ """Compile a function to draw samples, conditioned on the values of some variables.
1563+
1564+ The goal of this function is to walk the aesara computational graph from the list
1565+ of output nodes down to the root nodes, and then compile a function that will produce
1566+ values for these output nodes. The compiled function will take as inputs the subset of
1567+ variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
1568+
1569+ Volatile variables are variables whose values could change between runs of the
1570+ compiled function or after inference has been run. These variables are:
1571+
1572+ - Variables in the outputs list
1573+ - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1574+ - Basic RVs that are not in the ``vars_in_trace`` list
1575+ - Variables that are keys in the ``givens_dict``
1576+ - Variables that have volatile inputs
1577+
1578+ Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1579+ that are in the ``basic_rvs`` list.
1580+
1581+ Concretely, this function can be used to compile a function to sample from the
1582+ posterior predictive distribution of a model that has variables that are conditioned
1583+ on ``MutableData`` instances. The variables that depend on the mutable data will be
1584+ considered volatile, and as such, they wont be included as inputs into the compiled function.
1585+ This means that if they have values stored in the posterior, these values will be ignored
1586+ and new values will be computed (in the case of deterministics and potentials) or sampled
1587+ (in the case of random variables).
1588+
1589+ This function also enables a way to impute values for any variable in the computational
1590+ graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
1591+ to set the ``givens`` argument of the aesara function compilation. This will essentially
1592+ replace a node in the computational graph with any other expression that has the same
1593+ type as the desired node. Passing variables in the givens_dict is considered an intervention
1594+ that might lead to different variable values from those that could have been seen during
1595+ inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
1596+ volatile**.
1597+
1598+ Parameters
1599+ ----------
1600+ outputs : List[aesara.graph.basic.Variable]
1601+ The list of variables that will be returned by the compiled function
1602+ vars_in_trace : List[aesara.graph.basic.Variable]
1603+ The list of variables that are assumed to have values stored in the trace
1604+ basic_rvs : Optional[List[aesara.graph.basic.Variable]]
1605+ A list of random variables that are defined in the model. This list (which could be the
1606+ output of ``model.basic_RVs``) should have a reference to the variables that should
1607+ be considered as random variable instances. This includes variables that have
1608+ a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1609+ Censored distributions. If ``None``, only pure random variables will be considered
1610+ as potential random variables.
1611+ givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
1612+ A dictionary that maps tensor variables to the values that should be used to replace them
1613+ in the compiled function. The types of the key and value should match or an error will be
1614+ raised during compilation.
1615+ """
1616+ if givens_dict is None :
1617+ givens_dict = {}
1618+
1619+ if basic_rvs is None :
1620+ basic_rvs = []
1621+
1622+ # We need a function graph to walk the clients and propagate the volatile property
1623+ fg = FunctionGraph (outputs = outputs , clone = False )
1624+
1625+ # Walk the graph from inputs to outputs and tag the volatile variables
1626+ nodes : List [Variable ] = general_toposort (
1627+ fg .outputs , deps = lambda x : x .owner .inputs if x .owner else []
1628+ )
1629+ volatile_nodes : Set [Any ] = set ()
1630+ for node in nodes :
1631+ if (
1632+ node in fg .outputs
1633+ or node in givens_dict
1634+ or ( # SharedVariables, except RandomState/Generators
1635+ isinstance (node , SharedVariable )
1636+ and not isinstance (node , (RandomStateSharedVariable , RandomGeneratorSharedVariable ))
1637+ )
1638+ or ( # Basic RVs that are not in the trace
1639+ node .owner
1640+ and isinstance (node .owner .op , RandomVariable )
1641+ and node in basic_rvs
1642+ and node not in vars_in_trace
1643+ )
1644+ or ( # Variables that have any volatile input
1645+ node .owner and any (inp in volatile_nodes for inp in node .owner .inputs )
1646+ )
1647+ ):
1648+ volatile_nodes .add (node )
1649+
1650+ # Collect the function inputs by walking the graph from the outputs. Inputs will be:
1651+ # 1. Random variables that are not volatile
1652+ # 2. Variables that have no owner and are not constant or shared
1653+ inputs = []
1654+
1655+ def expand (node ):
1656+ if (
1657+ (
1658+ node .owner is None and not isinstance (node , (Constant , SharedVariable ))
1659+ ) # Variables without owners that are not constant or shared
1660+ or node in vars_in_trace # Variables in the trace
1661+ ) and node not in volatile_nodes :
1662+ # This test will include variables without owners, and that are not constant
1663+ # or shared, because these nodes will never be considered volatile
1664+ inputs .append (node )
1665+ if node .owner :
1666+ return node .owner .inputs
1667+
1668+ # walk produces a generator, so we have to actually exhaust the generator in a list to walk
1669+ # the entire graph
1670+ list (walk (fg .outputs , expand ))
1671+
1672+ # Populate the givens list
1673+ givens = [
1674+ (
1675+ node ,
1676+ value
1677+ if isinstance (value , (Variable , Apply ))
1678+ else at .constant (value , dtype = getattr (node , "dtype" , None ), name = node .name ),
1679+ )
1680+ for node , value in givens_dict .items ()
1681+ ]
1682+
1683+ return compile_pymc (inputs , fg .outputs , givens = givens , on_unused_input = "ignore" , ** kwargs )
1684+
1685+
15371686def sample_posterior_predictive (
15381687 trace ,
15391688 samples : Optional [int ] = None ,
@@ -1718,38 +1867,23 @@ def sample_posterior_predictive(
17181867 return trace
17191868 return {}
17201869
1721- inputs : Sequence [TensorVariable ]
1722- input_names : Sequence [str ]
1723- if not isinstance (_trace , MultiTrace ):
1724- names_in_trace = list (_trace [0 ])
1725- else :
1726- names_in_trace = _trace .varnames
1727- inputs_and_names = [
1728- (rv , rv .name )
1729- for rv in walk_model (vars_to_sample , walk_past_rvs = True )
1730- if rv not in vars_to_sample
1731- and rv in model .named_vars .values ()
1732- and not isinstance (rv , (Constant , SharedVariable ))
1733- and rv .name in names_in_trace
1734- ]
1735- if inputs_and_names :
1736- inputs , input_names = zip (* inputs_and_names )
1737- else :
1738- inputs , input_names = [], []
1739-
17401870 if size is not None :
17411871 vars_to_sample = [change_rv_size (v , size , expand = True ) for v in vars_to_sample ]
1872+ vars_in_trace = get_vars_in_point_list (_trace , model )
17421873
17431874 if compile_kwargs is None :
17441875 compile_kwargs = {}
1745-
1746- sampler_fn = compile_pymc (
1747- inputs ,
1748- vars_to_sample ,
1749- allow_input_downcast = True ,
1750- accept_inplace = True ,
1751- on_unused_input = "ignore" ,
1752- ** compile_kwargs ,
1876+ compile_kwargs .setdefault ("allow_input_downcast" , True )
1877+ compile_kwargs .setdefault ("accept_inplace" , True )
1878+
1879+ sampler_fn = point_wrapper (
1880+ compile_forward_sampling_function (
1881+ outputs = vars_to_sample ,
1882+ vars_in_trace = vars_in_trace ,
1883+ basic_rvs = model .basic_RVs ,
1884+ givens_dict = None ,
1885+ ** compile_kwargs ,
1886+ )
17531887 )
17541888
17551889 ppc_trace_t = _DefaultTrace (samples )
@@ -1775,7 +1909,7 @@ def sample_posterior_predictive(
17751909 else :
17761910 param = _trace [idx % len_trace ]
17771911
1778- values = sampler_fn (* ( param [ n ] for n in input_names ) )
1912+ values = sampler_fn (** param )
17791913
17801914 for k , v in zip (vars_ , values ):
17811915 ppc_trace_t .insert (k .name , v , idx )
@@ -2063,16 +2197,16 @@ def sample_prior_predictive(
20632197 names .append (rv_var .name )
20642198 vars_to_sample .append (rv_var )
20652199
2066- inputs = [i for i in inputvars (vars_to_sample ) if not isinstance (i , (Constant , SharedVariable ))]
2067-
20682200 if compile_kwargs is None :
20692201 compile_kwargs = {}
2202+ compile_kwargs .setdefault ("allow_input_downcast" , True )
2203+ compile_kwargs .setdefault ("accept_inplace" , True )
20702204
2071- sampler_fn = compile_pymc (
2072- inputs ,
2205+ sampler_fn = compile_forward_sampling_function (
20732206 vars_to_sample ,
2074- allow_input_downcast = True ,
2075- accept_inplace = True ,
2207+ vars_in_trace = [],
2208+ basic_rvs = model .basic_RVs ,
2209+ givens_dict = None ,
20762210 ** compile_kwargs ,
20772211 )
20782212
0 commit comments