Skip to content

Refactor pm.Simulator #4903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Distributions
distributions/discrete
distributions/multivariate
distributions/mixture
distributions/simulator
distributions/timeseries
distributions/transforms
distributions/utilities
11 changes: 11 additions & 0 deletions docs/source/api/distributions/simulator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
**********
Simulator
**********

.. currentmodule:: pymc3.distributions.simulator
.. autosummary::

Simulator

.. automodule:: pymc3.distributions.simulator
:members:
35 changes: 30 additions & 5 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,24 @@ def rvs_to_value_vars(

"""

# Avoid circular dependency
from pymc3.distributions import NoDistribution

def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)

if rv_value_var is None:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []
# If RandomVariable does not have a value_var and corresponds to
# a NoDistribution, we allow further replacements in upstream graph
if isinstance(rv_var.owner.op, NoDistribution):
return rv_var.owner.inputs

else:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []

transform = getattr(rv_value_var.tag, "transform", None)

Expand Down Expand Up @@ -886,6 +895,22 @@ def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
Using this function ensures that compiled functions containing random
variables will produce new samples on each call.
"""

# Avoid circular dependency
from pymc3.distributions import NoDistribution

# Set the default update of a NoDistribution RNG so that it is automatically
# updated after every function call
output_to_list = outputs if isinstance(outputs, list) else [outputs]
for rv in (
node
for node in walk_model(output_to_list, walk_past_rvs=True)
if node.owner and isinstance(node.owner.op, NoDistribution)
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng.default_update = rv.owner.outputs[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a WIP hack that seems to fix the issues I was finding in this PR, until I get to the bottom of the problem.

In the meantime and just to be sure, is it a bad idea to set the default_updates here before compilation. It requires a full graph transversal, but also means we could remove this code here: https://github.com/pymc-devs/pymc3/blob/c913709d02deab1307df89245782630d8f25d953/pymc3/distributions/distribution.py#L307-L322

CC @brandonwillard


mode = get_mode(mode)
opt_qry = mode.provided_optimizer.including("random_make_inplace")
mode = Mode(linker=mode.linker, optimizer=opt_qry)
Expand Down
49 changes: 7 additions & 42 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import Optional

import aesara
import aesara.tensor as at

from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
Expand Down Expand Up @@ -353,47 +352,6 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])


class NoDistribution(Distribution):
def __init__(
self,
shape,
dtype,
initval=None,
defaults=(),
parent_dist=None,
*args,
**kwargs,
):
super().__init__(
shape=shape, dtype=dtype, initval=initval, defaults=defaults, *args, **kwargs
)
self.parent_dist = parent_dist

def __getattr__(self, name):
# Do not use __getstate__ and __setstate__ from parent_dist
# to avoid infinite recursion during unpickling
if name.startswith("__"):
raise AttributeError("'NoDistribution' has no attribute '%s'" % name)
return getattr(self.parent_dist, name)

def logp(self, x):
"""Calculate log probability.

Parameters
----------
x: numeric
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
return at.zeros_like(x)

def _distr_parameters_for_repr(self):
return []


class Discrete(Distribution):
"""Base class for discrete distributions"""

Expand All @@ -409,6 +367,13 @@ class Continuous(Distribution):
"""Base class for continuous distributions"""


class NoDistribution(Distribution):
"""Base class for artifical distributions

RandomVariables that share this type are allowed in logprob graphs
"""


class DensityDist(Distribution):
"""Distribution based on a given log density function.

Expand Down
Loading