Skip to content

Commit 9532589

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA][64-bit indexing] Support 64-bit indexing in distribution_elementwise_grid_stride_kernel (pytorch#141613)
For pytorch#141544 Overhead doesn't seem to be noticeable even on small sizes (e.g., 2**10 elements) Pull Request resolved: pytorch#141613 Approved by: https://github.com/Skylion007, https://github.com/ngimel
1 parent 7fafaa9 commit 9532589

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

Diff for: aten/src/ATen/native/cuda/DistributionTemplates.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,25 @@ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_eleme
6363
// grid stride loop kernel for distributions
6464
template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
6565
C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
66-
__global__ void distribution_elementwise_grid_stride_kernel(int numel,
66+
__global__ void distribution_elementwise_grid_stride_kernel(int64_t numel,
6767
PhiloxCudaState philox_args,
6868
const dist_t dist_func,
6969
const transform_t transform_func) {
7070
auto seeds = at::cuda::philox::unpack(philox_args);
71-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
71+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
7272
curandStatePhilox4_32_10_t state;
7373
curand_init(std::get<0>(seeds),
7474
idx,
7575
std::get<1>(seeds),
7676
&state);
7777

78-
int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
78+
int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
7979
blockDim.x * gridDim.x * unroll_factor;
80-
for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
80+
for(int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
8181
auto rand = dist_func(&state);
8282
#pragma unroll
8383
for (int ii = 0; ii < unroll_factor; ii++) {
84-
int li = linear_index + blockDim.x * gridDim.x * ii;
84+
int64_t li = linear_index + blockDim.x * gridDim.x * ii;
8585
if (li < numel) {
8686
transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
8787
}

Diff for: test/test_cuda.py

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from torch.testing._internal.common_device_type import (
3838
instantiate_device_type_tests,
39+
largeTensorTest,
3940
onlyCUDA,
4041
onlyNativeDeviceTypes,
4142
)
@@ -1051,6 +1052,12 @@ def run(dev: torch.device) -> int:
10511052
abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
10521053
)
10531054

1055+
@largeTensorTest("20GB", "cuda")
1056+
def test_randint_generation_for_large_numel(self) -> None:
1057+
numel = 2**31 + 1
1058+
s = torch.randint(2, (numel,), device="cuda", dtype=torch.int8).sum()
1059+
self.assertTrue(s > 0, "expected randint in [0, 1] to generate nonzero values")
1060+
10541061
@parametrize("dtype", [torch.float32, torch.double])
10551062
def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
10561063
# Test if random states do not overlap between consecutive rand/randn calls.

0 commit comments

Comments
 (0)