forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathsimt_canonical.cu
425 lines (348 loc) · 12.3 KB
/
simt_canonical.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example requires NVIDIA Maxwell GPU or beyond.
*/
// Standard Library includes
#include <iostream>
#include <sstream>
#include <vector>
// CUTLASS Includes
#include "cutlass/cutlass.h"
#include "cutlass/core_io.h"
#include "cutlass/functional.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/warp/mma_simt.h"
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
// CUTLASS Utility Includes
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/gemm_complex.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define the overal warp-level problem shape
int const kM = 14;
int const kN = 27;
int const kK = 17;
///////////////////////////////////////////////////////////////////////////////////////////////////
// Define a warp-level GEMM operator.
//
// This template could be part of the CUTLASS Template Library or implemented internally. This
// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be
// instantiated in device code.
namespace cutlass {
namespace gemm {
namespace warp {
template <
typename Shape,
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementScalar
>
class GemmSimt {
public:
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
cutlass::MatrixShape<4, 8>,
cutlass::layout::RowMajorInterleaved<2>,
cutlass::gemm::GemmShape<4, 4, 1>
>;
using MmaWarp = cutlass::gemm::warp::MmaSimt<
cutlass::gemm::GemmShape<16, 32, 8>,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
Policy
>;
// Number of 'K groups'
int const kKgroups = Shape::kK;
using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
typename MmaWarp::Shape,
typename MmaWarp::ThreadMma,
layout::RowMajor, // SMEM layout
typename MmaWarp::Policy
>;
using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical<
typename MmaWarp::Shape,
typename MmaWarp::ThreadMma,
float, // ElementAccumulator
layout::RowMajor, // SMEM layout
typename MmaWarp::Policy
>;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
using TensorRefC = typename AccumulatorTileIterator::TensorRef;
public:
CUTLASS_HOST_DEVICE
GemmSimt() { }
CUTLASS_DEVICE
void operator()(
ElementScalar alpha,
TensorRefA ref_A,
TensorRefB ref_B,
ElementScalar beta,
TensorRefC ref_C,
TensorRefC ref_D,
int lane_id) const {
// Instantiate iterators pointing to slices of the A and B matrices in shared memory
typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id);
typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id);
// Instantiate and clear accumulator tile holding the C matrix
typename MmaWarp::FragmentC accum;
accum.clear();
// Instantiate the warp-level matrix multiply operator
MmaWarp mma_op;
// Instantiate fragments holding the slice of the matrix held by each warp
typename MmaWarp::FragmentA frag_A[2];
typename MmaWarp::FragmentB frag_B[2];
// Load fragments from shared memory
iter_A.load(frag_A[0]);
iter_B.load(frag_B[0]);
++iter_A;
++iter_B;
// Load fragments from shared memory
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
// Load fragments from shared memory
iter_A.load(frag_A[(k + 1) % 2]);
iter_B.load(frag_B[(k + 1) % 2]);
++iter_A;
++iter_B;
// Compute the matrix multiply
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
}
// Instantiate iterators
FragmentIterator accum_frag_it(accum);
AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id);
AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id);
// Define function objects for linear scaling operation
cutlass::multiplies<typename FragmentIterator::Fragment> mul_source;
cutlass::multiply_add<typename FragmentIterator::Fragment> mul_add_accumulator;
// Iterate over the epilogue components
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) {
// Define storage for slices of the accumulators
typename FragmentIterator::Fragment accum_fragment;
typename FragmentIterator::Fragment source_fragment;
// Select a slice of accumulators from the accumulator tile
accum_frag_it.load(accum_fragment);
++accum_frag_it;
// Load a corresponding slice from Shared memory
source_tile_it.load(source_fragment);
++source_tile_it;
// Compute linear scaling - alpha * AB + beta * C
source_fragment = mul_source(beta, source_fragment);
accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment);
// Store the result to shared memory
dest_tile_it.store(accum_fragment);
++dest_tile_it;
}
}
};
} // namespace warp
} // namespace gemm
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held
// in Shared Memory.
__global__ void kernel(
float *D_gmem,
float alpha,
float const *A_gmem,
float const *B_gmem,
float beta,
float const *C_gmem) {
// Define several matrices in shared memory
__shared__ float A[kM][kK];
__shared__ float B[kN][kK];
__shared__ float C[kM][kN];
// Copy data into SMEM
if (threadIdx.x == 0) {
CUTLASS_PRAGMA_NO_UNROLL
for (int m = 0; m < kM; ++m) {
for (int k = 0; k < kK; ++k) {
A[m][k] = A_gmem[m * kK + k];
}
}
CUTLASS_PRAGMA_NO_UNROLL
for (int n = 0; n < kN; ++n) {
for (int k = 0; k < kK; ++k) {
B[n][k] = B_gmem[n * kK + k];
}
}
CUTLASS_PRAGMA_NO_UNROLL
for (int m = 0; m < kM; ++m) {
CUTLASS_PRAGMA_NO_UNROLL
for (int n = 0; n < kN; ++n) {
C[m][n] = C_gmem[m * kN + n];
}
}
}
__syncthreads();
//
// Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4),
// overall shape, data type of each operand, and layout of each operand.
//
using GemmSimt = cutlass::gemm::warp::GemmSimt<
cutlass::gemm::GemmShape<kM, kN, kK>,
float, // Data type of A elements
cutlass::layout::RowMajor, // Layout of A matrix
float, // Data type of B elements
cutlass::layout::ColumnMajor, // Layout of B matrix
float, // Data type of C elements
cutlass::layout::RowMajor, // Layout of C matrix
float // Scalar type of alpha and beta
>;
// Instantiate the GEMM operator
GemmSimt gemm;
// Execute the warp-level GEMM operation
gemm(
alpha,
{&A[0][0], kK},
{&B[0][0], kK},
beta,
{&C[0][0], kN},
{&C[0][0], kN},
threadIdx.x);
__syncthreads();
// Copy data into SMEM
if (threadIdx.x == 0) {
CUTLASS_PRAGMA_NO_UNROLL
for (int m = 0; m < kM; ++m) {
CUTLASS_PRAGMA_NO_UNROLL
for (int n = 0; n < kN; ++n) {
D_gmem[m * kN + n] = C[m][n];
}
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, const char *arg[]) {
cutlass::HostTensor<float, cutlass::layout::RowMajor> A({kM, kK});
cutlass::HostTensor<float, cutlass::layout::ColumnMajor> B({kK, kN});
cutlass::HostTensor<float, cutlass::layout::RowMajor> C({kM, kN});
cutlass::HostTensor<float, cutlass::layout::RowMajor> D({kM, kN});
uint64_t seed = 2020;
float max = 8;
float min = -8;
std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape<kM, kN, kK>() <<")" << std::endl;
cutlass::reference::host::TensorFillRandomUniform(
A.host_view(),
seed,
max,
min,
0
);
cutlass::reference::host::TensorFillRandomUniform(
B.host_view(),
seed + 17,
max,
min,
0
);
#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging
cutlass::reference::host::BlockFillSequential(
A.host_view().data(), A.host_view().capacity());
cutlass::reference::host::TensorFillIdentity(B.host_view());
#endif
cutlass::reference::host::TensorFillRandomUniform(
C.host_view(),
seed + 31,
max,
min,
0
);
A.sync_device();
B.sync_device();
C.sync_device();
D.sync_device();
dim3 grid(1, 1);
dim3 block(32, 1, 1);
float alpha = 1.0f;
float beta = 0.0f;
kernel<<< grid, block >>>(
D.device_data(),
alpha,
A.device_data(),
B.device_data(),
beta,
C.device_data()
);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Failed to synchronize device after kernel launch." << std::endl;
return -1;
}
D.sync_host();
// Compute reference on host
cutlass::HostTensor<float, cutlass::layout::RowMajor> D_ref({kM, kN}, false);
cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view());
cutlass::reference::host::Gemm<
float, cutlass::layout::RowMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
float, float> reference_gemm;
reference_gemm(
{kM, kN, kK},
alpha,
A.host_ref(),
B.host_ref(),
beta,
D_ref.host_ref(),
float()
);
// Verify reference matches computed
if (!cutlass::reference::host::TensorEquals(
D.host_view(),
D_ref.host_view())) {
std::cerr
<< "A =\n" << A.host_view()
<< "\n\nB = \n" << B.host_view()
<< "\n\nC = " << C.host_view()
<< "\n\nRef =\n" << D_ref.host_view()
<< "\n\nD =\n" << D.host_view() << "\n\n";
std::cerr << "Error - device results mismatch host reference." << std::endl;
return -1;
}
std::cout << "Passed" << std::endl;
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////