Skip to content

Commit 8d9d3d8

Browse files
lucianopaztwiecki
authored andcommitted
Changed Categorical to work with multidim p at the logp level (#3383)
* Changed Categorical to work with multidim p at the logp level. * Fixed problems with OrderedLogistic. * Use np.moveaxis instead of transposing. Also added some more tests.
1 parent f541c5a commit 8d9d3d8

File tree

5 files changed

+69
-12
lines changed

5 files changed

+69
-12
lines changed

RELEASE-NOTES.md

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
- Fixed incorrect usage of `broadcast_distribution_samples` in `DiscreteWeibull`.
2424
- `Mixture`'s default dtype is now determined by `theano.config.floatX`.
2525
- `dist_math.random_choice` now handles nd-arrays of category probabilities, and also handles sizes that are not `None`. Also removed unused `k` kwarg from `dist_math.random_choice`.
26+
- Changed `Categorical.mode` to preserve all the dimensions of `p` except the last one, which encodes each category's probability.
27+
- Changed initialization of `Categorical.p`. `p` is now normalized to sum to `1` inside `logp` and `random`, but not during initialization. This could hide negative values supplied to `p` as mentioned in #2082.
28+
- To be able to test for negative `p` values supplied to `Categorical`, `Categorical.logp` was changed to check for `sum(self.p, axis=-1) == 1` only if `self.p` is not a `Number`, `np.ndarray`, `TensorConstant` or `SharedVariable`. These cases are automatically normalized to sum to `1`. The other condition may originate from a `step_method` proposal, where `self.p` tensor's value may be set, but must sum to 1 nevertheless. This may break old code which intialized `p` with a theano expression and relied on the default normalization to get it to sum to 1. `Categorical.logp` now also checks that the used `p` has values lower than 1.
2629

2730
### Deprecations
2831

pymc3/distributions/discrete.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numbers
12
import numpy as np
23
import theano
34
import theano.tensor as tt
@@ -710,11 +711,17 @@ def __init__(self, p, *args, **kwargs):
710711
except AttributeError:
711712
self.k = tt.shape(p)[-1]
712713
p = tt.as_tensor_variable(floatX(p))
713-
self.p = (p.T / tt.sum(p, -1)).T
714-
self.mode = tt.argmax(p)
714+
715+
# From #2082, it may be dangerous to automatically rescale p at this
716+
# point without checking for positiveness
717+
self.p = p
718+
self.mode = tt.argmax(p, axis=-1)
719+
if self.mode.ndim == 1:
720+
self.mode = tt.squeeze(self.mode)
715721

716722
def random(self, point=None, size=None):
717723
p, k = draw_values([self.p, self.k], point=point, size=size)
724+
p = p / np.sum(p, axis=-1, keepdims=True)
718725

719726
return generate_samples(random_choice,
720727
p=p,
@@ -723,21 +730,33 @@ def random(self, point=None, size=None):
723730
size=size)
724731

725732
def logp(self, value):
726-
p = self.p
733+
p_ = self.p
727734
k = self.k
728735

729736
# Clip values before using them for indexing
730737
value_clip = tt.clip(value, 0, k - 1)
731738

732-
sumto1 = theano.gradient.zero_grad(
733-
tt.le(abs(tt.sum(p, axis=-1) - 1), 1e-5))
739+
# We must only check that the values sum to 1 if p comes from a
740+
# tensor variable, i.e. when p is a step_method proposal. In the other
741+
# cases we normalize ourselves
742+
if not isinstance(p_, (numbers.Number,
743+
np.ndarray,
744+
tt.TensorConstant,
745+
tt.sharedvar.SharedVariable)):
746+
sumto1 = theano.gradient.zero_grad(
747+
tt.le(abs(tt.sum(p_, axis=-1) - 1), 1e-5))
748+
p = p_
749+
else:
750+
p = p_ / tt.sum(p_, axis=-1, keepdims=True)
751+
sumto1 = True
734752

735753
if p.ndim > 1:
736-
a = tt.log(p[tt.arange(p.shape[0]), value_clip])
754+
a = tt.log(np.moveaxis(p, -1, 0)[value_clip])
737755
else:
738756
a = tt.log(p[value_clip])
739757

740-
return bound(a, value >= 0, value <= (k - 1), sumto1)
758+
return bound(a, value >= 0, value <= (k - 1), sumto1,
759+
tt.all(p_ > 0, axis=-1), tt.all(p <= 1, axis=-1))
741760

742761
def _repr_latex_(self, name=None, dist=None):
743762
if dist is None:
@@ -1177,7 +1196,7 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
11771196
tt.zeros_like(tt.shape_padright(pa[:, 0])),
11781197
pa,
11791198
tt.ones_like(tt.shape_padright(pa[:, 0]))
1180-
], axis=1)
1199+
], axis=-1)
11811200
p = p_cum[:, 1:] - p_cum[:, :-1]
11821201

11831202
super().__init__(p=p, *args, **kwargs)

pymc3/tests/test_distribution_defaults.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ..model import Model
2-
from ..distributions import DiscreteUniform, Continuous
2+
from ..distributions import DiscreteUniform, Continuous, Categorical
33

44
import numpy as np
55
import pytest
@@ -67,3 +67,10 @@ def test_discrete_uniform_negative():
6767
with model:
6868
x = DiscreteUniform('x', lower=-10, upper=0)
6969
assert model.test_point['x'] == -5
70+
71+
72+
def test_categorical_mode():
73+
model = Model()
74+
with model:
75+
x = Categorical('x', p=np.eye(4), shape=4)
76+
assert np.allclose(model.test_point['x'], np.arange(4))

pymc3/tests/test_distributions.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def dirichlet_logpdf(value, a):
321321

322322
def categorical_logpdf(value, p):
323323
if value >= 0 and value <= len(p):
324-
return floatX(np.log(p[value]))
324+
return floatX(np.log(np.moveaxis(p, -1, 0)[value]))
325325
else:
326326
return -inf
327327

@@ -346,8 +346,10 @@ def invlogit(x, eps=sys.float_info.epsilon):
346346

347347
def orderedlogistic_logpdf(value, eta, cutpoints):
348348
c = np.concatenate(([-np.inf], cutpoints, [np.inf]))
349-
p = invlogit(eta - c[value]) - invlogit(eta - c[value + 1])
350-
return np.log(p)
349+
ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1)
350+
for cc, cc1 in zip(c[:-1], c[1:])])
351+
p = ps[value]
352+
return np.where(np.all(ps > 0), np.log(p), -np.inf)
351353

352354
class Simplex:
353355
def __init__(self, n):
@@ -1079,6 +1081,29 @@ def test_categorical_bounds(self):
10791081
assert np.isinf(x.logp({'x': -1}))
10801082
assert np.isinf(x.logp({'x': 3}))
10811083

1084+
def test_categorical_valid_p(self):
1085+
with Model():
1086+
x = Categorical('x', p=np.array([-0.2, 0.3, 0.5]))
1087+
assert np.isinf(x.logp({'x': 0}))
1088+
assert np.isinf(x.logp({'x': 1}))
1089+
assert np.isinf(x.logp({'x': 2}))
1090+
with Model():
1091+
# A model where p sums to 1 but contains negative values
1092+
x = Categorical('x', p=np.array([-0.2, 0.7, 0.5]))
1093+
assert np.isinf(x.logp({'x': 0}))
1094+
assert np.isinf(x.logp({'x': 1}))
1095+
assert np.isinf(x.logp({'x': 2}))
1096+
with Model():
1097+
# Hard edge case from #2082
1098+
# Early automatic normalization of p's sum would hide the negative
1099+
# entries if there is a single or pair number of negative values
1100+
# and the rest are zero
1101+
x = Categorical('x', p=np.array([-1, -1, 0, 0]))
1102+
assert np.isinf(x.logp({'x': 0}))
1103+
assert np.isinf(x.logp({'x': 1}))
1104+
assert np.isinf(x.logp({'x': 2}))
1105+
assert np.isinf(x.logp({'x': 3}))
1106+
10821107
@pytest.mark.parametrize('n', [2, 3, 4])
10831108
def test_categorical(self, n):
10841109
self.pymc3_matches_scipy(Categorical, Domain(range(n), 'int64'), {'p': Simplex(n)},

pymc3/tests/test_distributions_random.py

+3
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def test_probability_vector_shape(self):
444444
p = np.ones((10, 5))
445445
assert pm.Categorical.dist(p=p).random().shape == (10,)
446446
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 10)
447+
p = np.ones((3, 7, 5))
448+
assert pm.Categorical.dist(p=p).random().shape == (3, 7)
449+
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7)
447450

448451

449452
class TestScalarParameterSamples(SeededTest):

0 commit comments

Comments
 (0)