@@ -25,11 +25,21 @@ static inline int64_t div_up(int64_t a, int64_t b) {
25
25
return (a + b - 1 ) / b;
26
26
}
27
27
28
+ static inline int last_pow2 (int n) {
29
+ n |= (n >> 1 );
30
+ n |= (n >> 2 );
31
+ n |= (n >> 4 );
32
+ n |= (n >> 8 );
33
+ n |= (n >> 16 );
34
+ return std::max (1 , n - (n >> 1 ));
35
+ }
36
+
28
37
struct ReduceConfig {
29
- static constexpr int LANE = 0 ;
30
- static constexpr int WARP = 1 ;
38
+ static constexpr int BLOCK_X = 0 ;
39
+ static constexpr int BLOCK_Y = 1 ;
31
40
static constexpr int CTA = 2 ;
32
- static constexpr int NUM_THREADS = 512 ;
41
+
42
+ static constexpr int MAX_NUM_THREADS = 512 ;
33
43
34
44
ReduceConfig (int element_size_bytes, int num_outputs, int num_inputs)
35
45
: element_size_bytes(element_size_bytes)
@@ -45,6 +55,19 @@ struct ReduceConfig {
45
55
int input_mult[3 ] = {0 , 0 , 0 };
46
56
int output_mult[2 ] = {0 , 0 };
47
57
58
+ int block_width;
59
+ int block_height;
60
+ int num_threads;
61
+
62
+ void set_block_dimension (int64_t dim0, int64_t dim1) {
63
+ int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast <int >(last_pow2 (dim0)) : MAX_NUM_THREADS;
64
+ int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast <int >(last_pow2 (dim1)) : MAX_NUM_THREADS;
65
+ block_width = std::min (dim0_pow2, int (at::cuda::warp_size ()));
66
+ block_height = std::min (dim1_pow2, int (MAX_NUM_THREADS / block_width));
67
+ block_width = std::min (dim0_pow2, int (MAX_NUM_THREADS / block_height));
68
+ num_threads = block_width * block_height;
69
+ }
70
+
48
71
int split_input (int parallelism) {
49
72
int step = step_input;
50
73
step_input *= parallelism;
@@ -58,20 +81,19 @@ struct ReduceConfig {
58
81
}
59
82
60
83
dim3 block () const {
61
- int warp_size = at::cuda::warp_size ();
62
- return dim3 (warp_size, NUM_THREADS / warp_size);
84
+ return dim3 (block_width, block_height);
63
85
}
64
86
65
87
dim3 grid () const {
66
88
return dim3 (div_up (num_outputs, step_output), ctas_per_output);
67
89
}
68
90
69
- C10_HOST_DEVICE bool should_warp_reduce () const {
70
- return input_mult[LANE ] != 0 ;
91
+ C10_HOST_DEVICE bool should_block_x_reduce () const {
92
+ return input_mult[BLOCK_X ] != 0 ;
71
93
}
72
94
73
- C10_HOST_DEVICE bool should_block_reduce () const {
74
- return input_mult[WARP ] != 0 ;
95
+ C10_HOST_DEVICE bool should_block_y_reduce () const {
96
+ return input_mult[BLOCK_Y ] != 0 ;
75
97
}
76
98
77
99
C10_HOST_DEVICE bool should_global_reduce () const {
@@ -80,25 +102,25 @@ struct ReduceConfig {
80
102
81
103
C10_DEVICE bool should_store (int output_idx) const {
82
104
return output_idx < num_outputs &&
83
- (!should_warp_reduce () || threadIdx .x == 0 ) &&
84
- (!should_block_reduce () || threadIdx .y == 0 );
105
+ (!should_block_x_reduce () || threadIdx .x == 0 ) &&
106
+ (!should_block_y_reduce () || threadIdx .y == 0 );
85
107
}
86
108
87
109
C10_HOST_DEVICE int input_idx () const {
88
110
int lane = threadIdx .x ;
89
111
int warp = threadIdx .y ;
90
112
int cta2 = blockIdx .y ;
91
- return (lane * input_mult[LANE ] +
92
- warp * input_mult[WARP ] +
113
+ return (lane * input_mult[BLOCK_X ] +
114
+ warp * input_mult[BLOCK_Y ] +
93
115
cta2 * input_mult[CTA]);
94
116
}
95
117
96
118
C10_HOST_DEVICE int output_idx () const {
97
119
int lane = threadIdx .x ;
98
120
int warp = threadIdx .y ;
99
121
int cta1 = blockIdx .x ;
100
- return (lane * output_mult[LANE ] +
101
- warp * output_mult[WARP ] +
122
+ return (lane * output_mult[BLOCK_X ] +
123
+ warp * output_mult[BLOCK_Y ] +
102
124
cta1 * step_output);
103
125
}
104
126
@@ -108,25 +130,27 @@ struct ReduceConfig {
108
130
109
131
C10_DEVICE int staging_memory_offset (int cta2) const {
110
132
int offset = cta2 + blockIdx .x * gridDim .y ;
111
- if (!should_warp_reduce ()) {
133
+ if (!should_block_x_reduce ()) {
112
134
offset = threadIdx .x + offset * blockDim .x ;
113
135
}
114
136
return offset;
115
137
}
116
138
117
139
int shared_memory_size () const {
118
- if (!should_block_reduce ()) {
140
+ if (!should_block_y_reduce () &&
141
+ (!should_block_x_reduce () ||
142
+ block_width <= at::cuda::warp_size ())) {
119
143
return 0 ;
120
144
}
121
- return element_size_bytes * NUM_THREADS ;
145
+ return element_size_bytes * num_threads ;
122
146
}
123
147
124
148
int64_t global_memory_size () const {
125
149
if (!should_global_reduce ()) {
126
150
return 0 ;
127
151
}
128
152
auto size = (int64_t )element_size_bytes * num_outputs * ctas_per_output;
129
- if (!should_warp_reduce ()) {
153
+ if (!should_block_x_reduce ()) {
130
154
size *= block ().x ;
131
155
}
132
156
return size;
@@ -267,6 +291,7 @@ struct ReduceOp {
267
291
}
268
292
269
293
C10_DEVICE void run () const {
294
+ extern __shared__ char shared_memory[];
270
295
index_t output_idx = config.output_idx ();
271
296
index_t input_idx = config.input_idx ();
272
297
auto base_offsets = output_calc.get (output_idx);
@@ -276,17 +301,17 @@ struct ReduceOp {
276
301
auto input_slice = (const char *)src + base_offsets[1 ];
277
302
value = thread_reduce ((const scalar_t *)input_slice);
278
303
}
279
- bool should_block_reduce = config.should_block_reduce ();
280
- if (should_block_reduce ) {
281
- value = block_reduce (value);
304
+ bool should_block_y_reduce = config.should_block_y_reduce ();
305
+ if (should_block_y_reduce ) {
306
+ value = block_y_reduce (value, shared_memory );
282
307
}
283
- if (config.should_warp_reduce () && (!should_block_reduce || threadIdx . y == 0 )) {
284
- value = warp_reduce (value);
308
+ if (config.should_block_x_reduce ( )) {
309
+ value = block_x_reduce (value, shared_memory );
285
310
}
286
311
287
312
auto out = (out_scalar_t *)((char *)dst + base_offsets[0 ]);
288
313
if (config.should_global_reduce ()) {
289
- value = global_reduce (value, out);
314
+ value = global_reduce (value, out, shared_memory );
290
315
} else if (config.should_store (output_idx)) {
291
316
if (accumulate) {
292
317
value = accumulate_in_output<can_accumulate_in_output>(out, value);
@@ -331,22 +356,38 @@ struct ReduceOp {
331
356
return value;
332
357
}
333
358
334
- C10_DEVICE arg_t warp_reduce (arg_t value) const {
335
- for (int offset = 1 ; offset < warpSize ; offset <<= 1 ) {
359
+ C10_DEVICE arg_t block_x_reduce (arg_t value, char * shared_memory) const {
360
+ int dim_x = blockDim .x ;
361
+ arg_t * shared = (arg_t *)shared_memory;
362
+ if (dim_x > warpSize ) {
363
+ int address_base = threadIdx .x + threadIdx .y *blockDim .x ;
364
+ shared[address_base] = value;
365
+ for (int offset = dim_x/2 ; offset >= warpSize ; offset >>= 1 ) {
366
+ __syncthreads ();
367
+ if (threadIdx .x < offset && threadIdx .x + offset < blockDim .x ) {
368
+ arg_t other = shared[address_base + offset];
369
+ value = ops.combine (value, other);
370
+ shared[address_base] = value;
371
+ }
372
+ }
373
+ dim_x = warpSize ;
374
+ }
375
+
376
+ __syncthreads ();
377
+
378
+ for (int offset = 1 ; offset < dim_x; offset <<= 1 ) {
336
379
arg_t other = ops.warp_shfl_down (value, offset);
337
380
value = ops.combine (value, other);
338
381
}
339
382
return value;
340
383
}
341
384
342
- C10_DEVICE arg_t block_reduce (arg_t value) const {
343
- extern __shared__ char shared_memory[];
385
+ C10_DEVICE arg_t block_y_reduce (arg_t value, char * shared_memory) const {
344
386
arg_t * shared = (arg_t *)shared_memory;
345
387
shared[config.shared_memory_offset (0 )] = value;
346
- int num_warps = (blockDim .x * blockDim .y ) / warpSize ;
347
- for (int offset = num_warps / 2 ; offset > 0 ; offset >>= 1 ) {
388
+ for (int offset = blockDim .y / 2 ; offset > 0 ; offset >>= 1 ) {
348
389
__syncthreads ();
349
- if (threadIdx .y < offset && threadIdx .y + offset < num_warps ) {
390
+ if (threadIdx .y < offset && threadIdx .y + offset < blockDim . y ) {
350
391
arg_t other = shared[config.shared_memory_offset (offset)];
351
392
value = ops.combine (value, other);
352
393
shared[config.shared_memory_offset (0 )] = value;
@@ -356,19 +397,17 @@ struct ReduceOp {
356
397
}
357
398
358
399
C10_DEVICE bool mark_block_finished () const {
359
- extern __shared__ int is_last_block_done_shared[] ;
400
+ __shared__ bool is_last_block_done_shared;
360
401
361
402
__syncthreads ();
362
403
if (threadIdx .x == 0 && threadIdx .y == 0 ) {
363
404
int prev_blocks_finished = atomicAdd (&semaphores[blockIdx .x ], 1 );
364
- is_last_block_done_shared[ 0 ] = (prev_blocks_finished == gridDim .y - 1 );
405
+ is_last_block_done_shared = (prev_blocks_finished == gridDim .y - 1 );
365
406
}
366
407
367
- __syncthreads ();
368
- bool is_last_block_done = is_last_block_done_shared[0 ];
369
408
__syncthreads ();
370
409
371
- return is_last_block_done ;
410
+ return is_last_block_done_shared ;
372
411
}
373
412
374
413
template <bool can_acc>
@@ -409,7 +448,7 @@ struct ReduceOp {
409
448
return ops.project (value);
410
449
}
411
450
412
- C10_DEVICE arg_t global_reduce (arg_t value, out_scalar_t * out) const {
451
+ C10_DEVICE arg_t global_reduce (arg_t value, out_scalar_t * out, char * shared_memory ) const {
413
452
arg_t * reduce_buffer = (arg_t *)buffer;
414
453
415
454
bool should_store = config.should_store (config.output_idx ());
@@ -424,7 +463,7 @@ struct ReduceOp {
424
463
425
464
if (is_last_block_done) {
426
465
value = ident;
427
- if (config.should_warp_reduce ()) {
466
+ if (config.should_block_x_reduce ()) {
428
467
index_t input_offset = threadIdx .x + threadIdx .y * blockDim .x ;
429
468
index_t step = blockDim .x * blockDim .y ;
430
469
for (; input_offset < config.ctas_per_output ; input_offset += step) {
@@ -441,9 +480,9 @@ struct ReduceOp {
441
480
value = ops.combine (value, next);
442
481
}
443
482
}
444
- value = block_reduce (value);
445
- if (config.should_warp_reduce ()) {
446
- value = warp_reduce (value);
483
+ value = block_y_reduce (value, shared_memory );
484
+ if (config.should_block_x_reduce ()) {
485
+ value = block_x_reduce (value, shared_memory );
447
486
}
448
487
if (should_store) {
449
488
if (accumulate) {
@@ -461,6 +500,7 @@ template<int nt, typename R>
461
500
static void launch_reduce_kernel (const ReduceConfig& config, const R& reduction) {
462
501
dim3 block = config.block ();
463
502
dim3 grid = config.grid ();
503
+
464
504
auto stream = at::cuda::getCurrentCUDAStream ();
465
505
int shared_memory = config.shared_memory_size ();
466
506
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>> (reduction);
@@ -487,35 +527,47 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
487
527
char * out_data = (char *)iter.data_ptr (0 );
488
528
const char * in_data = (char *)iter.data_ptr (1 );
489
529
490
-
491
- int warp_size = at::cuda::warp_size ();
492
- int warps_per_cta = ReduceConfig::NUM_THREADS / warp_size;
493
-
494
530
// Start by assuming that each thread handles a single output and all
495
531
// the inputs for that output.
496
532
int64_t num_outputs = iter.num_output_elements ();
497
533
int64_t inputs_per_output = iter.numel () / num_outputs;
498
534
499
535
auto config = ReduceConfig (sizeof (arg_t ), num_outputs, inputs_per_output);
500
536
537
+ int64_t dim0;
538
+ int64_t dim1;
539
+ // adjust block size to fit width to fast changing dimension
540
+ if (iter.strides (/* arg=*/ 1 )[0 ] == sizeof (scalar_t )) {
541
+ dim0 = iter.shape ()[0 ];
542
+ dim1 = num_outputs;
543
+ } else {
544
+ dim0 = iter.shape ()[iter.num_reduce_dims ()];
545
+ dim1 = inputs_per_output;
546
+ }
547
+
548
+ config.set_block_dimension (dim0, dim1);
549
+
550
+ int block_width = config.block_width ;
551
+ int block_height = config.block_height ;
552
+
501
553
if (iter.ndim () == 0 || iter.strides (/* arg=*/ 1 )[0 ] == sizeof (scalar_t )) {
502
554
// Split the input across lanes if the input is contiguous in the reduced
503
555
// dimension. This will require reduction between threads using warp
504
- // shuffle instructions.
505
- config.input_mult [0 ] = config.split_input (warp_size );
556
+ // shuffle instructions and shared memory (if block_width > warpSize) .
557
+ config.input_mult [0 ] = config.split_input (block_width );
506
558
} else {
507
559
// Otherwise split the output across lanes in a warp.
508
- config.output_mult [0 ] = config.split_output (warp_size );
560
+ config.output_mult [0 ] = config.split_output (block_width );
509
561
}
510
562
511
- if (config.values_per_thread () >= warps_per_cta * 16 ) {
563
+ if (config.values_per_thread () >= block_height * 16 ) {
512
564
// Divide the input across warps in a thread-block, if that leaves at least
513
565
// 16 elements to be summed by each thread. This will require inter-warp
514
566
// reduction using shared memory.
515
- config.input_mult [1 ] = config.split_input (warps_per_cta );
567
+ config.input_mult [1 ] = config.split_input (block_height );
516
568
} else {
517
569
// Otherwise, each warp handles a separate output.
518
- config.output_mult [1 ] = config.split_output (warps_per_cta );
570
+ config.output_mult [1 ] = config.split_output (block_height );
519
571
}
520
572
521
573
if (config.values_per_thread () >= 256 && num_outputs <= 4096 ) {
@@ -556,7 +608,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
556
608
reduce.accumulate = iter.should_accumulate ();
557
609
reduce.final_output = iter.is_final_output ();
558
610
559
- launch_reduce_kernel<ReduceConfig::NUM_THREADS >(config, reduce);
611
+ launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS >(config, reduce);
560
612
} else {
561
613
auto output_calc = make_output_calculator<uint64_t >(iter);
562
614
auto input_calc = make_input_calculator<uint64_t >(iter);
@@ -574,7 +626,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
574
626
reduce.accumulate = false ;
575
627
reduce.final_output = true ;
576
628
577
- launch_reduce_kernel<ReduceConfig::NUM_THREADS >(config, reduce);
629
+ launch_reduce_kernel<ReduceConfig::MAX_NUM_THREADS >(config, reduce);
578
630
}
579
631
}
580
632
0 commit comments