Skip to content

Commit 9ef2947

Browse files
lucianopaztwiecki
authored andcommitted
Fix random_choice to handle multidim p and sizes that are not None (#3380)
* Fixed dist_math.random_choice to handle multidimensional p and also non None sizes correctly. * Fixed mixture distribution conflict. * Moved to_tuple from distribution.py to dist_math.py
1 parent 6187eee commit 9ef2947

File tree

6 files changed

+37
-25
lines changed

6 files changed

+37
-25
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- Added tests for mixtures of multidimensional distributions to the test suite.
2323
- Fixed incorrect usage of `broadcast_distribution_samples` in `DiscreteWeibull`.
2424
- `Mixture`'s default dtype is now determined by `theano.config.floatX`.
25+
- `dist_math.random_choice` now handles nd-arrays of category probabilities, and also handles sizes that are not `None`. Also removed unused `k` kwarg from `dist_math.random_choice`.
2526

2627
### Deprecations
2728

pymc3/distributions/dist_math.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@
2020
c = - .5 * np.log(2. * np.pi)
2121

2222

23+
def to_tuple(shape):
24+
"""Convert ints, arrays, and Nones to tuples"""
25+
if shape is None:
26+
return tuple()
27+
temp = np.atleast_1d(shape)
28+
if temp.size == 0:
29+
return tuple()
30+
else:
31+
return tuple(temp)
32+
33+
2334
def bound(logp, *conditions, **kwargs):
2435
"""
2536
Bounds a log probability density with several conditions.
@@ -308,11 +319,12 @@ def random_choice(*args, **kwargs):
308319
309320
Args:
310321
p: array
311-
Probability of each class
312-
size: int
313-
Number of draws to return
314-
k: int
315-
Number of bins
322+
Probability of each class. If p.ndim > 1, the last axis is
323+
interpreted as the probability of each class, and numpy.random.choice
324+
is iterated for every other axis element.
325+
size: int or tuple
326+
Shape of the desired output array. If p is multidimensional, size
327+
should broadcast with p.shape[:-1].
316328
317329
Returns:
318330
random sample: array
@@ -323,8 +335,19 @@ def random_choice(*args, **kwargs):
323335
k = p.shape[-1]
324336

325337
if p.ndim > 1:
326-
# If a 2d vector of probabilities is passed return a sample for each row of categorical probability
338+
# If p is an nd-array, the last axis is interpreted as the class
339+
# probability. We must iterate over the elements of all the other
340+
# dimensions.
341+
# We first ensure that p is broadcasted to the output's shape
342+
size = to_tuple(size) + (1,)
343+
p = np.broadcast_arrays(p, np.empty(size))[0]
344+
out_shape = p.shape[:-1]
345+
# np.random.choice accepts 1D p arrays, so we semiflatten p to
346+
# iterate calls using the last axis as the category probabilities
347+
p = np.reshape(p, (-1, p.shape[-1]))
327348
samples = np.array([np.random.choice(k, p=p_) for p_ in p])
349+
# We reshape to the desired output shape
350+
samples = np.reshape(samples, out_shape)
328351
else:
329352
samples = np.random.choice(k, p=p, size=size)
330353
return samples

pymc3/distributions/distribution.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ObservedRV, MultiObservedRV, Context, InitContextMeta
1111
)
1212
from ..vartypes import string_types
13+
from .dist_math import to_tuple
1314

1415
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
1516
'NoDistribution', 'TensorType', 'draw_values', 'generate_samples']
@@ -553,17 +554,6 @@ def _draw_value(param, point=None, givens=None, size=None):
553554
return output
554555
raise ValueError('Unexpected type in draw_value: %s' % type(param))
555556

556-
557-
def to_tuple(shape):
558-
"""Convert ints, arrays, and Nones to tuples"""
559-
if shape is None:
560-
return tuple()
561-
temp = np.atleast_1d(shape)
562-
if temp.size == 0:
563-
return tuple()
564-
else:
565-
return tuple(temp)
566-
567557
def _is_one_d(dist_shape):
568558
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
569559
return True

pymc3/distributions/mixture.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from pymc3.util import get_variable_name
77
from ..math import logsumexp
8-
from .dist_math import bound, random_choice
8+
from .dist_math import bound, random_choice, to_tuple
99
from .distribution import (Discrete, Distribution, draw_values,
1010
generate_samples, _DrawValuesContext,
11-
_DrawValuesContextBlocker, to_tuple,
11+
_DrawValuesContextBlocker,
1212
broadcast_distribution_samples)
1313
from .continuous import get_tau_sigma, Normal
1414
from ..theanof import _conversion_map
@@ -464,11 +464,8 @@ def random(self, point=None, size=None):
464464
# mixture mixture components, and the rest is all about size,
465465
# dist_shape and broadcasting
466466
w_ = np.reshape(w, (-1, w.shape[-1]))
467-
w_samples = generate_samples(random_choice,
468-
p=w_,
469-
broadcast_shape=w.shape[:-1] or (1,),
470-
dist_shape=w.shape[:-1] or (1,),
471-
size=None) # w's shape already includes size
467+
w_samples = random_choice(p=w_,
468+
size=None) # w's shape already includes size
472469
# Now we broadcast the chosen components to the dist_shape
473470
w_samples = np.reshape(w_samples, w.shape[:-1])
474471
if size is not None and dist_shape[:len(size)] != size:

pymc3/tests/test_distributions_random.py

+1
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def test_probability_vector_shape(self):
443443
"""Check that if a 2d array of probabilities are passed to categorical correct shape is returned"""
444444
p = np.ones((10, 5))
445445
assert pm.Categorical.dist(p=p).random().shape == (10,)
446+
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 10)
446447

447448

448449
class TestScalarParameterSamples(SeededTest):

pymc3/tests/test_mixture.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pymc3.theanof import floatX
1212
import theano
1313
from theano import tensor as tt
14-
from pymc3.distributions.distribution import to_tuple
14+
from pymc3.distributions.dist_math import to_tuple
1515

1616
# Generate data
1717
def generate_normal_mixture_data(w, mu, sd, size=1000):

0 commit comments

Comments
 (0)