forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathgemm_grouped_per_group_scale.h
261 lines (216 loc) · 9.43 KB
/
gemm_grouped_per_group_scale.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Problem visitor for grouped GEMMs
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false
>
struct GemmGroupedPerGroupScale :
public GemmGrouped<Mma_, Epilogue_, ThreadblockSwizzle_, GroupScheduleMode_, Transposed> {
// Inherit constructors
using Base = GemmGrouped<Mma_, Epilogue_, ThreadblockSwizzle_, GroupScheduleMode_, Transposed>;
// Inherit type definitions
using typename Base::Mma;
using typename Base::Epilogue;
using typename Base::EpilogueOutputOp;
using typename Base::ThreadblockSwizzle;
using typename Base::Params;
using typename Base::SharedStorage;
// Explicitly inherit the kTransposed constant
static bool const kTransposed = Base::kTransposed;
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
//
// Problem visitor.
//
typename Base::ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
0);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{
0,
threadblock_offset.n()
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
{problem_size.k(), problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Epilogue
//
ElementC *ptr_C = params.ptr_C[problem_idx];
ElementC *ptr_D = params.ptr_D[problem_idx];
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C,
ptr_C,
problem_size.mn(),
thread_idx,
threadblock_offset.mn()
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D,
ptr_D,
problem_size.mn(),
thread_idx,
threadblock_offset.mn()
);
Epilogue epilogue(
shared_storage.kernel.epilogue,
thread_idx,
warp_idx,
lane_idx);
// The if branch is for the per-group scaling epilogue. The customized epilogue operator scales each gemm output by a scalar value.
// This branch is only enabled if EpilogueOutputOp is LinearCombination.
if constexpr (platform::is_same<EpilogueOutputOp,
::cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
EpilogueOutputOp::kRound>>::value)
{
EpilogueOutputOp output_op(params.output_op, problem_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
} else {
EpilogueOutputOp output_op(params.output_op);
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////