Skip to content

Commit 8dbfc75

Browse files
authored
Fix ExGaussian logp (#4050)
* Fix exgaussian logp * Updated release notes
1 parent 78cbf30 commit 8dbfc75

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

Diff for: RELEASE-NOTES.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### Maintenance
66
- Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)).
7-
7+
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
88

99
### Documentation
1010

Diff for: pymc3/distributions/continuous.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
nodes in PyMC.
1919
"""
2020
import numpy as np
21+
import theano
2122
import theano.tensor as tt
2223
from scipy import stats
2324
from scipy.special import expit
@@ -3268,13 +3269,21 @@ def logp(self, value):
32683269
sigma = self.sigma
32693270
nu = self.nu
32703271

3271-
# This condition suggested by exGAUS.R from gamlss
3272-
lp = tt.switch(tt.gt(nu, 0.05 * sigma),
3273-
- tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu)**2
3274-
+ logpow(std_cdf((value - mu) / sigma - sigma / nu), 1.),
3275-
- tt.log(sigma * tt.sqrt(2 * np.pi))
3276-
- 0.5 * ((value - mu) / sigma)**2)
3277-
return bound(lp, sigma > 0., nu > 0.)
3272+
standardized_val = (value - mu) / sigma
3273+
cdf_val = std_cdf(standardized_val - sigma / nu)
3274+
cdf_val_safe = tt.switch(tt.eq(cdf_val, 0), np.finfo(theano.config.floatX).eps, cdf_val)
3275+
3276+
# This condition is suggested by exGAUS.R from gamlss
3277+
lp = tt.switch(
3278+
tt.gt(nu, 0.05 * sigma),
3279+
-tt.log(nu)
3280+
+ (mu - value) / nu
3281+
+ 0.5 * (sigma / nu) ** 2
3282+
+ logpow(cdf_val_safe, 1.0),
3283+
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * standardized_val ** 2,
3284+
)
3285+
3286+
return bound(lp, sigma > 0.0, nu > 0.0)
32783287

32793288
def _repr_latex_(self, name=None, dist=None):
32803289
if dist is None:

0 commit comments

Comments
 (0)