Skip to content

2560: fail-fast when trace vars mismatches model vars #2658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

jazzmuesli
Copy link

I'm not sure about testing, I reproduced mine in 2560.py - but I guess it's better to add it in pymc3/tests//test_text_backend.py. Do I really want to create text files in unit tests though?

@springcoil
Copy link
Contributor

#2560 This largely looks good to me. Some suggestions are to create a regression test with your .py file in it.

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this -- it is a funny problem, and I think this is a helpful fix. I left perhaps overly specific feedback on the test, since it can be hard to navigate our test suite!

@@ -0,0 +1,15 @@
import pymc3 as pm
with pm.Model() as model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a function in pymc3/tests/test_text_backend.py. Something like

def test_wont_dump_mismatched_vars():
    with pm.Model() as model:
        a = pm.Gamma('a', mu=10.0, sd=2.0)
        b = pm.Gamma('b', mu=10.0, sd=2.0)

        partial_trace = pm.sample(draws=100, tune=0, trace=[model.a, model.a_log__])
        assert len(partial_trace.varnames) == 2
    
        with pytest.raises(ValueError):
            pm.backends.text.dump('trace.text', partial_trace)

trace_vars = sorted(trace.varnames)

if model_vars != trace_vars:
raise ValueError('Variables mismatch: model_vars=' + str(model_vars) + ", trace_vars=" + str(trace_vars))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error might be more explicit:

ValueError('Cannot serialize a trace that does not include all model variables.'
           '\n\tmodel variables: {}'
           '\n\ttrace variables: {}'.format(', '.join(model_vars), ', '.join(trace_vars))

@springcoil
Copy link
Contributor

springcoil commented Oct 24, 2017 via email

@fonnesbeck
Copy link
Member

Does this PR have legs, or can it be closed?

@springcoil
Copy link
Contributor

I think it can be closed

@twiecki twiecki closed this Oct 30, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants