Skip to content

Commit 7adf05d

Browse files
authored
Refactor DiscreteWeibull (#4615)
1 parent cf1aded commit 7adf05d

File tree

3 files changed

+33
-57
lines changed

3 files changed

+33
-57
lines changed

pymc3/distributions/discrete.py

+32-50
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import aesara.tensor as at
1717
import numpy as np
1818

19-
from aesara.tensor.random.basic import bernoulli, binomial, categorical, nbinom, poisson
19+
from aesara.tensor.random.basic import (
20+
RandomVariable,
21+
bernoulli,
22+
binomial,
23+
categorical,
24+
nbinom,
25+
poisson,
26+
)
2027
from scipy import stats
2128

2229
from pymc3.aesaraf import floatX, intX, take_along_axis
@@ -434,6 +441,22 @@ def _distr_parameters_for_repr(self):
434441
return ["p"]
435442

436443

444+
class DiscreteWeibullRV(RandomVariable):
445+
name = "discrete_weibull"
446+
ndim_supp = 0
447+
ndims_params = [0, 0]
448+
dtype = "int64"
449+
_print_name = ("dWeibull", "\\operatorname{dWeibull}")
450+
451+
@classmethod
452+
def rng_fn(cls, rng, q, beta, size):
453+
p = rng.uniform(size=size)
454+
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
455+
456+
457+
discrete_weibull = DiscreteWeibullRV()
458+
459+
437460
class DiscreteWeibull(Discrete):
438461
R"""Discrete Weibull log-likelihood
439462
@@ -473,51 +496,15 @@ def DiscreteWeibull(q, b, x):
473496
Variance :math:`2 \sum_{x = 1}^{\infty} x q^{x^{\beta}} - \mu - \mu^2`
474497
======== ======================
475498
"""
499+
rv_op = discrete_weibull
476500

477-
def __init__(self, q, beta, *args, **kwargs):
478-
super().__init__(*args, defaults=("median",), **kwargs)
479-
480-
self.q = at.as_tensor_variable(floatX(q))
481-
self.beta = at.as_tensor_variable(floatX(beta))
482-
483-
self.median = self._ppf(0.5)
484-
485-
def _ppf(self, p):
486-
r"""
487-
The percentile point function (the inverse of the cumulative
488-
distribution function) of the discrete Weibull distribution.
489-
"""
490-
q = self.q
491-
beta = self.beta
492-
493-
return (at.ceil(at.power(at.log(1 - p) / at.log(q), 1.0 / beta)) - 1).astype("int64")
494-
495-
def _random(self, q, beta, size=None):
496-
p = np.random.uniform(size=size)
497-
498-
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
499-
500-
def random(self, point=None, size=None):
501-
r"""
502-
Draw random values from DiscreteWeibull distribution.
503-
504-
Parameters
505-
----------
506-
point: dict, optional
507-
Dict of variable values on which random values are to be
508-
conditioned (uses default point if not specified).
509-
size: int, optional
510-
Desired size of random sample (returns one sample if not
511-
specified).
512-
513-
Returns
514-
-------
515-
array
516-
"""
517-
# q, beta = draw_values([self.q, self.beta], point=point, size=size)
518-
# return generate_samples(self._random, q, beta, dist_shape=self.shape, size=size)
501+
@classmethod
502+
def dist(cls, q, beta, *args, **kwargs):
503+
q = at.as_tensor_variable(floatX(q))
504+
beta = at.as_tensor_variable(floatX(beta))
505+
return super().dist([q, beta], **kwargs)
519506

520-
def logp(self, value):
507+
def logp(value, q, beta):
521508
r"""
522509
Calculate log-probability of DiscreteWeibull distribution at specified value.
523510
@@ -531,8 +518,6 @@ def logp(self, value):
531518
-------
532519
TensorVariable
533520
"""
534-
q = self.q
535-
beta = self.beta
536521
return bound(
537522
at.log(at.power(q, at.power(value, beta)) - at.power(q, at.power(value + 1, beta))),
538523
0 <= value,
@@ -541,7 +526,7 @@ def logp(self, value):
541526
0 < beta,
542527
)
543528

544-
def logcdf(self, value):
529+
def logcdf(value, q, beta):
545530
"""
546531
Compute the log of the cumulative distribution function for Discrete Weibull distribution
547532
at the specified value.
@@ -556,9 +541,6 @@ def logcdf(self, value):
556541
-------
557542
TensorVariable
558543
"""
559-
q = self.q
560-
beta = self.beta
561-
562544
return bound(
563545
at.log1p(-at.power(q, at.power(value + 1, beta))),
564546
0 <= value,

pymc3/tests/test_distributions.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,7 @@ def logpow(v, p):
424424

425425
def discrete_weibull_logpmf(value, q, beta):
426426
return floatX(
427-
np.log(
428-
np.power(floatX(q), np.power(floatX(value), floatX(beta)))
429-
- np.power(floatX(q), np.power(floatX(value + 1), floatX(beta)))
430-
)
427+
np.log(np.power(q, np.power(value, beta)) - np.power(q, np.power(value + 1, beta)))
431428
)
432429

433430

@@ -1556,7 +1553,6 @@ def test_bernoulli(self):
15561553
{"p": Unit},
15571554
)
15581555

1559-
@pytest.mark.xfail(reason="Distribution not refactored yet")
15601556
def test_discrete_weibull(self):
15611557
self.check_logp(
15621558
DiscreteWeibull,

pymc3/tests/test_distributions_random.py

-2
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ class TestBernoulli(BaseTestCases.BaseTestCase):
426426
params = {"p": 0.5}
427427

428428

429-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
430429
class TestDiscreteWeibull(BaseTestCases.BaseTestCase):
431430
distribution = pm.DiscreteWeibull
432431
params = {"q": 0.25, "beta": 2.0}
@@ -784,7 +783,6 @@ def ref_rand(size, lower, upper):
784783
pm.DiscreteUniform, {"lower": -NatSmall, "upper": NatSmall}, ref_rand=ref_rand
785784
)
786785

787-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
788786
def test_discrete_weibull(self):
789787
def ref_rand(size, q, beta):
790788
u = np.random.uniform(size=size)

0 commit comments

Comments
 (0)