Skip to content

Commit 49443d4

Browse files
jjsjann123facebook-github-bot
authored andcommitted
TensorIterator cuda launch configs update (pytorch#16224)
Summary: Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible block dimension to improve efficiency for reduction cases with small fast dimension. Previously TensorIterator launches blocks with fixed 32x16 threads. For cases like: import torch torch.randn(2**20, 4, device='cuda').sum(0) The fixed launch config does handle coalesced memory access efficiently. Updated launch configure enables flexible block dimension. Combining with improved reduction scheme (using flexible vertical / horizontal reduction instead of limited warp / block reduction in the old code), it ensures optimal memory access pattern even with reduction on dimension with small stride. Possible future improvements: 1. Precise dynamic shared memory allocation. 2. Using warp shuffle for vertical (block_y) reduction. Pull Request resolved: pytorch#16224 Differential Revision: D13806753 Pulled By: soumith fbshipit-source-id: 37e45c7767b5748cf9ecf894fad306e040e2f79f
1 parent b2135b2 commit 49443d4

File tree

1 file changed

+106
-54
lines changed

1 file changed

+106
-54
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

+106-54
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,21 @@ static inline int64_t div_up(int64_t a, int64_t b) {
2525
return (a + b - 1) / b;
2626
}
2727

28+
static inline int last_pow2(int n) {
29+
n |= (n >> 1);
30+
n |= (n >> 2);
31+
n |= (n >> 4);
32+
n |= (n >> 8);
33+
n |= (n >> 16);
34+
return std::max(1, n - (n >> 1));
35+
}
36+
2837
struct ReduceConfig {
29-
static constexpr int LANE = 0;
30-
static constexpr int WARP = 1;
38+
static constexpr int BLOCK_X = 0;
39+
static constexpr int BLOCK_Y = 1;
3140
static constexpr int CTA = 2;
32-
static constexpr int NUM_THREADS = 512;
41+
42+
static constexpr int MAX_NUM_THREADS = 512;
3343

3444
ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
3545
: element_size_bytes(element_size_bytes)
@@ -45,6 +55,19 @@ struct ReduceConfig {
4555
int input_mult[3] = {0, 0, 0};
4656
int output_mult[2] = {0, 0};
4757

58+
int block_width;
59+
int block_height;
60+
int num_threads;
61+
62+
void set_block_dimension(int64_t dim0, int64_t dim1) {
63+
int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim0)) : MAX_NUM_THREADS;
64+
int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim1)) : MAX_NUM_THREADS;
65+
block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
66+
block_height = std::min(dim1_pow2, int(MAX_NUM_THREADS / block_width));
67+
block_width = std::min(dim0_pow2, int(MAX_NUM_THREADS / block_height));
68+
num_threads = block_width * block_height;
69+
}
70+
4871
int split_input(int parallelism) {
4972
int step = step_input;
5073
step_input *= parallelism;
@@ -58,20 +81,19 @@ struct ReduceConfig {
5881
}
5982

6083
dim3 block() const {
61-
int warp_size = at::cuda::warp_size();
62-
return dim3(warp_size, NUM_THREADS / warp_size);
84+
return dim3(block_width, block_height);
6385
}
6486

6587
dim3 grid() const {
6688
return dim3(div_up(num_outputs, step_output), ctas_per_output);
6789
}
6890

69-
C10_HOST_DEVICE bool should_warp_reduce() const {
70-
return input_mult[LANE] != 0;
91+
C10_HOST_DEVICE bool should_block_x_reduce() const {
92+
return input_mult[BLOCK_X] != 0;
7193
}
7294

73-
C10_HOST_DEVICE bool should_block_reduce() const {
74-
return input_mult[WARP] != 0;
95+
C10_HOST_DEVICE bool should_block_y_reduce() const {
96+
return input_mult[BLOCK_Y] != 0;
7597
}
7698

7799
C10_HOST_DEVICE bool should_global_reduce() const {
@@ -80,25 +102,25 @@ struct ReduceConfig {
80102

81103
C10_DEVICE bool should_store(int output_idx) const {
82104
return output_idx < num_outputs &&
83-
(!should_warp_reduce() || threadIdx.x == 0) &&
84-
(!should_block_reduce() || threadIdx.y == 0);
105+
(!should_block_x_reduce() || threadIdx.x == 0) &&
106+
(!should_block_y_reduce() || threadIdx.y == 0);
85107
}
86108

87109
C10_HOST_DEVICE int input_idx() const {
88110
int lane = threadIdx.x;
89111
int warp = threadIdx.y;
90112
int cta2 = blockIdx.y;
91-
return (lane * input_mult[LANE] +
92-
warp * input_mult[WARP] +
113+
return (lane * input_mult[BLOCK_X] +
114+
warp * input_mult[BLOCK_Y] +
93115
cta2 * input_mult[CTA]);
94116
}
95117

96118
C10_HOST_DEVICE int output_idx() const {
97119
int lane = threadIdx.x;
98120
int warp = threadIdx.y;
99121
int cta1 = blockIdx.x;
100-
return (lane * output_mult[LANE] +
101-
warp * output_mult[WARP] +
122+
return (lane * output_mult[BLOCK_X] +
123+
warp * output_mult[BLOCK_Y] +
102124
cta1 * step_output);
103125
}
104126

@@ -108,25 +130,27 @@ struct ReduceConfig {
108130

109131
C10_DEVICE int staging_memory_offset(int cta2) const {
110132
int offset = cta2 + blockIdx.x * gridDim.y;
111-
if (!should_warp_reduce()) {
133+
if (!should_block_x_reduce()) {
112134
offset = threadIdx.x + offset * blockDim.x;
113135
}
114136
return offset;
115137
}
116138

117139
int shared_memory_size() const {
118-
if (!should_block_reduce()) {
140+
if (!should_block_y_reduce() &&
141+
(!should_block_x_reduce() ||
142+
block_width <= at::cuda::warp_size())) {
119143
return 0;
120144
}
121-
return element_size_bytes * NUM_THREADS;
145+
return element_size_bytes * num_threads;
122146
}
123147

124148
int64_t global_memory_size() const {
125149
if (!should_global_reduce()) {
126150
return 0;
127151
}
128152
auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
129-
if (!should_warp_reduce()) {
153+
if (!should_block_x_reduce()) {
130154
size *= block().x;
131155
}
132156
return size;
@@ -267,6 +291,7 @@ struct ReduceOp {
267291
}
268292

269293
C10_DEVICE void run() const {
294+
extern __shared__ char shared_memory[];
270295
index_t output_idx = config.output_idx();
271296
index_t input_idx = config.input_idx();
272297
auto base_offsets = output_calc.get(output_idx);
@@ -276,17 +301,17 @@ struct ReduceOp {
276301
auto input_slice = (const char*)src + base_offsets[1];
277302
value = thread_reduce((const scalar_t*)input_slice);
278303
}
279-
bool should_block_reduce = config.should_block_reduce();
280-
if (should_block_reduce) {
281-
value = block_reduce(value);
304+
bool should_block_y_reduce = config.should_block_y_reduce();
305+
if (should_block_y_reduce) {
306+
value = block_y_reduce(value, shared_memory);
282307
}
283-
if (config.should_warp_reduce() && (!should_block_reduce || threadIdx.y == 0)) {
284-
value = warp_reduce(value);
308+
if (config.should_block_x_reduce()) {
309+
value = block_x_reduce(value, shared_memory);
285310
}
286311

287312
auto out = (out_scalar_t*)((char*)dst + base_offsets[0]);
288313
if (config.should_global_reduce()) {
289-
value = global_reduce(value, out);
314+
value = global_reduce(value, out, shared_memory);
290315
} else if (config.should_store(output_idx)) {
291316
if (accumulate) {
292317
value = accumulate_in_output<can_accumulate_in_output>(out, value);
@@ -331,22 +356,38 @@ struct ReduceOp {
331356
return value;
332357
}
333358

334-
C10_DEVICE arg_t warp_reduce(arg_t value) const {
335-
for (int offset = 1; offset < warpSize; offset <<= 1) {
359+
C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const {
360+
int dim_x = blockDim.x;
361+
arg_t* shared = (arg_t*)shared_memory;
362+
if (dim_x > warpSize) {
363+
int address_base = threadIdx.x + threadIdx.y*blockDim.x;
364+
shared[address_base] = value;
365+
for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
366+
__syncthreads();
367+
if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
368+
arg_t other = shared[address_base + offset];
369+
value = ops.combine(value, other);
370+
shared[address_base] = value;
371+
}
372+
}
373+
dim_x = warpSize;
374+
}
375+
376+
__syncthreads();
377+
378+
for (int offset = 1; offset < dim_x; offset <<= 1) {
336379
arg_t other = ops.warp_shfl_down(value, offset);
337380
value = ops.combine(value, other);
338381
}
339382
return value;
340383
}
341384

342-
C10_DEVICE arg_t block_reduce(arg_t value) const {
343-
extern __shared__ char shared_memory[];
385+
C10_DEVICE arg_t block_y_reduce(arg_t value, char* shared_memory) const {
344386
arg_t* shared = (arg_t*)shared_memory;
345387
shared[config.shared_memory_offset(0)] = value;
346-
int num_warps = (blockDim.x * blockDim.y) / warpSize;
347-
for (int offset = num_warps / 2; offset > 0; offset >>= 1) {
388+
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
348389
__syncthreads();
349-
if (threadIdx.y < offset && threadIdx.y + offset < num_warps) {
390+
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
350391
arg_t other = shared[config.shared_memory_offset(offset)];
351392
value = ops.combine(value, other);
352393
shared[config.shared_memory_offset(0)] = value;
@@ -356,19 +397,17 @@ struct ReduceOp {
356397
}
357398

358399
C10_DEVICE bool mark_block_finished() const {
359-
extern __shared__ int is_last_block_done_shared[];
400+
__shared__ bool is_last_block_done_shared;
360401

361402
__syncthreads();
362403
if (threadIdx.x == 0 && threadIdx.y == 0) {
363404
int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
364-
is_last_block_done_shared[0] = (prev_blocks_finished == gridDim.y - 1);
405+
is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
365406
}
366407

367-
__syncthreads();
368-
bool is_last_block_done = is_last_block_done_shared[0];
369408
__syncthreads();
370409

371-
return is_last_block_done;
410+
return is_last_block_done_shared;
372411
}
373412

374413
template <bool can_acc>
@@ -409,7 +448,7 @@ struct ReduceOp {
409448
return ops.project(value);
410449
}
411450

412-
C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out) const {
451+
C10_DEVICE arg_t global_reduce(arg_t value, out_scalar_t* out, char* shared_memory) const {
413452
arg_t* reduce_buffer = (arg_t*)buffer;
414453

415454
bool should_store = config.should_store(config.output_idx());
@@ -424,7 +463,7 @@ struct ReduceOp {
424463

425464
if (is_last_block_done) {
426465
value = ident;
427-
if (config.should_warp_reduce()) {
466+
if (config.should_block_x_reduce()) {
428467
index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
429468
index_t step = blockDim.x * blockDim.y;
430469
for (; input_offset < config.ctas_per_output; input_offset += step) {
@@ -441,9 +480,9 @@ struct ReduceOp {
441480
value = ops.combine(value, next);
442481
}
443482
}
444-
value = block_reduce(value);
445-
if (config.should_warp_reduce()) {
446-
value = warp_reduce(value);
483+
value = block_y_reduce(value, shared_memory);
484+
if (config.should_block_x_reduce()) {
485+
value = block_x_reduce(value, shared_memory);
447486
}
448487
if (should_store) {
449488
if (accumulate) {
@@ -461,6 +500,7 @@ template<int nt, typename R>
461500
static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
462501
dim3 block = config.block();
463502
dim3 grid = config.grid();
503+
464504
auto stream = at::cuda::getCurrentCUDAStream();
465505
int shared_memory = config.shared_memory_size();
466506
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>>(reduction);
@@ -487,35 +527,47 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
487527
char* out_data = (char*)iter.data_ptr(0);
488528
const char* in_data = (char*)iter.data_ptr(1);
489529

490-
491-
int warp_size = at::cuda::warp_size();
492-
int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
493-
494530
// Start by assuming that each thread handles a single output and all
495531
// the inputs for that output.
496532
int64_t num_outputs = iter.num_output_elements();
497533
int64_t inputs_per_output = iter.numel() / num_outputs;
498534

499535
auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
500536

537+
int64_t dim0;
538+
int64_t dim1;
539+
// adjust block size to fit width to fast changing dimension
540+
if (iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) {
541+
dim0 = iter.shape()[0];
542+
dim1 = num_outputs;
543+
} else {
544+
dim0 = iter.shape()[iter.num_reduce_dims()];
545+
dim1 = inputs_per_output;
546+
}
547+
548+
config.set_block_dimension(dim0, dim1);
549+
550+
int block_width = config.block_width;
551+
int block_height = config.block_height;
552+
501553
if (iter.ndim() == 0 || iter.strides(/*arg=*/1)[0] == sizeof(scalar_t)) {
502554
// Split the input across lanes if the input is contiguous in the reduced
503555
// dimension. This will require reduction between threads using warp
504-
// shuffle instructions.
505-
config.input_mult[0] = config.split_input(warp_size);
556+
// shuffle instructions and shared memory (if block_width > warpSize).
557+
config.input_mult[0] = config.split_input(block_width);
506558
} else {
507559
// Otherwise split the output across lanes in a warp.
508-
config.output_mult[0] = config.split_output(warp_size);
560+
config.output_mult[0] = config.split_output(block_width);
509561
}
510562

511-
if (config.values_per_thread() >= warps_per_cta * 16) {
563+
if (config.values_per_thread() >= block_height * 16) {
512564
// Divide the input across warps in a thread-block, if that leaves at least
513565
// 16 elements to be summed by each thread. This will require inter-warp
514566
// reduction using shared memory.
515-
config.input_mult[1] = config.split_input(warps_per_cta);
567+
config.input_mult[1] = config.split_input(block_height);
516568
} else {
517569
// Otherwise, each warp handles a separate output.
518-
config.output_mult[1] = config.split_output(warps_per_cta);
570+
config.output_mult[1] = config.split_output(block_height);
519571
}
520572

521573
if (config.values_per_thread() >= 256 && num_outputs <= 4096) {
@@ -556,7 +608,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
556608
reduce.accumulate = iter.should_accumulate();
557609
reduce.final_output = iter.is_final_output();
558610

559-
launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
611+
launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS>(config, reduce);
560612
} else {
561613
auto output_calc = make_output_calculator<uint64_t>(iter);
562614
auto input_calc = make_input_calculator<uint64_t>(iter);
@@ -574,7 +626,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
574626
reduce.accumulate = false;
575627
reduce.final_output = true;
576628

577-
launch_reduce_kernel<ReduceConfig::NUM_THREADS>(config, reduce);
629+
launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS>(config, reduce);
578630
}
579631
}
580632

0 commit comments

Comments
 (0)