forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathampere_conv_kernel.h
320 lines (275 loc) · 12.2 KB
/
ampere_conv_kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_atom.hpp"
#include <random>
#include "cutlass/util/print_error.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_mma.hpp"
using namespace cute;
struct AmpereUnpredicatedFprop {
//
// Static config for conv problem shape
//
using D = _6;
using H = _4;
using W = _4;
using T = _3;
using R = _3;
using S = _3;
using Z = _4;
using P = _2;
using Q = _2;
using C = _64;
using K = _128;
// Tiler config
using Tiler_K = decltype(cute::min(K{}, _128{}));
using Tiler_C = decltype(cute::min(C{}, _32{}));
using Tiler_N = _4;
using TileM = Tiler_K;
using TileN = Shape<Tiler_N, Z, P, Q>;
using TileK = Shape<Tiler_C,_1,_1,_1>;
using PIPE = _3;
using TilerFlt = Shape<TileM, TileK>;
using TilerAct = Shape<TileN, TileK>;
using TilerOut = Shape<TileM, TileN>;
using TileSizeM = Int<size(TileM{})>;
using TileSizeN = Int<size(TileN{})>;
using TileSizeK = Int<size(TileK{})>;
static constexpr int Stages = PIPE::value;
using ElementFlt = tfloat32_t;
using ElementAct = tfloat32_t;
using ElementOut = float;
using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
Layout<Shape<_2,_2,_1>>,
Tile<_32,_32,Underscore>>;
static constexpr int MaxThreadsPerBlock = size(TiledMma{});
static constexpr int MinBlocksPerMultiprocessor = 1;
union SharedStorage {
struct {
ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
} mainloop;
struct {
ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
} epilogue;
};
//
// Stencil tensor
//
using GmemLayoutFlt = decltype(make_ordered_layout(
Shape< K, Shape< C, T, R, S>>{},
tuple<_4, tuple<_0,_3,_2,_1>>{}));
// We have 64 elements * 32b each in the major mode that we can vectorize
// Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
// Rest along the minor mode
using GmemTiledCopyFlt = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
Layout<Shape <_16, _8>,
Stride< _8, _1>>{},
Layout<Shape < _1, _4>>{}));
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
// using SmemLayoutFlt = decltype(
// composition(Swizzle<3,2,3>{},
// make_ordered_layout(
// Shape<TileSizeM,TileSizeK,PIPE>{},
// tuple< _1, _0, _2>{})));
using SmemLayoutAtomFlt = decltype(
composition(Swizzle<1,2,3>{},
Layout<Shape <_8,Shape <_4, _2>>,
Stride<_4,Stride<_1,_32>>>{}));
using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
//
// Activation tensor
//
// Activation tensor is major in the contraction mode, so vectorize that mode first
// Then lay out the rest of the threads along the other mode
using GmemTiledCopyAct = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
Layout<Shape <_16, _8>,
Stride< _8, _1>>{},
Layout<Shape < _1, _4>>{}));
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
// using SmemLayoutAct = decltype(
// composition(Swizzle<3,2,3>{},
// make_ordered_layout(
// Shape<TileSizeN,TileSizeK,PIPE>{},
// tuple< _1, _0, _2>{})));
using SmemLayoutAtomAct = decltype(
composition(Swizzle<1,2,3>{},
Layout<Shape <_8,Shape <_4, _2>>,
Stride<_4,Stride<_1,_32>>>{}));
using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
//
// Output tensor
//
using GmemTiledCopyOut = decltype(make_tiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
Layout<Shape <_8, _16>,
Stride<_1, _8>>{},
Layout<Shape <_4, _1>>{}));
using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
// This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
//
// Conv functor
//
template <class EngineFlt, class TensorActivation, class TensorOutput>
void __device__
operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
TensorOutput mOut, // ( K, (N,Z,P,Q))
char* smem_buf) const {
using namespace cute;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
Shape<TileM,TileN,TileK>,
ElementFlt,
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
ElementAct,
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
TiledMma,
GmemTiledCopyFlt,
SmemLayoutAtomFlt,
SmemCopyAtomFlt,
cute::identity,
GmemTiledCopyAct,
SmemLayoutAtomAct,
SmemCopyAtomAct,
cute::identity>;
TiledMma tiled_mma;
Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
clear(accum);
// Set up tensors
// NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
// Compute m_coord and n_coord with their post-tiled shapes
auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
int k_tile_count = size<2>(gA);
CollectiveMainloop collective_mma;
collective_mma(
accum,
gA,
gB,
accum,
k_tile_iter, k_tile_count,
Underscore{}, // no residue since we do not support predication
threadIdx.x,
smem_buf);
//
// Epilogue
//
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
auto tCrC = smem_thr_copy_C.retile_S(accum);
auto tCsC = smem_thr_copy_C.partition_D(sC);
copy(smem_tiled_copy_C, tCrC, tCsC);
__syncthreads();
GmemTiledCopyOut gmem_tiled_copy_C;
auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
auto tDsC = gmem_thr_copy_C.partition_S(sC);
auto tDgC = gmem_thr_copy_C.partition_D(gC);
copy(gmem_tiled_copy_C, tDsC, tDgC);
#if 0
if (thread0()) {
print("mAct = "); print(mAct); print('\n');
print("mFlt = "); print(mFlt); print('\n');
print("mOut = "); print(mOut); print('\n');
print("gA = "); print(gA); print('\n');
print("gB = "); print(gB); print('\n');
print("gC = "); print(gC); print('\n');
print("sA = "); print(sA.layout()); print('\n');
print("sB = "); print(sB.layout()); print('\n');
print("sC = "); print(sC.layout()); print('\n');
print("tAgA = "); print(tAgA.layout()); print('\n');
print("tBgB = "); print(tBgB.layout()); print('\n');
print("tAsA = "); print(tAsA.layout()); print('\n');
print("tBsB = "); print(tBsB.layout()); print('\n');
print("tCsA = "); print(tCsA.layout()); print('\n');
print("tCsB = "); print(tCsB.layout()); print('\n');
print("tCrC = "); print(tCrC.layout()); print('\n');
print("tCsC = "); print(tCsC.layout()); print('\n');
print("tDsC = "); print(tDsC.layout()); print('\n');
print("tDgC = "); print(tDgC.layout()); print('\n');
print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
print("k_tile_count = "); print(size<2>(gA)); print('\n');
print("k_tile_iter = "); print(*k_tile_iter); print('\n');
print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
}
#endif
}
};
template <class TensorFlt, class TensorAct, class TensorOut>
inline int
fprop_reference(
TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
TensorOut mOutputRef) {
int32_t N = size<1,0>(mOutputRef);
int32_t Z = size<1,1>(mOutputRef);
int32_t P = size<1,2>(mOutputRef);
int32_t Q = size<1,3>(mOutputRef);
int32_t T = size<1,3>(mStencil);
int32_t R = size<1,2>(mStencil);
int32_t S = size<1,1>(mStencil);
int32_t C = size<1,0>(mStencil);
size_t K = static_cast<size_t>(size<0>(mOutputRef));
size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
size_t CTRS = static_cast<size_t>(size<1>(mStencil));
#if defined(_OPENMP)
#pragma omp parallel for
#endif
for (size_t logical_m = 0; logical_m < K; ++logical_m) {
for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
auto accumulator = float(0);
for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
}
mOutputRef(logical_m, logical_n) = accumulator;
}
}
return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
}