-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Mixture random cleanup #3364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixture random cleanup #3364
Changes from 5 commits
b6ae4b1
6da77ac
af7ea76
fe44a29
5e3db64
fae4c11
e8affd1
74ff181
86f69ea
46347b2
2200b46
d047a09
dd81ec1
a058bfe
6f75956
bf39de6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
from collections.abc import Iterable | ||
import numpy as np | ||
import theano | ||
import theano.tensor as tt | ||
|
||
from pymc3.util import get_variable_name | ||
from ..math import logsumexp | ||
from .dist_math import bound, random_choice | ||
from .distribution import (Discrete, Distribution, draw_values, | ||
generate_samples, _DrawValuesContext, | ||
_DrawValuesContextBlocker, to_tuple) | ||
_DrawValuesContextBlocker, to_tuple, | ||
broadcast_distribution_samples) | ||
from .continuous import get_tau_sigma, Normal | ||
from ..theanof import _conversion_map | ||
|
||
|
||
def all_discrete(comp_dists): | ||
|
@@ -79,9 +83,9 @@ def __init__(self, w, comp_dists, *args, **kwargs): | |
defaults = kwargs.pop('defaults', []) | ||
|
||
if all_discrete(comp_dists): | ||
dtype = kwargs.pop('dtype', 'int64') | ||
default_dtype = _conversion_map[theano.config.floatX] | ||
else: | ||
dtype = kwargs.pop('dtype', 'float64') | ||
default_dtype = theano.config.floatX | ||
|
||
try: | ||
self.mean = (w * self._comp_means()).sum(axis=-1) | ||
|
@@ -90,6 +94,7 @@ def __init__(self, w, comp_dists, *args, **kwargs): | |
defaults.append('mean') | ||
except AttributeError: | ||
pass | ||
dtype = kwargs.pop('dtype', default_dtype) | ||
|
||
try: | ||
comp_modes = self._comp_modes() | ||
|
@@ -108,29 +113,72 @@ def comp_dists(self): | |
return self._comp_dists | ||
|
||
@comp_dists.setter | ||
def comp_dists(self, _comp_dists): | ||
self._comp_dists = _comp_dists | ||
# Tests if the comp_dists can call random with non None size | ||
with _DrawValuesContextBlocker(): | ||
if isinstance(self.comp_dists, (list, tuple)): | ||
try: | ||
[comp_dist.random(size=23) | ||
for comp_dist in self.comp_dists] | ||
self._comp_dists_vect = True | ||
except Exception: | ||
# The comp_dists cannot call random with non None size or | ||
# without knowledge of the point so we assume that we will | ||
# have to iterate calls to random to get the correct size | ||
self._comp_dists_vect = False | ||
def comp_dists(self, comp_dists): | ||
if isinstance(comp_dists, Distribution): | ||
self._comp_dists = comp_dists | ||
self._comp_dist_shapes = to_tuple(comp_dists.shape) | ||
self._broadcast_shape = self._comp_dist_shapes | ||
self.is_multidim_comp = True | ||
elif isinstance(comp_dists, Iterable): | ||
if not all((isinstance(comp_dist, Distribution) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this check should be move to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I can do that. I thought that it would be a good check to make whenever |
||
for comp_dist in comp_dists)): | ||
raise TypeError('Supplied Mixture comp_dists must be a ' | ||
'Distribution or an iterable of ' | ||
'Distributions.') | ||
self._comp_dists = comp_dists | ||
# Now we check the comp_dists distribution shape, see what | ||
# the broadcast shape would be. This shape will be the dist_shape | ||
# used by generate samples (the shape of a single random sample) | ||
# from the mixture | ||
self._comp_dist_shapes = [to_tuple(d.shape) for d in comp_dists] | ||
# All component distributions must broadcast with each other | ||
try: | ||
self._broadcast_shape = np.broadcast( | ||
*[np.empty(shape) for shape in self._comp_dist_shapes] | ||
).shape | ||
except Exception: | ||
raise TypeError('Supplied comp_dists shapes do not broadcast ' | ||
'with each other. comp_dists shapes are: ' | ||
'{}'.format(self._comp_dist_shapes)) | ||
|
||
# We wrap the _comp_dist.random by adding the kwarg raw_size_, | ||
# which will be the size attribute passed to _comp_samples. | ||
# _comp_samples then calls generate_samples, which may change the | ||
# size value to make it compatible with scipy.stats.*.rvs | ||
self._generators = [] | ||
for comp_dist in comp_dists: | ||
generator = Mixture._comp_dist_random_wrapper(comp_dist.random) | ||
self._generators.append(generator) | ||
self.is_multidim_comp = False | ||
else: | ||
raise TypeError('Cannot handle supplied comp_dist type {}' | ||
.format(type(comp_dists))) | ||
|
||
@staticmethod | ||
def _comp_dist_random_wrapper(random): | ||
"""Wrap the comp_dists.random method to take the kwarg raw_size_ and | ||
use it's value to replace the size parameter. This is needed because | ||
generate_samples makes the size value compatible with the | ||
scipy.stats.*.rvs, where size has a different meaning than in the | ||
distributions' random methods. | ||
""" | ||
def wrapped_random(*args, **kwargs): | ||
raw_size_ = kwargs.pop('raw_size_', None) | ||
if raw_size_ is not None: | ||
if isinstance(raw_size_, np.ndarray): | ||
# This may happen because generate_samples broadcasts | ||
# parameter values | ||
raw_size_ = raw_size_.ravel()[0] | ||
else: | ||
raw_size_ = int(raw_size_) | ||
# Distribution.random's signature is always (point=None, size=None) | ||
# so size could be the second arg or be given as a kwarg | ||
if len(args) > 1: | ||
args[1] = raw_size_ | ||
else: | ||
try: | ||
self.comp_dists.random(size=23) | ||
self._comp_dists_vect = True | ||
except Exception: | ||
# The comp_dists cannot call random with non None size or | ||
# without knowledge of the point so we assume that we will | ||
# have to iterate calls to random to get the correct size | ||
self._comp_dists_vect = False | ||
kwargs['size'] = raw_size_ | ||
return random(*args, **kwargs) | ||
return wrapped_random | ||
|
||
def _comp_logp(self, value): | ||
comp_dists = self.comp_dists | ||
|
@@ -160,35 +208,86 @@ def _comp_modes(self): | |
for comp_dist in self.comp_dists], | ||
axis=1)) | ||
|
||
def _comp_samples(self, point=None, size=None): | ||
if self._comp_dists_vect or size is None: | ||
try: | ||
return self.comp_dists.random(point=point, size=size) | ||
except AttributeError: | ||
samples = np.array([comp_dist.random(point=point, size=size) | ||
for comp_dist in self.comp_dists]) | ||
samples = np.moveaxis(samples, 0, samples.ndim - 1) | ||
def _comp_samples(self, point=None, size=None, | ||
comp_dist_shapes=None, | ||
broadcast_shape=None): | ||
if self.is_multidim_comp: | ||
samples = self._comp_dists.random(point=point, size=size) | ||
else: | ||
# We must iterate the calls to random manually | ||
size = to_tuple(size) | ||
_size = int(np.prod(size)) | ||
try: | ||
samples = np.array([self.comp_dists.random(point=point, | ||
size=None) | ||
for _ in range(_size)]) | ||
samples = np.reshape(samples, size + samples.shape[1:]) | ||
except AttributeError: | ||
samples = np.array([[comp_dist.random(point=point, size=None) | ||
for _ in range(_size)] | ||
for comp_dist in self.comp_dists]) | ||
samples = np.moveaxis(samples, 0, samples.ndim - 1) | ||
samples = np.reshape(samples, size + samples[1:]) | ||
|
||
if comp_dist_shapes is None: | ||
comp_dist_shapes = self._comp_dist_shapes | ||
if broadcast_shape is None: | ||
broadcast_shape = self._sample_shape | ||
samples = [] | ||
for dist_shape, generator in zip(comp_dist_shapes, | ||
self._generators): | ||
sample = generate_samples( | ||
generator=generator, | ||
dist_shape=dist_shape, | ||
broadcast_shape=broadcast_shape, | ||
point=point, | ||
size=size, | ||
raw_size_=size, | ||
) | ||
samples.append(sample) | ||
samples = np.array( | ||
broadcast_distribution_samples(samples, size=size) | ||
) | ||
# In the logp we assume the last axis holds the mixture components | ||
# so we move the axis to the last dimension | ||
samples = np.moveaxis(samples, 0, -1) | ||
if samples.shape[-1] == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if statement is legacy code that I don't really understand. Why should we do this test, wouldn't it be done by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont remember as well, but at some point there is an error with shape = (100, 1) and shape = (100, ) - not even sure we have a test for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So error happens when we try to call |
||
return samples[..., 0] | ||
else: | ||
return samples | ||
|
||
def infer_comp_dist_shapes(self, point=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it is better to call this once in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, good point. |
||
if self.is_multidim_comp: | ||
if len(self._comp_dist_shapes) > 0: | ||
comp_dist_shapes = self._comp_dist_shapes | ||
else: | ||
# Happens when the distribution is a scalar or when it was not | ||
# given a shape. In these cases we try to draw a single value | ||
# to check its shape, we use the provided point dictionary | ||
# hoping that it can circumvent the Flat and HalfFlat | ||
# undrawable distributions. | ||
with _DrawValuesContextBlocker(): | ||
test_sample = self._comp_dists.random(point=point, | ||
size=None) | ||
comp_dist_shapes = test_sample.shape | ||
broadcast_shape = comp_dist_shapes | ||
includes_mixture = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not complete sure what is the function of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. Yes, I'll add a docstring. |
||
else: | ||
# Now we check the comp_dists distribution shape, see what | ||
# the broadcast shape would be. This shape will be the dist_shape | ||
# used by generate samples (the shape of a single random sample) | ||
# from the mixture | ||
comp_dist_shapes = [] | ||
for dist_shape, comp_dist in zip(self._comp_dist_shapes, | ||
self._comp_dists): | ||
if dist_shape == tuple(): | ||
# Happens when the distribution is a scalar or when it was | ||
# not given a shape. In these cases we try to draw a single | ||
# value to check its shape, we use the provided point | ||
# dictionary hoping that it can circumvent the Flat and | ||
# HalfFlat undrawable distributions. | ||
with _DrawValuesContextBlocker(): | ||
test_sample = comp_dist.random(point=point, | ||
size=None) | ||
dist_shape = test_sample.shape | ||
comp_dist_shapes.append(dist_shape) | ||
# All component distributions must broadcast with each other | ||
try: | ||
broadcast_shape = np.broadcast( | ||
*[np.empty(shape) for shape in comp_dist_shapes] | ||
).shape | ||
except Exception: | ||
raise TypeError('Inferred comp_dist shapes do not broadcast ' | ||
'with each other. comp_dists inferred shapes ' | ||
'are: {}'.format(comp_dist_shapes)) | ||
includes_mixture = False | ||
return comp_dist_shapes, broadcast_shape, includes_mixture | ||
|
||
def logp(self, value): | ||
w = self.w | ||
|
||
|
@@ -203,10 +302,9 @@ def random(self, point=None, size=None): | |
with _DrawValuesContext() as draw_context: | ||
# We first need to check w and comp_tmp shapes and re compute size | ||
w = draw_values([self.w], point=point, size=size)[0] | ||
with _DrawValuesContextBlocker(): | ||
# We don't want to store the values drawn here in the context | ||
# because they wont have the correct size | ||
comp_tmp = self._comp_samples(point=point, size=None) | ||
comp_dist_shapes, broadcast_shape, includes_mixture = ( | ||
self.infer_comp_dist_shapes(point=point) | ||
) | ||
|
||
# When size is not None, it's hard to tell the w parameter shape | ||
if size is not None and w.shape[:len(size)] == size: | ||
|
@@ -215,8 +313,12 @@ def random(self, point=None, size=None): | |
w_shape = w.shape | ||
|
||
# Try to determine parameter shape and dist_shape | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
comp_tmp).shape | ||
if includes_mixture: | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
np.empty(broadcast_shape)).shape | ||
else: | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
np.empty(broadcast_shape + (1,))).shape | ||
if np.asarray(self.shape).size != 0: | ||
dist_shape = np.broadcast(np.empty(self.shape), | ||
np.empty(param_shape[:-1])).shape | ||
|
@@ -259,7 +361,11 @@ def random(self, point=None, size=None): | |
else: | ||
output_size = int(np.prod(dist_shape) * param_shape[-1]) | ||
# Get the size we need for the mixture's random call | ||
mixture_size = int(output_size // np.prod(comp_tmp.shape)) | ||
if includes_mixture: | ||
mixture_size = int(output_size // np.prod(broadcast_shape)) | ||
else: | ||
mixture_size = int(output_size // | ||
(np.prod(broadcast_shape) * param_shape[-1])) | ||
if mixture_size == 1 and _size is None: | ||
mixture_size = None | ||
|
||
|
@@ -277,11 +383,23 @@ def random(self, point=None, size=None): | |
size=size) | ||
# Sample from the mixture | ||
with draw_context: | ||
mixed_samples = self._comp_samples(point=point, | ||
size=mixture_size) | ||
w_samples = w_samples.flatten() | ||
mixed_samples = self._comp_samples( | ||
point=point, | ||
size=mixture_size, | ||
broadcast_shape=broadcast_shape, | ||
comp_dist_shapes=comp_dist_shapes, | ||
) | ||
# Test that the mixture has the same number of "samples" as w | ||
if w_samples.size != (mixed_samples.size // w.shape[-1]): | ||
raise ValueError('Inconsistent number of samples from the ' | ||
'mixture and mixture weights. Drew {} mixture ' | ||
'weights elements, and {} samples from the ' | ||
'mixture components.'. | ||
format(w_samples.size, | ||
mixed_samples.size // w.shape[-1])) | ||
# Semiflatten the mixture to be able to zip it with w_samples | ||
mixed_samples = np.reshape(mixed_samples, (-1, comp_tmp.shape[-1])) | ||
w_samples = w_samples.flatten() | ||
mixed_samples = np.reshape(mixed_samples, (-1, w.shape[-1])) | ||
# Select the samples from the mixture | ||
samples = np.array([mixed[choice] for choice, mixed in | ||
zip(w_samples, mixed_samples)]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,7 +267,7 @@ def random(self, point=None, size=None): | |
else: | ||
std_norm_shape = mu.shape | ||
standard_normal = np.random.standard_normal(std_norm_shape) | ||
return mu + np.tensordot(standard_normal, chol, axes=[[-1], [-1]]) | ||
return mu + np.einsum('...ij,...j->...i', chol, standard_normal) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I had not understood |
||
else: | ||
mu, tau = draw_values([self.mu, self.tau], point=point, size=size) | ||
if mu.shape[-1] != tau[0].shape[-1]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is a bit confusion - this is only to distinguish between the comp being a list or a distribution right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're totally right. I'll change this