forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathgemm_grouped_softmax_mainloop_fusion.h
481 lines (395 loc) · 14.9 KB
/
gemm_grouped_softmax_mainloop_fusion.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Problem visitor for grouped GEMMs with a softmax fused beforehand
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false
>
struct GemmGroupedSoftmaxMainloopFusion {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = Transposed;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<
typename Mma::IteratorA::Element,
typename Mma::IteratorA::Layout,
Mma::kTransformA,
Mma::IteratorA::AccessType::kElements,
typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout,
Mma::kTransformB,
Mma::IteratorB::AccessType::kElements,
typename Mma::LayoutC,
kTransposed
>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScaleBias = typename Mma::IteratorNormSum::Element;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = GemmGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount,
kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord *problem_sizes{nullptr};
int problem_count{0};
int threadblock_count{0};
typename EpilogueOutputOp::Params output_op{};
ElementA ** ptr_A{nullptr};
ElementB ** ptr_B{nullptr};
ElementC ** ptr_C{nullptr};
ElementC ** ptr_D{nullptr};
void ** ptr_norm{nullptr};
void ** ptr_sum{nullptr};
typename LayoutA::Stride::LongIndex *lda{nullptr};
typename LayoutB::Stride::LongIndex *ldb{nullptr};
typename LayoutC::Stride::LongIndex *ldc{nullptr};
typename LayoutC::Stride::LongIndex *ldd{nullptr};
// Only used by device-level operator
GemmCoord *host_problem_sizes{nullptr};
//
// Methods
//
/// Default ctor
Arguments() = default;
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord *problem_sizes,
int problem_count,
int threadblock_count,
typename EpilogueOutputOp::Params output_op,
ElementA ** ptr_A,
ElementB ** ptr_B,
ElementC ** ptr_C,
ElementC ** ptr_D,
void ** ptr_norm,
void ** ptr_sum,
typename LayoutA::Stride::LongIndex *lda,
typename LayoutB::Stride::LongIndex *ldb,
typename LayoutC::Stride::LongIndex *ldc,
typename LayoutC::Stride::LongIndex *ldd,
GemmCoord *host_problem_sizes=nullptr
):
problem_sizes(problem_sizes),
problem_count(problem_count),
threadblock_count(threadblock_count),
output_op(output_op),
ptr_A(ptr_A),
ptr_B(ptr_B),
ptr_C(ptr_C),
ptr_D(ptr_D),
ptr_norm(ptr_norm),
ptr_sum(ptr_sum),
lda(lda),
ldb(ldb),
ldc(ldc),
ldd(ldd),
host_problem_sizes(host_problem_sizes)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor{};
int threadblock_count{0};
typename EpilogueOutputOp::Params output_op{};
ElementA ** ptr_A{nullptr};
ElementB ** ptr_B{nullptr};
ElementC ** ptr_C{nullptr};
ElementC ** ptr_D{nullptr};
void ** ptr_norm{nullptr};
void ** ptr_sum{nullptr};
typename LayoutA::Stride::LongIndex *lda{nullptr};
typename LayoutB::Stride::LongIndex *ldb{nullptr};
typename LayoutC::Stride::LongIndex *ldc{nullptr};
typename LayoutC::Stride::LongIndex *ldd{nullptr};
//
// Methods
//
Params() = default;
CUTLASS_HOST_DEVICE
Params(Arguments const &args,
void *workspace = nullptr,
int tile_count = 0):
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
output_op(args.output_op),
ptr_A(args.ptr_A),
ptr_B(args.ptr_B),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D),
ptr_norm(args.ptr_norm),
ptr_sum(args.ptr_sum),
lda(args.lda),
ldb(args.ldb),
ldc(args.ldc),
ldd(args.ldd)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count,
workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_norm = args.ptr_norm;
ptr_sum = args.ptr_sum;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
}
};
/// Shared memory storage structure
struct SharedStorage {
union {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmGroupedSoftmaxMainloopFusion() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
0);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{
0,
threadblock_offset.n()
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
{problem_size.k(), problem_size.n()},
thread_idx,
tb_offset_B);
// Construct iterator to the softmax norm/sum vector
typename Mma::IteratorNormSum iterator_norm_sum(
problem_size.m(),
static_cast<ElementScaleBias const *>(params.ptr_norm[problem_idx]),
static_cast<ElementScaleBias const *>(params.ptr_sum[problem_idx]),
thread_idx,
MatrixCoord(0, threadblock_offset.m())
);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
iterator_norm_sum,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC *ptr_C = params.ptr_C[problem_idx];
ElementC *ptr_D = params.ptr_D[problem_idx];
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C,
ptr_C,
problem_size.mn(),
thread_idx,
threadblock_offset.mn()
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D,
ptr_D,
problem_size.mn(),
thread_idx,
threadblock_offset.mn()
);
Epilogue epilogue(
shared_storage.kernel.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////