Skip to content

Refactor pm.Simulator #4903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2021
Merged

Conversation

ricardoV94
Copy link
Member

Third attempt at refactoring pm.Simulator.

Apparently cloudpickle overcomes the limitations that we were facing in previous iterations of this PR in #4802 and #4877

@ricardoV94 ricardoV94 added the SMC Sequential Monte Carlo label Aug 4, 2021
@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch 2 times, most recently from b830c20 to b14d95c Compare August 4, 2021 14:09
# NoDistribution.register(rv_type)
NoDistribution.register(SimulatorRV)

# @_logp.register(rv_type)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucianopaz, I cannot register it on the new rv_type, or I get a weird failure on multiprocessing, where the logp defaults to the standard at.zeros_like(value).

This causes the first Simulator test to fail, complaining that the input variables are not used in the compiled function (which is true because the default logp is based only on value)

The strange thing, is that when I run it on Jupyter, it fails in the first time the cell is run but passes the second time.

dtype=dtype,
inplace=False,
fn=fn,
_distance=distance,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I wouldn't have to pass distance, sum_stat or epsilon via the SimulatorRV, but the dispatching is not working very well yet, see comment below

@codecov
Copy link

codecov bot commented Aug 4, 2021

Codecov Report

Merging #4903 (737b24b) into main (bcc40ce) will decrease coverage by 1.86%.
The diff coverage is 95.04%.

❗ Current head 737b24b differs from pull request most recent head 7c9e31f. Consider uploading reports for the commit 7c9e31f to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main    #4903      +/-   ##
==========================================
- Coverage   76.34%   74.48%   -1.87%     
==========================================
  Files          86       86              
  Lines       13931    13927       -4     
==========================================
- Hits        10636    10373     -263     
- Misses       3295     3554     +259     
Impacted Files Coverage Δ
pymc3/distributions/simulator.py 86.25% <94.23%> (+60.39%) ⬆️
pymc3/aesaraf.py 91.48% <100.00%> (+0.21%) ⬆️
pymc3/distributions/distribution.py 85.71% <100.00%> (+3.04%) ⬆️
pymc3/smc/sample_smc.py 85.50% <100.00%> (+0.50%) ⬆️
pymc3/smc/smc.py 98.70% <100.00%> (+0.32%) ⬆️
pymc3/distributions/bart.py 18.05% <0.00%> (-81.95%) ⬇️
pymc3/step_methods/pgbart.py 18.12% <0.00%> (-77.68%) ⬇️
pymc3/distributions/tree.py 30.86% <0.00%> (-69.14%) ⬇️
pymc3/distributions/multivariate.py 71.45% <0.00%> (-0.58%) ⬇️
pymc3/step_methods/metropolis.py 83.25% <0.00%> (-0.45%) ⬇️
... and 13 more

@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch from b14d95c to 0502d9c Compare August 4, 2021 16:02
@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 4, 2021

I still have an issue with the logp function of models with Simulator variables. The upstream input RandomVariables persist in the final compiled function, even though they don't contribute at all to the logp term: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429

They seem to persist in the graph of the sim_value_rng despite being properly replaced in the graph of sim_value

This might by why the Simulator is performing slightly worse than in V3.

@brandonwillard do you have any idea?

@brandonwillard
Copy link
Contributor

brandonwillard commented Aug 4, 2021

I still have an issue with the logp function of models with Simulator variables. The upstream input RandomVariables persist in the final compiled function, even though their output is completely disconnected: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429

They seem to persist in the graph of the sim_value_rng even though they are properly replaced in the graph of sim_value

This might by why the Simulator is performing slightly worse than in V3.

@brandonwillard do you have any idea?

It looks like it's adding the shared RNG states implicitly created by a, b, and s as outputs to the compiled m.logpt graph, and I believe this has to do with shared variable updates. Those RNG state inputs have a default_update attribute (e.g. a.owner.inputs[0].default_update) that's probably being picked up by aesara.function and adding them to the compiled function's outputs.

In other words, aesara.function could be automatically adding things to its set of updates (i.e. shared variables that it needs to update), and those updates require samples from a and b in order to update the shared RNG state outputs.

If you add no_default_updates=True to the aesara.function call, you probably won't see those extra outputs; however, we should think about the repercussions of doing that, and perhaps adjust something in Aesara, if need be.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 5, 2021

It looks like it's adding the shared RNG states implicitly created by a, b, and s as outputs to the compiled m.logpt graph, and I believe this has to do with shared variable updates. Those RNG state inputs have a default_update attribute (e.g. a.owner.inputs[0].default_update) that's probably being picked up by aesara.function and adding them to the compiled function's outputs.

I updated my gist where I try to manually recreate this sort of shared rngs with default update and then replace one of the inputs, and it seems to not have the same issue: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429

If you add no_default_updates=True to the aesara.function call, you probably won't see those extra outputs; however, we should think about the repercussions of doing that, and perhaps adjust something in Aesara, if need be.

That indeed removes them from the graph, but also removes the automatic updates for the inserted RandomVariable, so it's not what I need.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 5, 2021

@brandonwillard, if I add the RandomVariable to replacements in pymc3.aesaraf.rvs_to_value_vars, the compiled logp function no longer has the unecessary upstream RVs:

def transform_replacements(var, replacements):
    rv_var, rv_value_var = extract_rv_and_value_vars(var)

    if rv_value_var is None:
        # If RandomVariable does not have a value_var and corresponds to
        # a SimulatorRV, we allow further replacements in upstream graph
        if isinstance(rv_var.owner.op, SimulatorRV):
            replacements[rv_var] = rv_var  # << ADDED LINE
            # First 3 inputs are just rng, dtype, and size, which don't
            # need to be replaced.
            return rv_var.owner.inputs[3:]
import numpy as np
import aesara
import pymc3 as pm

def normal_sim(rng, a, b, size):
    return rng.normal(a, b, size=size)

data = np.random.normal(loc=0, scale=1, size=10)

with pm.Model() as m:
    a = pm.Normal("a", mu=0, sigma=1)
    b = pm.HalfNormal("b", sigma=1)
    s = pm.Simulator("s", normal_sim, a, b, observed=5)

a.owner.outputs[0].name = "a_rng"
b.owner.outputs[0].name = "b_rng"
s.owner.outputs[0].name = "s_rng"

f = aesara.function([a.tag.value_var, b.tag.value_var], m.logpt)
aesara.dprint(f)
Sum{acc_dtype=float64} [id A] '__logp'   6
 |MakeVector{dtype='float64'} [id B] ''   5
   |Elemwise{Composite{(i0 * (i1 + (-sqr(i2))))}} [id C] '__logp_a'   1
   | |TensorConstant{0.5} [id D]
   | |TensorConstant{-1.8378770664093453} [id E]
   | |a [id F]
   |Elemwise{Composite{(Switch(Cast{int8}(GE(i0, i1)), (i2 + (i3 * sqr(i0))), i4) + i5)}}[(0, 0)] [id G] '__logp_b'   3
   | |Elemwise{exp,no_inplace} [id H] ''   0
   | | |b_log__ [id I]
   | |TensorConstant{0.0} [id J]
   | |TensorConstant{-0.2257913526447274} [id K]
   | |TensorConstant{-0.5} [id L]
   | |TensorConstant{-inf} [id M]
   | |b_log__ [id I]
   |Elemwise{Composite{(i0 * sqr((i1 - i2)))}}[(0, 2)] [id N] '__logp_s'   4
     |TensorConstant{-0.5} [id L]
     |TensorConstant{5.0} [id O]
     |Simulator_rv{0, (0, 0), floatX, True}.1 [id P] 'sim_value'   2
       |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F383DD42120>) [id Q]
       |TensorConstant{[]} [id R]
       |TensorConstant{11} [id S]
       |a [id F]
       |Elemwise{exp,no_inplace} [id H] ''   0
Simulator_rv{0, (0, 0), floatX, True}.0 [id P] ''   2

However, pymc3.aesaraf.join_nonshared_inputs stops working for this type of graphs, probably due to a conflict between the default_update and aesara.clone_replace which is called inside (although it did work before with the more inefficient graph):

(new_datalogpt,), inarray = pm.aesaraf.join_nonshared_inputs(m.initial_point, [m.logpt], m.value_vars, {})
new_f = aesara.function([inarray], new_datalogpt)
aesara.dprint(new_f)
MissingInputError: Input 0 (b_log__) of the graph (indices start from 0), used to compute Elemwise{exp,no_inplace}(b_log__), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.

Two steps forward, one step backward...

@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch 2 times, most recently from ba0bf79 to fdb4fe8 Compare August 5, 2021 12:33
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng.default_update = rv.owner.outputs[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a WIP hack that seems to fix the issues I was finding in this PR, until I get to the bottom of the problem.

In the meantime and just to be sure, is it a bad idea to set the default_updates here before compilation. It requires a full graph transversal, but also means we could remove this code here: https://github.com/pymc-devs/pymc3/blob/c913709d02deab1307df89245782630d8f25d953/pymc3/distributions/distribution.py#L307-L322

CC @brandonwillard

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 5, 2021

I added a hack to delay the specification of default_update for the shared RandomState variables until the function is actually compiled. This avoids the issue where old RVs remain in the logp graph, upstream of the Simulator rng output.

Now this might not be something we want to do just for the sake of the Simulator in which case I can create a specialized compile_rv_inplace. However, this means that the Simulator will only work with SMC, since other samplers will keep calling the old compile_rv_inplace.

For the good news, the slowdown I was observing relative to V3 is gone!


Edit: This gist shows the limitation of aesaraf.rvs_to_value_vars when the model is supposed to contain a RandomVariable. Put simply the RVs in the default_update graphs are not being replaced by value vars: https://gist.github.com/ricardoV94/11e0bba211b746d369252a51773cf429

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 10, 2021

I created a (semi-)minimal example of the problem where I'm stuck at: https://gist.github.com/ricardoV94/08badc6c501652559900fdcb276e7e50

I need to figure out what is needed for aesaraf.join_nonshared_inputs to cope with RV models where the shared RNGs have a default_update set.

This is pretty much the last thing blocking the Simulator refactor

@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch 6 times, most recently from 4c984ec to ef2a400 Compare August 17, 2021 09:57
@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch from ef2a400 to 4aebf29 Compare August 17, 2021 10:10
@ricardoV94 ricardoV94 force-pushed the restore_abc_nice_api branch from 737b24b to 0635e08 Compare August 18, 2021 10:27
@ricardoV94 ricardoV94 marked this pull request as ready for review September 22, 2021 01:09
@twiecki twiecki merged commit 55d455a into pymc-devs:main Sep 22, 2021
@ricardoV94 ricardoV94 deleted the restore_abc_nice_api branch January 31, 2022 09:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
SMC Sequential Monte Carlo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants