Skip to content

Commit 9a15341

Browse files
neerajpradfacebook-github-bot
authored andcommitted
Fix underflow issue with dirichlet sample (pytorch#17488)
Summary: Addresses pytorch#15738, using fritzo's suggestion. This adds a `torch._sample_dirichlet` method in `Distributions.cpp` and `Distributions.cu`. - For CPU, this leads to no perf hit since all we do is to promote the `alpha` to double when getting the gamma samples (the gamma sampler anyways uses `accscalar_t`(double for CPU)) and cast it back to float32 on return. - I have added an analogous method for CUDA as well, but the default sampler for CUDA uses scalar_t for efficiency, so I have kept it as that. With this, I do not see the bias towards 1 as reported in pytorch#15738 with `float32`, but there is a spurious mode at 0.5, as would be expected. Users would need to explicitly use `float64` for GPU to not see the spurious mode at 0.5. (EDIT: see note below, it appears that the bias issue is still there for certain builds). Added some tests and checked that there is no perf regression. My experience with C++ is very limited, so apologies in advance if I missed something basic. cc. ailzhang, fritzo, fmassa Pull Request resolved: pytorch#17488 Differential Revision: D14410301 Pulled By: ezyang fbshipit-source-id: 62b2f694b4642685eab06db96d74ce28e05c3992
1 parent 84fe206 commit 9a15341

File tree

5 files changed

+106
-15
lines changed

5 files changed

+106
-15
lines changed

aten/src/ATen/native/Distributions.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,41 @@ Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
228228
return ret;
229229
}
230230

231+
Tensor _s_dirichlet_cpu(const Tensor& alpha, Generator *gen) {
232+
Tensor ret = at::zeros(alpha.sizes(), alpha.options());
233+
AT_DISPATCH_FLOATING_TYPES(ret.type(), "dirichlet", [&] {
234+
Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double));
235+
THGenerator* generator = get_generator(gen);
236+
std::lock_guard<std::mutex> lock(generator->mutex);
237+
/* Generate gamma sample by casting alpha to double to prevent underflow. */
238+
CPU_tensor_apply2<double, scalar_t>(gamma, alpha,
239+
[generator](double& ret_val, const scalar_t& alpha){
240+
auto uniform_lambda = [generator] () {
241+
return THRandom_standard_uniform(generator);
242+
};
243+
BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
244+
245+
auto normal_lambda = [generator] () {
246+
return THRandom_normal(generator, 0.0, 1.0);
247+
};
248+
BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
249+
auto sample = sample_gamma<double, double, decltype(uniform_lambda), decltype(normal_lambda)>
250+
(alpha, standard_uniform, standard_normal);
251+
ret_val = std::max(std::numeric_limits<double>::min(), sample);
252+
}
253+
);
254+
/* Normalize and cast back to scalar_t. */
255+
Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes());
256+
CPU_tensor_apply3<scalar_t, double , double>(ret, gamma, gamma_sum,
257+
[](scalar_t& ret_val, const double& gamma, const double& gamma_sum){
258+
ret_val = gamma / gamma_sum;
259+
auto min_val = std::numeric_limits<scalar_t>::min();
260+
auto max_val = std::nexttoward(static_cast<scalar_t>(1.0f), 0.0f);
261+
ret_val = std::min(max_val, std::max(min_val, ret_val));
262+
ret_val = static_cast<scalar_t>(ret_val);
263+
}
264+
);
265+
});
266+
return ret;
267+
}
231268
}} // namespace at::native

aten/src/ATen/native/cuda/Distributions.cu

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void gamma_cuda_kernel(
8282
};
8383
BaseSampler<accscalar_t, decltype(normal_lambda)> standard_normal(normal_lambda);
8484
auto sample = sample_gamma<scalar_t, accscalar_t, decltype(uniform_lambda), decltype(normal_lambda)>(alpha, standard_uniform, standard_normal);
85-
auto min_value = std::numeric_limits<scalar_t>::lowest();
85+
auto min_value = std::numeric_limits<scalar_t>::min();
8686
ret_val = (min_value > sample) ? min_value : sample;
8787
});
8888
}
@@ -181,6 +181,21 @@ void bernoulli_scalar_cuda_kernel(
181181
);
182182
}
183183

184+
template<typename scalar_t>
185+
void dirichlet_scalar_cuda_kernel(
186+
at::Tensor& ret,
187+
const at::Tensor& gamma) {
188+
auto gamma_sum = gamma.sum(-1, true).expand(ret.sizes());
189+
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, scalar_t>(ret, gamma, gamma_sum,
190+
[] __device__(scalar_t &ret_val, const scalar_t &gamma, const scalar_t &gamma_sum) {
191+
ret_val = gamma / gamma_sum;
192+
auto min_value = std::numeric_limits<scalar_t>::min();
193+
auto max_value = 1 - std::numeric_limits<scalar_t>::epsilon();
194+
ret_val = (min_value > ret_val) ? min_value : ret_val;
195+
ret_val = (max_value < ret_val) ? max_value : ret_val;
196+
});
197+
}
198+
184199
} // namespace
185200

186201
namespace at { namespace native {
@@ -200,6 +215,16 @@ Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
200215
return ret;
201216
}
202217

218+
Tensor _s_dirichlet_cuda(const Tensor& alpha, Generator* gen) {
219+
Tensor ret = at::empty(alpha.sizes(), alpha.options());
220+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "dirichlet", [&] {
221+
Tensor gamma = at::empty(alpha.sizes(), alpha.options());
222+
gamma_cuda_kernel<scalar_t>(gamma, alpha, next_philox_seed(gen, 10));
223+
dirichlet_scalar_cuda_kernel<scalar_t>(ret, gamma);
224+
});
225+
return ret;
226+
}
227+
203228
Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
204229
Tensor ret = at::empty(self.sizes(), self.options());
205230
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "_standard_gamma_grad_cuda", [&] {

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,6 +2455,13 @@
24552455
CPU: _s_gamma_cpu
24562456
CUDA: _s_gamma_cuda
24572457

2458+
- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
2459+
matches_jit_signature: True
2460+
variants: function
2461+
dispatch:
2462+
CPU: _s_dirichlet_cpu
2463+
CUDA: _s_dirichlet_cuda
2464+
24582465
- func: poisson(Tensor self, Generator? generator=None) -> Tensor
24592466
matches_jit_signature: True
24602467
dispatch:

test/test_distributions.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,39 @@ def test_beta_sample(self):
22082208
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
22092209
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
22102210

2211+
def test_beta_underflow(self):
2212+
# For low values of (alpha, beta), the gamma samples can underflow
2213+
# with float32 and result in a spurious mode at 0.5. To prevent this,
2214+
# torch._sample_dirichlet works with double precision for intermediate
2215+
# calculations.
2216+
set_rng_seed(1)
2217+
num_samples = 50000
2218+
for dtype in [torch.float, torch.double]:
2219+
conc = torch.tensor(1e-2, dtype=dtype)
2220+
beta_samples = Beta(conc, conc).sample([num_samples])
2221+
self.assertEqual((beta_samples == 0).sum(), 0)
2222+
self.assertEqual((beta_samples == 1).sum(), 0)
2223+
# assert support is concentrated around 0 and 1
2224+
frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2225+
frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2226+
self.assertEqual(frac_zeros, 0.5, 0.05)
2227+
self.assertEqual(frac_ones, 0.5, 0.05)
2228+
2229+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
2230+
def test_beta_underflow_gpu(self):
2231+
set_rng_seed(1)
2232+
num_samples = 50000
2233+
conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
2234+
beta_samples = Beta(conc, conc).sample([num_samples])
2235+
self.assertEqual((beta_samples == 0).sum(), 0)
2236+
self.assertEqual((beta_samples == 1).sum(), 0)
2237+
# assert support is concentrated around 0 and 1
2238+
frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2239+
frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2240+
# TODO: increase precision once imbalance on GPU is fixed.
2241+
self.assertEqual(frac_zeros, 0.5, 0.12)
2242+
self.assertEqual(frac_ones, 0.5, 0.12)
2243+
22112244
def test_independent_shape(self):
22122245
for Dist, params in EXAMPLES:
22132246
for param in params:
@@ -3375,7 +3408,7 @@ def test_entropy_monte_carlo(self):
33753408
continue
33763409
x = dist.sample(sample_shape=(60000,))
33773410
expected = -dist.log_prob(x).mean(0)
3378-
ignore = (expected == inf)
3411+
ignore = (expected == inf) | (expected == -inf)
33793412
expected[ignore] = actual[ignore]
33803413
self.assertEqual(actual, expected, prec=0.2, message='\n'.join([
33813414
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),

torch/distributions/dirichlet.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
1-
from numbers import Number
2-
31
import torch
42
from torch.autograd import Function
53
from torch.autograd.function import once_differentiable
64
from torch.distributions import constraints
75
from torch.distributions.exp_family import ExponentialFamily
8-
from torch.distributions.utils import broadcast_all, clamp_probs
9-
10-
11-
def _dirichlet_sample_nograd(concentration):
12-
probs = torch._standard_gamma(concentration)
13-
probs /= probs.sum(-1, True)
14-
return clamp_probs(probs)
156

167

178
# This helper is exposed for testing.
@@ -24,7 +15,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
2415
class _Dirichlet(Function):
2516
@staticmethod
2617
def forward(ctx, concentration):
27-
x = _dirichlet_sample_nograd(concentration)
18+
x = torch._sample_dirichlet(concentration)
2819
ctx.save_for_backward(x, concentration)
2920
return x
3021

@@ -71,9 +62,7 @@ def expand(self, batch_shape, _instance=None):
7162
def rsample(self, sample_shape=()):
7263
shape = self._extended_shape(sample_shape)
7364
concentration = self.concentration.expand(shape)
74-
if isinstance(concentration, torch.Tensor):
75-
return _Dirichlet.apply(concentration)
76-
return _dirichlet_sample_nograd(concentration)
65+
return _Dirichlet.apply(concentration)
7766

7867
def log_prob(self, value):
7968
if self._validate_args:

0 commit comments

Comments
 (0)