-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor pm.Simulator
#4903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor pm.Simulator
#4903
Conversation
b830c20
to
b14d95c
Compare
# NoDistribution.register(rv_type) | ||
NoDistribution.register(SimulatorRV) | ||
|
||
# @_logp.register(rv_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucianopaz, I cannot register it on the new rv_type
, or I get a weird failure on multiprocessing, where the logp defaults to the standard at.zeros_like(value)
.
This causes the first Simulator test to fail, complaining that the input variables are not used in the compiled function (which is true because the default logp
is based only on value
)
The strange thing, is that when I run it on Jupyter, it fails in the first time the cell is run but passes the second time.
dtype=dtype, | ||
inplace=False, | ||
fn=fn, | ||
_distance=distance, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally I wouldn't have to pass distance
, sum_stat
or epsilon
via the SimulatorRV
, but the dispatching is not working very well yet, see comment below
Codecov Report
@@ Coverage Diff @@
## main #4903 +/- ##
==========================================
- Coverage 76.34% 74.48% -1.87%
==========================================
Files 86 86
Lines 13931 13927 -4
==========================================
- Hits 10636 10373 -263
- Misses 3295 3554 +259
|
b14d95c
to
0502d9c
Compare
I still have an issue with the logp function of models with Simulator variables. The upstream input RandomVariables persist in the final compiled function, even though they don't contribute at all to the logp term: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429 They seem to persist in the graph of the This might by why the Simulator is performing slightly worse than in V3. @brandonwillard do you have any idea? |
It looks like it's adding the shared RNG states implicitly created by In other words, If you add |
I updated my gist where I try to manually recreate this sort of shared rngs with default update and then replace one of the inputs, and it seems to not have the same issue: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429
That indeed removes them from the graph, but also removes the automatic updates for the inserted RandomVariable, so it's not what I need. |
@brandonwillard, if I add the RandomVariable to replacements in def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)
if rv_value_var is None:
# If RandomVariable does not have a value_var and corresponds to
# a SimulatorRV, we allow further replacements in upstream graph
if isinstance(rv_var.owner.op, SimulatorRV):
replacements[rv_var] = rv_var # << ADDED LINE
# First 3 inputs are just rng, dtype, and size, which don't
# need to be replaced.
return rv_var.owner.inputs[3:] import numpy as np
import aesara
import pymc3 as pm
def normal_sim(rng, a, b, size):
return rng.normal(a, b, size=size)
data = np.random.normal(loc=0, scale=1, size=10)
with pm.Model() as m:
a = pm.Normal("a", mu=0, sigma=1)
b = pm.HalfNormal("b", sigma=1)
s = pm.Simulator("s", normal_sim, a, b, observed=5)
a.owner.outputs[0].name = "a_rng"
b.owner.outputs[0].name = "b_rng"
s.owner.outputs[0].name = "s_rng"
f = aesara.function([a.tag.value_var, b.tag.value_var], m.logpt)
aesara.dprint(f)
However, (new_datalogpt,), inarray = pm.aesaraf.join_nonshared_inputs(m.initial_point, [m.logpt], m.value_vars, {})
new_f = aesara.function([inarray], new_datalogpt)
aesara.dprint(new_f)
Two steps forward, one step backward... |
ba0bf79
to
fdb4fe8
Compare
): | ||
rng = rv.owner.inputs[0] | ||
if not hasattr(rng, "default_update"): | ||
rng.default_update = rv.owner.outputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a WIP hack that seems to fix the issues I was finding in this PR, until I get to the bottom of the problem.
In the meantime and just to be sure, is it a bad idea to set the default_updates
here before compilation. It requires a full graph transversal, but also means we could remove this code here: https://github.com/pymc-devs/pymc3/blob/c913709d02deab1307df89245782630d8f25d953/pymc3/distributions/distribution.py#L307-L322
I added a hack to delay the specification of Now this might not be something we want to do just for the sake of the For the good news, the slowdown I was observing relative to V3 is gone! Edit: This gist shows the limitation of |
69f827a
to
33acfc0
Compare
I created a (semi-)minimal example of the problem where I'm stuck at: https://gist.github.com/ricardoV94/08badc6c501652559900fdcb276e7e50 I need to figure out what is needed for This is pretty much the last thing blocking the Simulator refactor |
4c984ec
to
ef2a400
Compare
ef2a400
to
4aebf29
Compare
737b24b
to
0635e08
Compare
0635e08
to
2bb851b
Compare
Co-authored-by: Osvaldo Martin <[email protected]>
2bb851b
to
7c9e31f
Compare
Third attempt at refactoring
pm.Simulator
.Apparently
cloudpickle
overcomes the limitations that we were facing in previous iterations of this PR in #4802 and #4877