Skip to content

Commit 076afb8

Browse files
ricardoV94aloctavodia
authored andcommitted
Refactor pm.Simulator
Co-authored-by: Osvaldo Martin <[email protected]>
1 parent 4f8ad5d commit 076afb8

File tree

8 files changed

+554
-210
lines changed

8 files changed

+554
-210
lines changed

docs/source/api/distributions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Distributions
88
distributions/discrete
99
distributions/multivariate
1010
distributions/mixture
11+
distributions/simulator
1112
distributions/timeseries
1213
distributions/transforms
1314
distributions/utilities
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
**********
2+
Simulator
3+
**********
4+
5+
.. currentmodule:: pymc3.distributions.simulator
6+
.. autosummary::
7+
8+
Simulator
9+
10+
.. automodule:: pymc3.distributions.simulator
11+
:members:

pymc3/aesaraf.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,24 @@ def rvs_to_value_vars(
350350
351351
"""
352352

353+
# Avoid circular dependency
354+
from pymc3.distributions import NoDistribution
355+
353356
def transform_replacements(var, replacements):
354357
rv_var, rv_value_var = extract_rv_and_value_vars(var)
355358

356359
if rv_value_var is None:
357-
warnings.warn(
358-
f"No value variable found for {rv_var}; "
359-
"the random variable will not be replaced."
360-
)
361-
return []
360+
# If RandomVariable does not have a value_var and corresponds to
361+
# a NoDistribution, we allow further replacements in upstream graph
362+
if isinstance(rv_var.owner.op, NoDistribution):
363+
return rv_var.owner.inputs
364+
365+
else:
366+
warnings.warn(
367+
f"No value variable found for {rv_var}; "
368+
"the random variable will not be replaced."
369+
)
370+
return []
362371

363372
transform = getattr(rv_value_var.tag, "transform", None)
364373

pymc3/distributions/distribution.py

+7-42
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from typing import Optional
2424

2525
import aesara
26-
import aesara.tensor as at
2726

2827
from aesara.tensor.random.op import RandomVariable
2928
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -353,47 +352,6 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
353352
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])
354353

355354

356-
class NoDistribution(Distribution):
357-
def __init__(
358-
self,
359-
shape,
360-
dtype,
361-
initval=None,
362-
defaults=(),
363-
parent_dist=None,
364-
*args,
365-
**kwargs,
366-
):
367-
super().__init__(
368-
shape=shape, dtype=dtype, initval=initval, defaults=defaults, *args, **kwargs
369-
)
370-
self.parent_dist = parent_dist
371-
372-
def __getattr__(self, name):
373-
# Do not use __getstate__ and __setstate__ from parent_dist
374-
# to avoid infinite recursion during unpickling
375-
if name.startswith("__"):
376-
raise AttributeError("'NoDistribution' has no attribute '%s'" % name)
377-
return getattr(self.parent_dist, name)
378-
379-
def logp(self, x):
380-
"""Calculate log probability.
381-
382-
Parameters
383-
----------
384-
x: numeric
385-
Value for which log-probability is calculated.
386-
387-
Returns
388-
-------
389-
TensorVariable
390-
"""
391-
return at.zeros_like(x)
392-
393-
def _distr_parameters_for_repr(self):
394-
return []
395-
396-
397355
class Discrete(Distribution):
398356
"""Base class for discrete distributions"""
399357

@@ -409,6 +367,13 @@ class Continuous(Distribution):
409367
"""Base class for continuous distributions"""
410368

411369

370+
class NoDistribution(Distribution):
371+
"""Base class for artifical distributions
372+
373+
RandomVariables that share this type are allowed in logprob graphs
374+
"""
375+
376+
412377
class DensityDist(Distribution):
413378
"""Distribution based on a given log density function.
414379

0 commit comments

Comments
 (0)