Skip to content

Add rewrites to replace or remove Aeppl CheckParameterValue Ops #5233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 10, 2021
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.17
- aeppl=0.0.18
- aesara>=2.2.6
- arviz>=0.11.2
- cachetools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ Here is how the example continues:

from pymc.aesaraf import floatX, intX
from pymc.distributions.continuous import PositiveContinuous
from pymc.distributions.dist_math import bound
from pymc.distributions.dist_math import check_parameters


# Subclassing `PositiveContinuous` will dispatch a default `log` transformation
class Blah(PositiveContinuous):

# This will be used by the metaclass `DistributionMeta` to dispatch the
# class `logp` and `logcdf` methods to the `blah` `op`
rv_op = blah
Expand All @@ -158,24 +158,25 @@ class Blah(PositiveContinuous):
def logp(value, param1, param2):
logp_expression = value * (param1 + at.log(param2))

# We use `bound` for parameter validation. After the default expression,
# multiple comma-separated symbolic conditions can be added. Whenever
# a bound is invalidated, the returned expression evaluates to `-np.inf`
return bound(
# A switch is often used to enforce the distribution support domain
bounded_logp_expression = at.switch(
at.gt(value >= 0),
logp_expression,
value >= 0,
param2 >= 0,
# There is one sneaky optional keyowrd argument, that converts an
# otherwise elemwise `bound` to a reduced scalar bound. This is usually
# needed for multivariate distributions where the dimensionality
# of the bound conditions does not match that of the "value" / "logp"
# By default it is set to `True`.
broadcast_conditions=True,
-np.inf,
)

# logcdf works the same way as logp. For bounded variables, it is expected
# to return `-inf` for values below the domain start and `0` for values
# above the domain end, but only when the parameters are valid.
# We use `check_parameters` for parameter validation. After the default expression,
# multiple comma-separated symbolic conditions can be added. Whenever
# a bound is invalidated, the returned expression raises an error with the message
# defined in the optional `msg` keyword argument.
return check_parameters(
logp_expression,
param2 >= 0,
msg="param2 >= 0",
)

# logcdf works the same way as logp. For bounded variables, it is expected to return
# `-inf` for values below the domain start and `0` for values above the domain end.
def logcdf(value, param1, param2):
...

Expand Down Expand Up @@ -357,7 +358,7 @@ New distributions should have a rich docstring, following the same format as tha
It generally looks something like this:

```python
r"""Univariate blah distribution.
r"""Univariate blah distribution.

The pdf of this distribution is

Expand Down
78 changes: 73 additions & 5 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
import numpy as np
import scipy.sparse as sps

from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph import local_optimizer
from aesara.graph.basic import (
Apply,
Constant,
Expand Down Expand Up @@ -899,11 +901,65 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr_shape, indices, _axis)]


def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
"""Use ``aesara.function`` with the random_make_inplace optimization always enabled.
@local_optimizer(tracks=[CheckParameterValue])
def local_remove_check_parameter(fgraph, node):
"""Rewrite that removes Aeppl's CheckParameterValue

Using this function ensures that compiled functions containing random
variables will produce new samples on each call.
This is used when compile_rv_inplace
"""
if isinstance(node.op, CheckParameterValue):
return [node.inputs[0]]


@local_optimizer(tracks=[CheckParameterValue])
def local_check_parameter_to_ninf_switch(fgraph, node):
if isinstance(node.op, CheckParameterValue):
logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = at.all(logp_conds)
else:
(logp_cond,) = logp_conds
out = at.switch(logp_cond, logp_expr, -np.inf)
out.name = node.op.msg

if out.dtype != node.outputs[0].dtype:
out = at.cast(out, node.outputs[0].dtype)

return [out]


aesara.compile.optdb["canonicalize"].register(
"local_remove_check_parameter",
local_remove_check_parameter,
use_db_name_as_tag=False,
)

aesara.compile.optdb["canonicalize"].register(
"local_check_parameter_to_ninf_switch",
local_check_parameter_to_ninf_switch,
use_db_name_as_tag=False,
)


def compile_pymc(inputs, outputs, mode=None, **kwargs):
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.

Included rewrites
-----------------
random_make_inplace
Ensures that compiled functions containing random variables will produce new
samples on each call.
local_check_parameter_to_ninf_switch
Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
that return -inf in place of the assert.

Optional rewrites
-----------------
local_remove_check_parameter
Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
this function is called within a model context and the model `check_bounds` flag
is set to False.
"""

# Avoid circular dependency
Expand All @@ -921,8 +977,20 @@ def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
if not hasattr(rng, "default_update"):
rng.default_update = rv.owner.outputs[0]

# If called inside a model context, see if check_bounds flag is set to False
try:
from pymc.model import modelcontext

model = modelcontext(None)
check_bounds = model.check_bounds
except TypeError:
check_bounds = True
check_parameter_opt = (
"local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
)

mode = get_mode(mode)
opt_qry = mode.provided_optimizer.including("random_make_inplace")
opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
mode = Mode(linker=mode.linker, optimizer=opt_qry)
aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
return aesara_function
31 changes: 25 additions & 6 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import aesara.tensor as at
import numpy as np

from aeppl.logprob import logprob
from aesara.tensor import as_tensor_variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import floatX, intX
from pymc.distributions.continuous import BoundedContinuous
from pymc.distributions.dist_math import bound
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.logprob import logp
from pymc.distributions.shape_utils import to_tuple
from pymc.model import modelcontext

Expand Down Expand Up @@ -67,8 +68,17 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))
res = at.switch(
at.or_(at.lt(value, lower), at.gt(value, upper)),
-np.inf,
logp(distribution, value),
)

return check_parameters(
res,
lower <= upper,
msg="lower <= upper",
)


class DiscreteBoundRV(BoundRV):
Expand Down Expand Up @@ -107,8 +117,17 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))
res = at.switch(
at.or_(at.lt(value, lower), at.gt(value, upper)),
-np.inf,
logp(distribution, value),
)

return check_parameters(
res,
lower <= upper,
msg="lower <= upper",
)


class Bound:
Expand Down
Loading