-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
2560: fail-fast when trace vars mismatches model vars #2658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
#2560 This largely looks good to me. Some suggestions are to create a regression test with your .py file in it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this -- it is a funny problem, and I think this is a helpful fix. I left perhaps overly specific feedback on the test, since it can be hard to navigate our test suite!
@@ -0,0 +1,15 @@ | |||
import pymc3 as pm | |||
with pm.Model() as model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a function in pymc3/tests/test_text_backend.py
. Something like
def test_wont_dump_mismatched_vars():
with pm.Model() as model:
a = pm.Gamma('a', mu=10.0, sd=2.0)
b = pm.Gamma('b', mu=10.0, sd=2.0)
partial_trace = pm.sample(draws=100, tune=0, trace=[model.a, model.a_log__])
assert len(partial_trace.varnames) == 2
with pytest.raises(ValueError):
pm.backends.text.dump('trace.text', partial_trace)
trace_vars = sorted(trace.varnames) | ||
|
||
if model_vars != trace_vars: | ||
raise ValueError('Variables mismatch: model_vars=' + str(model_vars) + ", trace_vars=" + str(trace_vars)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error might be more explicit:
ValueError('Cannot serialize a trace that does not include all model variables.'
'\n\tmodel variables: {}'
'\n\ttrace variables: {}'.format(', '.join(model_vars), ', '.join(trace_vars))
Valid points @ColCarroll. I too find the tests hard to navigate.
…On 24 Oct 2017 9:35 pm, "Colin" ***@***.***> wrote:
***@***.**** commented on this pull request.
Thanks for this -- it is a funny problem, and I think this is a helpful
fix. I left perhaps overly specific feedback on the test, since it can be
hard to navigate our test suite!
------------------------------
In 2560.py
<#2658 (comment)>:
> @@ -0,0 +1,15 @@
+import pymc3 as pm
+with pm.Model() as model:
this should be a function in pymc3/tests/test_text_backend.py. Something
like
def test_wont_dump_mismatched_vars():
with pm.Model() as model:
a = pm.Gamma('a', mu=10.0, sd=2.0)
b = pm.Gamma('b', mu=10.0, sd=2.0)
partial_trace = pm.sample(draws=100, tune=0, trace=[model.a, model.a_log__])
assert len(partial_trace.varnames) == 2
with pytest.raises(ValueError):
pm.backends.text.dump('trace.text', partial_trace)
------------------------------
In pymc3/backends/text.py
<#2658 (comment)>:
> @@ -200,6 +201,12 @@ def dump(name, trace, chains=None):
os.mkdir(name)
if chains is None:
chains = trace.chains
+ model = modelcontext(model)
+ model_vars = sorted([x.name for x in model.unobserved_RVs])
+ trace_vars = sorted(trace.varnames)
+
+ if model_vars != trace_vars:
+ raise ValueError('Variables mismatch: model_vars=' + str(model_vars) + ", trace_vars=" + str(trace_vars))
This error might be more explicit:
ValueError('Cannot serialize a trace that does not include all model variables.'
'\n\tmodel variables: {}'
'\n\ttrace variables: {}'.format(', '.join(model_vars), ', '.join(trace_vars))
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#2658 (review)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AA8DiB6tW3GVaIu0RMR2JzMFgU3CKDAIks5svkn1gaJpZM4QFBGz>
.
|
Does this PR have legs, or can it be closed? |
I think it can be closed |
I'm not sure about testing, I reproduced mine in 2560.py - but I guess it's better to add it in pymc3/tests//test_text_backend.py. Do I really want to create text files in unit tests though?