Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrites to replace or remove Aeppl CheckParameterValue Ops #5233

Merged
merged 9 commits into from
Dec 10, 2021

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 1, 2021

This PR adds rewrites to replace or remove Aeppl CheckParameterValue Ops in logprob expressions.

Closes #5205
Closes #5204
Closes #4429

  • These rewrites are included by default when calling pymc.aesaraf.compile_pymc, which was previously named compile_rv_inplace
  • pymc.distributions.dist_math.bound was renamed to check_parameters and now returns an expression wrapped in the CheckParameterValue Op
  • The Model.check_bounds flag now only affects graphs at compilation time when pymc.aesaraf.compile_pymc is called within a Model context and that flag is set to False.
  • Most logp/logcdf methods were changed so that the bounding of the value variable is done separately from the parameter checks (missing this in a couple of multivariate distributions). This is in line with how aeppl defines logprob graphs, and to a lesser extent how scipy does things as well. It also provides a solution to Interaction between logcdf and new check_bounds flag #4429
  • Added explicit tests for univariate value and parameter bounds in check_logp, similar to what was already done in check_logcdf, as well as some specialized tests for non-scalar parameters/values in multivariate distributions.

The new dist_math.check_parameters always compresses the conditions to a scalar since that's required by CheckParameterValue. That means that when evaluating a logp vector, either all results get replaced by -np.inf or none do. It no longer switches only those that corresponded to invalid parameters/values.

I don't think it makes sense to generate logprob graphs differently than Aeppl does. If we want to keep the old format we should then reintroduce our own versions of logp/logcdf to replace those that are defined in Aeppl (many common distributions are defined there), or else we will have a mix of expressions that follow distinct logics.

@codecov
Copy link

codecov bot commented Dec 6, 2021

Codecov Report

Merging #5233 (ae850a6) into main (7191e61) will increase coverage by 0.13%.
The diff coverage is 95.89%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5233      +/-   ##
==========================================
+ Coverage   78.98%   79.12%   +0.13%     
==========================================
  Files          88       88              
  Lines       14231    14301      +70     
==========================================
+ Hits        11240    11315      +75     
+ Misses       2991     2986       -5     
Impacted Files Coverage Δ
pymc/sampling_jax.py 0.00% <0.00%> (ø)
pymc/distributions/mixture.py 19.76% <33.33%> (ø)
pymc/distributions/multivariate.py 73.20% <93.75%> (+0.68%) ⬆️
pymc/aesaraf.py 90.10% <96.29%> (+0.39%) ⬆️
pymc/distributions/continuous.py 96.85% <97.77%> (+0.12%) ⬆️
pymc/distributions/bound.py 100.00% <100.00%> (ø)
pymc/distributions/discrete.py 98.46% <100.00%> (+0.10%) ⬆️
pymc/distributions/dist_math.py 86.78% <100.00%> (-0.92%) ⬇️
pymc/gp/util.py 94.68% <100.00%> (ø)
pymc/initial_point.py 100.00% <100.00%> (ø)
... and 5 more

pymc/aesaraf.py Outdated
Comment on lines 931 to 942
for database in ("canonicalize", "stabilize", "specialize", "useless"):
aesara.compile.optdb[database].register(
"local_remove_check_parameter",
local_remove_check_parameter,
use_db_name_as_tag=False,
)

aesara.compile.optdb[database].register(
"local_check_parameter_to_ninf_switch",
local_check_parameter_to_ninf_switch,
use_db_name_as_tag=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add these to all those databases?

Also, don't forget that you can set a priority that affects the order in which they're run.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied what I've seen elsewhere.

Can you give some recommendation on what is the minimum database(s) where these should be registered, as well as what the priority should be?

Copy link
Contributor

@brandonwillard brandonwillard Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to put those in just the pass named "useless" and have them removed before the other passes. Check out aesara.compile.mode to see the passes and their ordering.

Otherwise, I would think that removing those checks should be a user-configurable option, and, if so, you might not want to include them by default. Instead, an option could be set that includes those rewrites when graphs are compiled by PyMC (e.g. by constructing a Mode object for aesara.function that uses includes=...).

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, I would think that removing those checks should be a user-configurable option, and, if so, you might not want to include them by default. Instead, an option could be set that includes those rewrites when graphs are compiled by PyMC (e.g. by constructing a Mode object for aesara.function that uses includes=...).

The rewrites are not running by default due to the use_db_name_as_tag=False option. I am manually including them in compile_pymc helper function only. Otherwise the new tests that expect the specific Exception would fail.

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to put those in just the pass named "useless" and have them removed before the other passes. Check out aesara.compile.mode to see the passes and their ordering.

For some reason putting it in "useless" alone is not doing anything, but putting them in "canonicalize" seems to do the job.

Copy link
Contributor

@brandonwillard brandonwillard Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we need to refactor that OptimizationDatabase code; the whole use_db_name_as_tag thing, and the special way it works with only one DB implementation (i.e. EquilibriumDB), is not good. Just the fact that there's a note about a specific subclass in the base class is a bad sign.

Anyway, "useless" doesn't use the same kind of underlying OptimizationDatabase types as "canonicalize", so I'm guessing that the reason the latter works and the former doesn't is related to that.

@ricardoV94 ricardoV94 force-pushed the logp_asserts branch 3 times, most recently from 18c51db to 4284fc8 Compare December 6, 2021 17:58
@ricardoV94
Copy link
Member Author

The failing dirichlet_multinomial tests should be solved by #5234

@ricardoV94 ricardoV94 marked this pull request as ready for review December 6, 2021 17:59
@ricardoV94 ricardoV94 force-pushed the logp_asserts branch 2 times, most recently from b322a33 to d943a5c Compare December 6, 2021 18:42
…essions

* These rewrites are included by default when calling pymc.aesaraf.compile_pymc, which was previously named compile_rv_inplace
* pymc.distributions.dist_math.bound was renamed to check_parameters and now returns an expression wrapped in the CheckParameterValue Op
* The Model.check_bounds flag now only affects graphs at compilation time when pymc.aesaraf.compile_pymc is called within a Model context and that flag is set to False.
* Allow edges to be infinity in discrete domains
* Automatically assign infinity edges when these are set to (None, None)
* Simplex and MultiSimplex now return a Domain instance
@twiecki twiecki merged commit a4f9657 into pymc-devs:main Dec 10, 2021
@ricardoV94 ricardoV94 deleted the logp_asserts branch December 12, 2021 14:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants