-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rewrites to replace or remove Aeppl CheckParameterValue Ops #5233
Conversation
c46dafe
to
3b99d14
Compare
b283501
to
5caeb9a
Compare
8c763c6
to
9735754
Compare
Codecov Report
@@ Coverage Diff @@
## main #5233 +/- ##
==========================================
+ Coverage 78.98% 79.12% +0.13%
==========================================
Files 88 88
Lines 14231 14301 +70
==========================================
+ Hits 11240 11315 +75
+ Misses 2991 2986 -5
|
pymc/aesaraf.py
Outdated
for database in ("canonicalize", "stabilize", "specialize", "useless"): | ||
aesara.compile.optdb[database].register( | ||
"local_remove_check_parameter", | ||
local_remove_check_parameter, | ||
use_db_name_as_tag=False, | ||
) | ||
|
||
aesara.compile.optdb[database].register( | ||
"local_check_parameter_to_ninf_switch", | ||
local_check_parameter_to_ninf_switch, | ||
use_db_name_as_tag=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add these to all those databases?
Also, don't forget that you can set a priority that affects the order in which they're run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied what I've seen elsewhere.
Can you give some recommendation on what is the minimum database(s) where these should be registered, as well as what the priority should be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to put those in just the pass named "useless"
and have them removed before the other passes. Check out aesara.compile.mode
to see the passes and their ordering.
Otherwise, I would think that removing those checks should be a user-configurable option, and, if so, you might not want to include them by default. Instead, an option could be set that includes those rewrites when graphs are compiled by PyMC (e.g. by constructing a Mode
object for aesara.function
that uses includes=...
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, I would think that removing those checks should be a user-configurable option, and, if so, you might not want to include them by default. Instead, an option could be set that includes those rewrites when graphs are compiled by PyMC (e.g. by constructing a Mode object for aesara.function that uses includes=...).
The rewrites are not running by default due to the use_db_name_as_tag=False
option. I am manually including them in compile_pymc
helper function only. Otherwise the new tests that expect the specific Exception would fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to put those in just the pass named "useless" and have them removed before the other passes. Check out aesara.compile.mode to see the passes and their ordering.
For some reason putting it in "useless" alone is not doing anything, but putting them in "canonicalize" seems to do the job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need to refactor that OptimizationDatabase
code; the whole use_db_name_as_tag
thing, and the special way it works with only one DB implementation (i.e. EquilibriumDB
), is not good. Just the fact that there's a note about a specific subclass in the base class is a bad sign.
Anyway, "useless"
doesn't use the same kind of underlying OptimizationDatabase
types as "canonicalize"
, so I'm guessing that the reason the latter works and the former doesn't is related to that.
18c51db
to
4284fc8
Compare
The failing dirichlet_multinomial tests should be solved by #5234 |
b322a33
to
d943a5c
Compare
…essions * These rewrites are included by default when calling pymc.aesaraf.compile_pymc, which was previously named compile_rv_inplace * pymc.distributions.dist_math.bound was renamed to check_parameters and now returns an expression wrapped in the CheckParameterValue Op * The Model.check_bounds flag now only affects graphs at compilation time when pymc.aesaraf.compile_pymc is called within a Model context and that flag is set to False.
* Allow edges to be infinity in discrete domains * Automatically assign infinity edges when these are set to (None, None) * Simplex and MultiSimplex now return a Domain instance
d943a5c
to
ae850a6
Compare
This PR adds rewrites to replace or remove Aeppl CheckParameterValue Ops in logprob expressions.
Closes #5205
Closes #5204
Closes #4429
The new
dist_math.check_parameters
always compresses the conditions to a scalar since that's required byCheckParameterValue
. That means that when evaluating a logp vector, either all results get replaced by-np.inf
or none do. It no longer switches only those that corresponded to invalid parameters/values.I don't think it makes sense to generate logprob graphs differently than Aeppl does. If we want to keep the old format we should then reintroduce our own versions of
logp
/logcdf
to replace those that are defined in Aeppl (many common distributions are defined there), or else we will have a mix of expressions that follow distinct logics.