Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,11 +1052,11 @@ def logp(value, lower, upper):

def logcdf(value, lower, upper):
res = pt.switch(
pt.le(value, lower),
pt.lt(value, lower),
-np.inf,
pt.switch(
pt.lt(value, upper),
pt.log(pt.minimum(pt.floor(value), upper) - lower + 1) - pt.log(upper - lower + 1),
pt.log(pt.floor(value) - lower + 1) - pt.log(upper - lower + 1),
0,
),
)
Expand Down
98 changes: 77 additions & 21 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import numpy as np
import pytensor.tensor as pt

from pytensor import scan
from pytensor import graph_replace, scan
from pytensor.gradient import jacobian
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -123,6 +123,7 @@
filter_measurable_variables,
find_negated_var,
)
from pymc.math import logdiffexp


class Transform(abc.ABC):
Expand Down Expand Up @@ -162,6 +163,8 @@ def __str__(self):
class MeasurableTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed measurable variable."""

__props__ = ("scalar_op", "inplace_pattern", "is_discrete")

valid_scalar_types = (
Exp,
Log,
Expand All @@ -186,16 +189,55 @@ class MeasurableTransform(MeasurableElemwise):
transform_elemwise: Transform
measurable_input_idx: int

def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs):
def __init__(
self, *args, transform: Transform, measurable_input_idx: int, is_discrete: bool, **kwargs
):
self.transform_elemwise = transform
self.measurable_input_idx = measurable_input_idx
self.is_discrete = is_discrete
super().__init__(*args, **kwargs)


def abs_logprob(op, value, x, **kwargs):
"""Compute the log-CDF graph for an absolute value transformation.

For `Y = |X|`, we have `PDF_Y(y) = PDF_Y(-y) + PDF_Y(y)`.
Except for discrete distributions where there's a special case `P(Y=0) = P(X=0)`.
"""
logprob_pos = _logprob_helper(x, value)
logprob_neg = graph_replace(logprob_pos, {value: -value})
if op.is_discrete:
logprob = pt.switch(
pt.eq(value, 0),
logprob_pos,
pt.logaddexp(logprob_pos, logprob_neg),
)
else:
logprob = pt.logaddexp(logprob_pos, logprob_neg)
logprob = pt.where(value < 0, -np.inf, logprob)
return logprob


def abs_logcdf(op, value, x, **kwargs):
"""Compute the log-CDF graph for an absolute value transformation.

For `Y = |X|`, we have `CDF_Y(y) = P(|X| <= y) = P(-y <= X <= y) = CDF_X(y) - CDF_X(-y)`.
"""
logcdf_pos = _logcdf_helper(x, value)
neg_value = -value - 1 if op.is_discrete else -value
logcdf_neg = graph_replace(logcdf_pos, {value: neg_value})
logcdf = logdiffexp(logcdf_pos, logcdf_neg)
logcdf = pt.where(value < 0, -np.inf, logcdf)
return logcdf


@_logprob.register(MeasurableTransform)
def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwargs):
"""Compute the log-probability graph for a `MeasurabeTransform`."""
# TODO: Could other rewrites affect the order of inputs?
if isinstance(op.scalar_op, Abs):
return abs_logprob(op, values[0], *inputs, **kwargs)

(value,) = values
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)
Expand All @@ -206,6 +248,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa

# Some transformations, like squaring may produce multiple backward values
if isinstance(backward_value, tuple):
if op.is_discrete:
# Discrete variables tend to have the tricky x=0 case, get out if we don't have a custom implementation
raise NotImplementedError(
"Logprob of transformed discrete variables with non-injective transforms not implemented"
)
input_logprob = pt.logaddexp(
*(
_logprob_helper(measurable_input, backward_val, **kwargs)
Expand All @@ -224,8 +271,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
ndim_supp = value.ndim - input_logprob.ndim
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))

# Discrete transformations do not need the jacobian adjustment
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really support any discrete cases with det(jacobian) != 1 but I prefer to have this branch already

logprob = input_logprob if op.is_discrete else input_logprob + jacobian

# The jacobian is used to ensure a value in the supported domain was provided
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
return pt.switch(pt.isnan(jacobian), -np.inf, logprob)


MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
Expand All @@ -235,6 +285,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
@_logcdf.register(MeasurableTransform)
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
if isinstance(op.scalar_op, Abs):
# Special case for absolute value transformation
return abs_logcdf(op, value, *inputs, **kwargs)

other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)
backward_value = op.transform_elemwise.backward(value, *other_inputs)
Expand All @@ -244,10 +298,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
if isinstance(backward_value, tuple):
raise NotImplementedError

is_discrete = measurable_input.type.dtype.startswith("int")

logcdf = _logcdf_helper(measurable_input, backward_value)
if is_discrete:
if op.is_discrete:
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
else:
logccdf = pt.log1mexp(logcdf)
Expand All @@ -267,16 +319,13 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
logcdf_zero = _logcdf_helper(measurable_input, 0)
logcdf = pt.switch(
pt.lt(backward_value, 0),
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
logdiffexp(logcdf_zero, logcdf),
pt.logaddexp(logccdf, logcdf_zero),
)
else:
# We don't know if this Op is monotonically increasing/decreasing
raise NotImplementedError

if is_discrete:
return logcdf

# The jacobian is used to ensure a value in the supported domain was provided
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)
Expand All @@ -285,13 +334,12 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
@_icdf.register(MeasurableTransform)
def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs):
"""Compute the inverse CDF graph for a `MeasurabeTransform`."""
if op.is_discrete:
raise NotImplementedError("icdf of transformed discrete variables not implemented")

other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
raise NotImplementedError("icdf of transformed discrete variables not implemented")

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
Expand Down Expand Up @@ -322,7 +370,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
# Fail if transformation is not injective
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple):
raise NotImplementedError
raise NotImplementedError("icdf of non-injective transformations not implemented")

return icdf

Expand Down Expand Up @@ -480,15 +528,22 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
[measurable_input] = measurable_inputs
[measurable_output] = node.outputs

# Do not apply rewrite to discrete variables except for their addition and negation
if measurable_input.type.dtype.startswith("int"):
# Do not apply rewrite to discrete variables except if:
# 1. Operation retains a discrete output
# 2. Operation doesn't create holes in the support
# Reason:
# 1. Due to a limitation in our IR we don't know the type of the MeasurableVariable
# We don't want to make other rewrites think they are dealing with continuous variables when they are not
# 2. We don't want to add cumbersome within-domain checks
is_discrete = measurable_input.type.dtype.startswith("int")
if is_discrete:
if not measurable_output.type.dtype.startswith("int"):
return None
if not (
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
isinstance(node.op.scalar_op, Add | Abs)
or find_negated_var(measurable_output) is not None
):
return None
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
if not measurable_output.type.dtype.startswith("int"):
return None

# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
Expand Down Expand Up @@ -544,6 +599,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
scalar_op=scalar_op,
transform=transform,
measurable_input_idx=measurable_input_idx,
is_discrete=is_discrete,
)
transform_out = transform_op.make_node(*transform_inputs).default_output()
return [transform_out]
Expand Down
9 changes: 8 additions & 1 deletion pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,14 @@ def kron_diag(*diags):

def logdiffexp(a, b):
"""Return log(exp(a) - exp(b))."""
return a + pt.log1mexp(b - a)
return pt.where(
# Handle special case where b is -inf
# If a == b == -inf, this will return the correct result of -inf
# whereas the default else branch would get a nan due to -inf - (-inf)
pt.isneginf(b),
a,
a + pt.log1mexp(b - a),
)


invlogit = sigmoid
Expand Down
20 changes: 11 additions & 9 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
Nat,
NatSmall,
R,
Rdunif,
Rplus,
Rplusdunif,
Runif,
Simplex,
Unit,
Expand Down Expand Up @@ -95,32 +93,36 @@ def orderedprobit_logpdf(value, eta, cutpoints):


class TestMatchesScipy:
def test_discrete_unif(self):
def test_discrete_uniform(self):
# Choose domain/paramdomain so we test edge cases as well
test_domain = Domain([-np.inf, -10, -1, 0, 1, 10, np.inf], dtype="int64")
test_paramdomain = Domain([-np.inf, 0, 10, np.inf], dtype="int64")
check_logp(
pm.DiscreteUniform,
Rdunif,
{"lower": -Rplusdunif, "upper": Rplusdunif},
test_domain,
{"lower": -test_paramdomain, "upper": test_paramdomain},
lambda value, lower, upper: st.randint.logpmf(value, lower, upper + 1),
skip_paramdomain_outside_edge_test=True,
)
check_logcdf(
pm.DiscreteUniform,
Rdunif,
{"lower": -Rplusdunif, "upper": Rplusdunif},
test_domain,
{"lower": -test_paramdomain, "upper": test_paramdomain},
lambda value, lower, upper: st.randint.logcdf(value, lower, upper + 1),
skip_paramdomain_outside_edge_test=True,
)
check_selfconsistency_discrete_logcdf(
pm.DiscreteUniform,
Domain([-10, 0, 10], "int64"),
{"lower": -Rplusdunif, "upper": Rplusdunif},
{"lower": -test_paramdomain, "upper": test_paramdomain},
)
check_icdf(
pm.DiscreteUniform,
{"lower": -Rplusdunif, "upper": Rplusdunif},
{"lower": -test_paramdomain, "upper": test_paramdomain},
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
skip_paramdomain_outside_edge_test=True,
)

# Custom logp / logcdf check for invalid parameters
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
with pytensor.config.change_flags(mode=Mode("py")):
Expand Down
Loading