Skip to content

Commit 862bd05

Browse files
lucianopazricardoV94
authored andcommitted
Add compile_forward_sampling_function
1 parent 4969460 commit 862bd05

File tree

4 files changed

+392
-37
lines changed

4 files changed

+392
-37
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
139139
- `pymc.sampling_jax` samplers support `log_likelihood`, `observed_data`, and `sample_stats` in returned InferenceData object (see [#5189](https://github.com/pymc-devs/pymc/pull/5189))
140140
- Adding support for `pm.Deterministic` in `pymc.sampling_jax` (see [#5182](https://github.com/pymc-devs/pymc/pull/5182))
141141
- Added an alternative parametrization, `logit_p` to `pm.Binomial` and `pm.Categorical` distributions (see [5637](https://github.com/pymc-devs/pymc/pull/5637)).
142+
- Added the low level `compile_forward_sampling_function` method to compile the aesara function responsible for generating forward samples (see [#5759](https://github.com/pymc-devs/pymc/pull/5759)).
142143
- ...
143144

144145

pymc/sampling.py

+171-37
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,22 @@
4343
import numpy as np
4444
import xarray
4545

46-
from aesara.graph.basic import Constant, Variable
47-
from aesara.tensor import TensorVariable
46+
from aesara import tensor as at
47+
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
48+
from aesara.graph.fg import FunctionGraph
49+
from aesara.tensor.random.op import RandomVariable
50+
from aesara.tensor.random.var import (
51+
RandomGeneratorSharedVariable,
52+
RandomStateSharedVariable,
53+
)
4854
from aesara.tensor.sharedvar import SharedVariable
4955
from arviz import InferenceData
5056
from fastprogress.fastprogress import progress_bar
5157
from typing_extensions import TypeAlias
5258

5359
import pymc as pm
5460

55-
from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model
61+
from pymc.aesaraf import change_rv_size, compile_pymc
5662
from pymc.backends.arviz import _DefaultTrace
5763
from pymc.backends.base import BaseTrace, MultiTrace
5864
from pymc.backends.ndarray import NDArray
@@ -75,6 +81,7 @@
7581
get_default_varnames,
7682
get_untransformed_name,
7783
is_transformed_name,
84+
point_wrapper,
7885
)
7986
from pymc.vartypes import discrete_types
8087

@@ -83,6 +90,7 @@
8390
__all__ = [
8491
"sample",
8592
"iter_sample",
93+
"compile_forward_sampling_function",
8694
"sample_posterior_predictive",
8795
"sample_posterior_predictive_w",
8896
"init_nuts",
@@ -1534,6 +1542,147 @@ def stop_tuning(step):
15341542
return step
15351543

15361544

1545+
def get_vars_in_point_list(trace, model):
1546+
"""Get the list of Variable instances in the model that have values stored in the trace."""
1547+
if not isinstance(trace, MultiTrace):
1548+
names_in_trace = list(trace[0])
1549+
else:
1550+
names_in_trace = trace.varnames
1551+
vars_in_trace = [model[v] for v in names_in_trace]
1552+
return vars_in_trace
1553+
1554+
1555+
def compile_forward_sampling_function(
1556+
outputs: List[Variable],
1557+
vars_in_trace: List[Variable],
1558+
basic_rvs: Optional[List[Variable]] = None,
1559+
givens_dict: Optional[Dict[Variable, Any]] = None,
1560+
**kwargs,
1561+
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
1562+
"""Compile a function to draw samples, conditioned on the values of some variables.
1563+
1564+
The goal of this function is to walk the aesara computational graph from the list
1565+
of output nodes down to the root nodes, and then compile a function that will produce
1566+
values for these output nodes. The compiled function will take as inputs the subset of
1567+
variables in the ``vars_in_trace`` that are deemed to not be **volatile**.
1568+
1569+
Volatile variables are variables whose values could change between runs of the
1570+
compiled function or after inference has been run. These variables are:
1571+
1572+
- Variables in the outputs list
1573+
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1574+
- Basic RVs that are not in the ``vars_in_trace`` list
1575+
- Variables that are keys in the ``givens_dict``
1576+
- Variables that have volatile inputs
1577+
1578+
Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
1579+
that are in the ``basic_rvs`` list.
1580+
1581+
Concretely, this function can be used to compile a function to sample from the
1582+
posterior predictive distribution of a model that has variables that are conditioned
1583+
on ``MutableData`` instances. The variables that depend on the mutable data will be
1584+
considered volatile, and as such, they wont be included as inputs into the compiled function.
1585+
This means that if they have values stored in the posterior, these values will be ignored
1586+
and new values will be computed (in the case of deterministics and potentials) or sampled
1587+
(in the case of random variables).
1588+
1589+
This function also enables a way to impute values for any variable in the computational
1590+
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
1591+
to set the ``givens`` argument of the aesara function compilation. This will essentially
1592+
replace a node in the computational graph with any other expression that has the same
1593+
type as the desired node. Passing variables in the givens_dict is considered an intervention
1594+
that might lead to different variable values from those that could have been seen during
1595+
inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
1596+
volatile**.
1597+
1598+
Parameters
1599+
----------
1600+
outputs : List[aesara.graph.basic.Variable]
1601+
The list of variables that will be returned by the compiled function
1602+
vars_in_trace : List[aesara.graph.basic.Variable]
1603+
The list of variables that are assumed to have values stored in the trace
1604+
basic_rvs : Optional[List[aesara.graph.basic.Variable]]
1605+
A list of random variables that are defined in the model. This list (which could be the
1606+
output of ``model.basic_RVs``) should have a reference to the variables that should
1607+
be considered as random variable instances. This includes variables that have
1608+
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
1609+
Censored distributions. If ``None``, only pure random variables will be considered
1610+
as potential random variables.
1611+
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
1612+
A dictionary that maps tensor variables to the values that should be used to replace them
1613+
in the compiled function. The types of the key and value should match or an error will be
1614+
raised during compilation.
1615+
"""
1616+
if givens_dict is None:
1617+
givens_dict = {}
1618+
1619+
if basic_rvs is None:
1620+
basic_rvs = []
1621+
1622+
# We need a function graph to walk the clients and propagate the volatile property
1623+
fg = FunctionGraph(outputs=outputs, clone=False)
1624+
1625+
# Walk the graph from inputs to outputs and tag the volatile variables
1626+
nodes: List[Variable] = general_toposort(
1627+
fg.outputs, deps=lambda x: x.owner.inputs if x.owner else []
1628+
)
1629+
volatile_nodes: Set[Any] = set()
1630+
for node in nodes:
1631+
if (
1632+
node in fg.outputs
1633+
or node in givens_dict
1634+
or ( # SharedVariables, except RandomState/Generators
1635+
isinstance(node, SharedVariable)
1636+
and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
1637+
)
1638+
or ( # Basic RVs that are not in the trace
1639+
node.owner
1640+
and isinstance(node.owner.op, RandomVariable)
1641+
and node in basic_rvs
1642+
and node not in vars_in_trace
1643+
)
1644+
or ( # Variables that have any volatile input
1645+
node.owner and any(inp in volatile_nodes for inp in node.owner.inputs)
1646+
)
1647+
):
1648+
volatile_nodes.add(node)
1649+
1650+
# Collect the function inputs by walking the graph from the outputs. Inputs will be:
1651+
# 1. Random variables that are not volatile
1652+
# 2. Variables that have no owner and are not constant or shared
1653+
inputs = []
1654+
1655+
def expand(node):
1656+
if (
1657+
(
1658+
node.owner is None and not isinstance(node, (Constant, SharedVariable))
1659+
) # Variables without owners that are not constant or shared
1660+
or node in vars_in_trace # Variables in the trace
1661+
) and node not in volatile_nodes:
1662+
# This test will include variables without owners, and that are not constant
1663+
# or shared, because these nodes will never be considered volatile
1664+
inputs.append(node)
1665+
if node.owner:
1666+
return node.owner.inputs
1667+
1668+
# walk produces a generator, so we have to actually exhaust the generator in a list to walk
1669+
# the entire graph
1670+
list(walk(fg.outputs, expand))
1671+
1672+
# Populate the givens list
1673+
givens = [
1674+
(
1675+
node,
1676+
value
1677+
if isinstance(value, (Variable, Apply))
1678+
else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
1679+
)
1680+
for node, value in givens_dict.items()
1681+
]
1682+
1683+
return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)
1684+
1685+
15371686
def sample_posterior_predictive(
15381687
trace,
15391688
samples: Optional[int] = None,
@@ -1718,38 +1867,23 @@ def sample_posterior_predictive(
17181867
return trace
17191868
return {}
17201869

1721-
inputs: Sequence[TensorVariable]
1722-
input_names: Sequence[str]
1723-
if not isinstance(_trace, MultiTrace):
1724-
names_in_trace = list(_trace[0])
1725-
else:
1726-
names_in_trace = _trace.varnames
1727-
inputs_and_names = [
1728-
(rv, rv.name)
1729-
for rv in walk_model(vars_to_sample, walk_past_rvs=True)
1730-
if rv not in vars_to_sample
1731-
and rv in model.named_vars.values()
1732-
and not isinstance(rv, (Constant, SharedVariable))
1733-
and rv.name in names_in_trace
1734-
]
1735-
if inputs_and_names:
1736-
inputs, input_names = zip(*inputs_and_names)
1737-
else:
1738-
inputs, input_names = [], []
1739-
17401870
if size is not None:
17411871
vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample]
1872+
vars_in_trace = get_vars_in_point_list(_trace, model)
17421873

17431874
if compile_kwargs is None:
17441875
compile_kwargs = {}
1745-
1746-
sampler_fn = compile_pymc(
1747-
inputs,
1748-
vars_to_sample,
1749-
allow_input_downcast=True,
1750-
accept_inplace=True,
1751-
on_unused_input="ignore",
1752-
**compile_kwargs,
1876+
compile_kwargs.setdefault("allow_input_downcast", True)
1877+
compile_kwargs.setdefault("accept_inplace", True)
1878+
1879+
sampler_fn = point_wrapper(
1880+
compile_forward_sampling_function(
1881+
outputs=vars_to_sample,
1882+
vars_in_trace=vars_in_trace,
1883+
basic_rvs=model.basic_RVs,
1884+
givens_dict=None,
1885+
**compile_kwargs,
1886+
)
17531887
)
17541888

17551889
ppc_trace_t = _DefaultTrace(samples)
@@ -1775,7 +1909,7 @@ def sample_posterior_predictive(
17751909
else:
17761910
param = _trace[idx % len_trace]
17771911

1778-
values = sampler_fn(*(param[n] for n in input_names))
1912+
values = sampler_fn(**param)
17791913

17801914
for k, v in zip(vars_, values):
17811915
ppc_trace_t.insert(k.name, v, idx)
@@ -2063,16 +2197,16 @@ def sample_prior_predictive(
20632197
names.append(rv_var.name)
20642198
vars_to_sample.append(rv_var)
20652199

2066-
inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, (Constant, SharedVariable))]
2067-
20682200
if compile_kwargs is None:
20692201
compile_kwargs = {}
2202+
compile_kwargs.setdefault("allow_input_downcast", True)
2203+
compile_kwargs.setdefault("accept_inplace", True)
20702204

2071-
sampler_fn = compile_pymc(
2072-
inputs,
2205+
sampler_fn = compile_forward_sampling_function(
20732206
vars_to_sample,
2074-
allow_input_downcast=True,
2075-
accept_inplace=True,
2207+
vars_in_trace=[],
2208+
basic_rvs=model.basic_RVs,
2209+
givens_dict=None,
20762210
**compile_kwargs,
20772211
)
20782212

0 commit comments

Comments
 (0)