@@ -63,25 +63,25 @@ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_eleme
6363// grid stride loop kernel for distributions
6464template <typename accscalar_t , int unroll_factor, typename dist_t , typename transform_t >
6565C10_LAUNCH_BOUNDS_2 (block_size_bound, grid_size_bound)
66- __global__ void distribution_elementwise_grid_stride_kernel (int numel,
66+ __global__ void distribution_elementwise_grid_stride_kernel (int64_t numel,
6767 PhiloxCudaState philox_args,
6868 const dist_t dist_func,
6969 const transform_t transform_func) {
7070 auto seeds = at::cuda::philox::unpack (philox_args);
71- int idx = blockIdx.x * blockDim.x + threadIdx.x ;
71+ int64_t idx = blockIdx.x * blockDim.x + threadIdx.x ;
7272 curandStatePhilox4_32_10_t state;
7373 curand_init (std::get<0 >(seeds),
7474 idx,
7575 std::get<1 >(seeds),
7676 &state);
7777
78- int rounded_size = ((numel - 1 )/(blockDim.x * gridDim.x * unroll_factor)+1 ) *
78+ int64_t rounded_size = ((numel - 1 )/(blockDim.x * gridDim.x * unroll_factor)+1 ) *
7979 blockDim.x * gridDim.x * unroll_factor;
80- for (int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
80+ for (int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
8181 auto rand = dist_func (&state);
8282 #pragma unroll
8383 for (int ii = 0 ; ii < unroll_factor; ii++) {
84- int li = linear_index + blockDim.x * gridDim.x * ii;
84+ int64_t li = linear_index + blockDim.x * gridDim.x * ii;
8585 if (li < numel) {
8686 transform_func (li, static_cast <accscalar_t >((&rand.x )[ii]));
8787 }
0 commit comments