Skip to content

Commit 96b2441

Browse files
Spaaktwiecki
andauthored
refactor _repr_latex functionality (#4065)
* treat array of shape (1,) as scalar * adding generic machinery for generating string representations * double quotes for strings * Update pymc3/distributions/distribution.py Co-authored-by: Thomas Wiecki <[email protected]> * restoring return None behavior of TransformedDistribution::_repr_latex_ * extra . for import * renaming _distr_parameters() to _distr_parameters_for_repr() to avoid confusion * replacing old _repr_latex_ functionality with new one * adding new repr functionality to Deterministic * replacing old with new str repr functionality in PyMC3Variable * ensure that TransformedDistribution does not mess up its str repr * new str repr functionality in Model * don't touch __str__ for now * fixing _repr_latex_ for Simulator Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 7e57248 commit 96b2441

11 files changed

+201
-612
lines changed

pymc3/distributions/continuous.py

+36-272
Large diffs are not rendered by default.

pymc3/distributions/discrete.py

+2-138
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from scipy import stats
1818
import warnings
1919

20-
from pymc3.util import get_variable_name
2120
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
2221
from .distribution import Discrete, draw_values, generate_samples
2322
from .shape_utils import broadcast_distribution_samples
@@ -123,15 +122,6 @@ def logp(self, value):
123122
0 <= value, value <= n,
124123
0 <= p, p <= 1)
125124

126-
def _repr_latex_(self, name=None, dist=None):
127-
if dist is None:
128-
dist = self
129-
n = dist.n
130-
p = dist.p
131-
name = r'\text{%s}' % name
132-
return r'${} \sim \text{{Binomial}}(\mathit{{n}}={},~\mathit{{p}}={})$'.format(name,
133-
get_variable_name(n),
134-
get_variable_name(p))
135125

136126
class BetaBinomial(Discrete):
137127
R"""
@@ -259,16 +249,6 @@ def logp(self, value):
259249
value >= 0, value <= self.n,
260250
alpha > 0, beta > 0)
261251

262-
def _repr_latex_(self, name=None, dist=None):
263-
if dist is None:
264-
dist = self
265-
alpha = dist.alpha
266-
beta = dist.beta
267-
name = r'\text{%s}' % name
268-
return r'${} \sim \text{{BetaBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name,
269-
get_variable_name(alpha),
270-
get_variable_name(beta))
271-
272252

273253
class Bernoulli(Discrete):
274254
R"""Bernoulli log-likelihood
@@ -371,13 +351,8 @@ def logp(self, value):
371351
value >= 0, value <= 1,
372352
p >= 0, p <= 1)
373353

374-
def _repr_latex_(self, name=None, dist=None):
375-
if dist is None:
376-
dist = self
377-
p = dist.p
378-
name = r'\text{%s}' % name
379-
return r'${} \sim \text{{Bernoulli}}(\mathit{{p}}={})$'.format(name,
380-
get_variable_name(p))
354+
def _distr_parameters_for_repr(self):
355+
return ["p"]
381356

382357

383358
class DiscreteWeibull(Discrete):
@@ -486,16 +461,6 @@ def random(self, point=None, size=None):
486461
dist_shape=self.shape,
487462
size=size)
488463

489-
def _repr_latex_(self, name=None, dist=None):
490-
if dist is None:
491-
dist = self
492-
q = dist.q
493-
beta = dist.beta
494-
name = r'\text{%s}' % name
495-
return r'${} \sim \text{{DiscreteWeibull}}(\mathit{{q}}={},~\mathit{{beta}}={})$'.format(name,
496-
get_variable_name(q),
497-
get_variable_name(beta))
498-
499464

500465
class Poisson(Discrete):
501466
R"""
@@ -590,14 +555,6 @@ def logp(self, value):
590555
return tt.switch(tt.eq(mu, 0) * tt.eq(value, 0),
591556
0, log_prob)
592557

593-
def _repr_latex_(self, name=None, dist=None):
594-
if dist is None:
595-
dist = self
596-
mu = dist.mu
597-
name = r'\text{%s}' % name
598-
return r'${} \sim \text{{Poisson}}(\mathit{{mu}}={})$'.format(name,
599-
get_variable_name(mu))
600-
601558

602559
class NegativeBinomial(Discrete):
603560
R"""
@@ -717,16 +674,6 @@ def logp(self, value):
717674
Poisson.dist(self.mu).logp(value),
718675
negbinom)
719676

720-
def _repr_latex_(self, name=None, dist=None):
721-
if dist is None:
722-
dist = self
723-
mu = dist.mu
724-
alpha = dist.alpha
725-
name = r'\text{%s}' % name
726-
return r'${} \sim \text{{NegativeBinomial}}(\mathit{{mu}}={},~\mathit{{alpha}}={})$'.format(name,
727-
get_variable_name(mu),
728-
get_variable_name(alpha))
729-
730677

731678
class Geometric(Discrete):
732679
R"""
@@ -810,14 +757,6 @@ def logp(self, value):
810757
return bound(tt.log(p) + logpow(1 - p, value - 1),
811758
0 <= p, p <= 1, value >= 1)
812759

813-
def _repr_latex_(self, name=None, dist=None):
814-
if dist is None:
815-
dist = self
816-
p = dist.p
817-
name = r'\text{%s}' % name
818-
return r'${} \sim \text{{Geometric}}(\mathit{{p}}={})$'.format(name,
819-
get_variable_name(p))
820-
821760

822761
class DiscreteUniform(Discrete):
823762
R"""
@@ -913,16 +852,6 @@ def logp(self, value):
913852
return bound(-tt.log(upper - lower + 1),
914853
lower <= value, value <= upper)
915854

916-
def _repr_latex_(self, name=None, dist=None):
917-
if dist is None:
918-
dist = self
919-
lower = dist.lower
920-
upper = dist.upper
921-
name = r'\text{%s}' % name
922-
return r'${} \sim \text{{DiscreteUniform}}(\mathit{{lower}}={},~\mathit{{upper}}={})$'.format(name,
923-
get_variable_name(lower),
924-
get_variable_name(upper))
925-
926855

927856
class Categorical(Discrete):
928857
R"""
@@ -1044,14 +973,6 @@ def logp(self, value):
1044973
return bound(a, value >= 0, value <= (k - 1),
1045974
tt.all(p_ >= 0, axis=-1), tt.all(p <= 1, axis=-1))
1046975

1047-
def _repr_latex_(self, name=None, dist=None):
1048-
if dist is None:
1049-
dist = self
1050-
p = dist.p
1051-
name = r'\text{%s}' % name
1052-
return r'${} \sim \text{{Categorical}}(\mathit{{p}}={})$'.format(name,
1053-
get_variable_name(p))
1054-
1055976

1056977
class Constant(Discrete):
1057978
r"""
@@ -1112,12 +1033,6 @@ def logp(self, value):
11121033
c = self.c
11131034
return bound(0, tt.eq(value, c))
11141035

1115-
def _repr_latex_(self, name=None, dist=None):
1116-
if dist is None:
1117-
dist = self
1118-
name = r'\text{%s}' % name
1119-
return r'${} \sim \text{{Constant}}()$'.format(name)
1120-
11211036

11221037
ConstantDist = Constant
11231038

@@ -1231,16 +1146,6 @@ def logp(self, value):
12311146
0 <= psi, psi <= 1,
12321147
0 <= theta)
12331148

1234-
def _repr_latex_(self, name=None, dist=None):
1235-
if dist is None:
1236-
dist = self
1237-
theta = dist.theta
1238-
psi = dist.psi
1239-
name = r'\text{%s}' % name
1240-
return r'${} \sim \text{{ZeroInflatedPoisson}}(\mathit{{theta}}={},~\mathit{{psi}}={})$'.format(name,
1241-
get_variable_name(theta),
1242-
get_variable_name(psi))
1243-
12441149

12451150
class ZeroInflatedBinomial(Discrete):
12461151
R"""
@@ -1354,22 +1259,6 @@ def logp(self, value):
13541259
0 <= psi, psi <= 1,
13551260
0 <= p, p <= 1)
13561261

1357-
def _repr_latex_(self, name=None, dist=None):
1358-
if dist is None:
1359-
dist = self
1360-
n = dist.n
1361-
p = dist.p
1362-
psi = dist.psi
1363-
1364-
name_n = get_variable_name(n)
1365-
name_p = get_variable_name(p)
1366-
name_psi = get_variable_name(psi)
1367-
name = r'\text{%s}' % name
1368-
return (r'${} \sim \text{{ZeroInflatedBinomial}}'
1369-
r'(\mathit{{n}}={},~\mathit{{p}}={},~'
1370-
r'\mathit{{psi}}={})$'
1371-
.format(name, name_n, name_p, name_psi))
1372-
13731262

13741263
class ZeroInflatedNegativeBinomial(Discrete):
13751264
R"""
@@ -1523,22 +1412,6 @@ def logp(self, value):
15231412
0 <= psi, psi <= 1,
15241413
mu > 0, alpha > 0)
15251414

1526-
def _repr_latex_(self, name=None, dist=None):
1527-
if dist is None:
1528-
dist = self
1529-
mu = dist.mu
1530-
alpha = dist.alpha
1531-
psi = dist.psi
1532-
1533-
name_mu = get_variable_name(mu)
1534-
name_alpha = get_variable_name(alpha)
1535-
name_psi = get_variable_name(psi)
1536-
name = r'\text{%s}' % name
1537-
return (r'${} \sim \text{{ZeroInflatedNegativeBinomial}}'
1538-
r'(\mathit{{mu}}={},~\mathit{{alpha}}={},~'
1539-
r'\mathit{{psi}}={})$'
1540-
.format(name, name_mu, name_alpha, name_psi))
1541-
15421415

15431416
class OrderedLogistic(Categorical):
15441417
R"""
@@ -1619,12 +1492,3 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
16191492
p = p_cum[..., 1:] - p_cum[..., :-1]
16201493

16211494
super().__init__(p=p, *args, **kwargs)
1622-
1623-
def _repr_latex_(self, name=None, dist=None):
1624-
if dist is None:
1625-
dist = self
1626-
name_eta = get_variable_name(dist.eta)
1627-
name_cutpoints = get_variable_name(dist.cutpoints)
1628-
return (r'${} \sim \text{{OrderedLogistic}}'
1629-
r'(\mathit{{eta}}={}, \mathit{{cutpoints}}={}$'
1630-
.format(name, name_eta, name_cutpoints))

pymc3/distributions/distribution.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
import numbers
1616
import contextvars
1717
import dill
18+
import inspect
1819
from typing import TYPE_CHECKING
1920
if TYPE_CHECKING:
2021
from typing import Optional, Callable
2122

2223
import numpy as np
2324
import theano.tensor as tt
2425
from theano import function
26+
from ..util import get_repr_for_variable
2527
import theano
2628
from ..memoize import memoize
2729
from ..model import (
@@ -135,9 +137,46 @@ def getattr_value(self, val):
135137

136138
return val
137139

138-
def _repr_latex_(self, name=None, dist=None):
140+
def _distr_parameters_for_repr(self):
141+
"""Return the names of the parameters for this distribution (e.g. "mu"
142+
and "sigma" for Normal). Used in generating string (and LaTeX etc.)
143+
representations of Distribution objects. By default based on inspection
144+
of __init__, but can be overwritten if necessary (e.g. to avoid including
145+
"sd" and "tau").
146+
"""
147+
return inspect.getfullargspec(self.__init__).args[1:]
148+
149+
def _distr_name_for_repr(self):
150+
return self.__class__.__name__
151+
152+
def _str_repr(self, name=None, dist=None, formatting='plain'):
153+
"""Generate string representation for this distribution, optionally
154+
including LaTeX markup (formatting='latex').
155+
"""
156+
if dist is None:
157+
dist = self
158+
if name is None:
159+
name = '[unnamed]'
160+
161+
param_names = self._distr_parameters_for_repr()
162+
param_values = [get_repr_for_variable(getattr(dist, x), formatting=formatting)
163+
for x in param_names]
164+
165+
if formatting == "latex":
166+
param_string = ",~".join([r"\mathit{{{name}}}={value}".format(name=name,
167+
value=value) for name, value in zip(param_names, param_values)])
168+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(var_name=name,
169+
distr_name=dist._distr_name_for_repr(), params=param_string)
170+
else:
171+
# 'plain' is default option
172+
param_string = ", ".join(["{name}={value}".format(name=name,
173+
value=value) for name, value in zip(param_names, param_values)])
174+
return "{var_name} ~ {distr_name}({params})".format(var_name=name,
175+
distr_name=dist._distr_name_for_repr(), params=param_string)
176+
177+
def _repr_latex_(self, **kwargs):
139178
"""Magic method name for IPython to use for LaTeX formatting."""
140-
return None
179+
return self._str_repr(formatting="latex", **kwargs)
141180

142181
def logp_nojac(self, *args, **kwargs):
143182
"""Return the logp, but do not include a jacobian term for transforms.
@@ -200,6 +239,9 @@ def logp(self, x):
200239
"""
201240
return tt.zeros_like(x)
202241

242+
def _distr_parameters_for_repr(self):
243+
return []
244+
203245

204246
class Discrete(Distribution):
205247
"""Base class for discrete distributions"""
@@ -501,6 +543,9 @@ def random(self, point=None, size=None, **kwargs):
501543
"Define a custom random method and pass it as kwarg random"
502544
)
503545

546+
def _distr_parameters_for_repr(self):
547+
return []
548+
504549

505550
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
506551
""" A context manager class used while drawing values with draw_values

pymc3/distributions/mixture.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import theano.tensor as tt
1919
import warnings
2020

21-
from pymc3.util import get_variable_name
2221
from ..math import logsumexp
2322
from .dist_math import bound, random_choice
2423
from .distribution import (Discrete, Distribution, draw_values,
@@ -578,6 +577,8 @@ def random(self, point=None, size=None):
578577
samples = np.reshape(samples, size + dist_shape)
579578
return samples
580579

580+
def _distr_parameters_for_repr(self):
581+
return []
581582

582583
class NormalMixture(Mixture):
583584
R"""
@@ -627,14 +628,5 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, *
627628
super().__init__(w, Normal.dist(mu, sigma=sigma, shape=comp_shape),
628629
*args, **kwargs)
629630

630-
def _repr_latex_(self, name=None, dist=None):
631-
if dist is None:
632-
dist = self
633-
mu = dist.mu
634-
w = dist.w
635-
sigma = dist.sigma
636-
name = r'\text{%s}' % name
637-
return r'${} \sim \text{{NormalMixture}}(\mathit{{w}}={},~\mathit{{mu}}={},~\mathit{{sigma}}={})$'.format(name,
638-
get_variable_name(w),
639-
get_variable_name(mu),
640-
get_variable_name(sigma))
631+
def _distr_parameters_for_repr(self):
632+
return ["w", "mu", "sigma"]

0 commit comments

Comments
 (0)