forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathpvc_flash_attn_runner.hpp
627 lines (517 loc) · 27.6 KB
/
pvc_flash_attn_runner.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "flash_attention_v2/collective/fmha_fusion.hpp"
#include "flash_attention_v2/kernel/tile_scheduler.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "flash_attention_v2/kernel/xe_flash_attn_gemm.hpp"
#include "flash_attention_v2/collective/xe_flash_attn_epilogue.hpp"
#include "flash_attention_v2/collective/xe_flash_attn_softmax_epilogue.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/sycl_event_manager.hpp"
#include <cute/tensor.hpp>
#include <random>
#include "helper.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "sycl_common.hpp"
using namespace cute;
// Command line options parsing
struct Options {
bool help;
bool error;
bool is_causal;
bool varlen = false;
std::string scheduler;
int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations;
float softmax_scale;
Options()
: help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128),
seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
if (cmd.check_cmd_line_flag("is_causal")) {
is_causal = true;
}
if (cmd.check_cmd_line_flag("varlen")) {
varlen = true;
}
cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual"));
cmd.get_cmd_line_argument("batch", batch, 32);
cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16);
cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q);
cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 512);
cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, seq_len_qo);
cmd.get_cmd_line_argument("head_size_vo", head_size_vo, 128);
cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo);
cmd.get_cmd_line_argument("iterations", iterations, 100);
softmax_scale = 1 / sqrt(static_cast<float>(head_size_qk));
}
/// Prints the usage statement.
std::ostream &print_usage(std::ostream &out) const {
out << "PVC Flash Attention v2 Example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --is_causal Apply Causal Mask to the output of first Matmul\n"
<< " --varlen Enable variable sequence length\n"
<< " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n"
<< " --batch=<int> Sets the Batch Size of the Multi-Head Self Attention module\n"
<< " --num_heads_q=<int> Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n"
<< " --num_heads_kv=<int> Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n"
<< " --seq_len_qo=<int> Sets the Sequence length of the Query input in Multi-Head Self Attention module\n"
<< " --seq_len_kv=<int> Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n"
<< " --head_size_qk=<int> Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n"
<< " --head_size_vo=<int> Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n"
<< " --iterations=<int> Iterations\n\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Flash Attention takes 3 input matrices: (K)eys, (Q)ueries and (V)alues.
using LayoutQ = cutlass::layout::RowMajor;
using LayoutK = cutlass::layout::ColumnMajor;
using LayoutV = cutlass::layout::RowMajor;
using LayoutO = cutlass::layout::RowMajor;
template <class GemmKernel, bool isVarLen> struct ExampleRunner {
using StrideQ = typename GemmKernel::StrideQ;
using StrideK = typename GemmKernel::StrideK;
using StrideV = typename GemmKernel::StrideV;
using StrideO = typename GemmKernel::StrideO;
using ElementQ = typename GemmKernel::ElementQ;
using ElementK = typename GemmKernel::ElementK;
using ElementV = typename GemmKernel::ElementV;
using ElementAcc = typename GemmKernel::ElementAccumulator;
using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
using ProblemShapeType = typename GemmKernel::ProblemShape;
//
// Data members
//
/// Initialization
StrideQ stride_Q;
StrideK stride_K;
StrideV stride_V;
StrideO stride_O;
uint64_t seed = 0;
cutlass::DeviceAllocation<ElementQ> block_Q;
cutlass::DeviceAllocation<ElementK> block_K;
cutlass::DeviceAllocation<ElementV> block_V;
cutlass::DeviceAllocation<ElementOutput> block_O;
cutlass::DeviceAllocation<ElementOutput> block_ref_O;
std::vector<int> cumulative_seqlen_q;
std::vector<int> cumulative_seqlen_kv;
cutlass::DeviceAllocation<int> device_cumulative_seqlen_q;
cutlass::DeviceAllocation<int> device_cumulative_seqlen_kv;
//
// Methods
//
bool verify(ProblemShapeType problem_size, bool is_causal) {
if constexpr (isVarLen) {
int max_seq_len_q = static_cast<int>(get<3>(problem_size));
int max_seq_len_kv = static_cast<int>(get<4>(problem_size));
get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()};
get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()};
}
auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size);
int seq_len_qo, seq_len_kv;
int offset_q = 0;
int offset_k = 0;
int offset_v = 0;
int offset_o = 0;
// loop over the batch dimension to compute the output
// to avoid the risk of running out of device memory
int q_group_size = num_heads_q/num_heads_kv;
for (int b = 0; b < batch; b++) {
if constexpr (isVarLen) {
auto logical_problem_shape = cutlass::fmha::collective::apply_variable_length(problem_size, b);
seq_len_qo = get<3>(logical_problem_shape);
seq_len_kv = get<4>(logical_problem_shape);
} else {
seq_len_qo = get<3>(problem_size);
seq_len_kv = get<4>(problem_size);
}
int kv_group_update=1;
for (int h = 0; h < num_heads_q; h++) {
cutlass::DeviceAllocation<ElementOutput> block_S;
block_S.reset(seq_len_qo * seq_len_kv);
cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
cutlass::TensorRef ref_K(block_K.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
cutlass::TensorRef ref_V(block_V.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
0.f, ref_S, ref_S, ElementAccumulator(0),
1, // batch_count
seq_len_qo * head_size_qk, // batch_stride_Q
seq_len_kv * head_size_qk, // batch_stride_K
seq_len_qo * seq_len_kv, // batch_stride_S
seq_len_qo * seq_len_kv // batch_stride_S
);
syclcompat::wait();
std::vector<ElementOutput> host_S(block_S.size());
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();
// delete this memory as it is no longer needed
block_S.reset();
auto offset = cute::min(seq_len_qo, seq_len_kv);
auto discard_seq_coord = seq_len_qo - offset;
auto full_tile_offset = seq_len_kv - offset;
if (is_causal) {
// apply mask to S
for (int row = 0; row < seq_len_qo; row++) {
for (int col = 0; col < seq_len_kv; col++) {
if ((col - full_tile_offset) > (row - discard_seq_coord))
host_S[col + row * seq_len_kv] = -INFINITY;
}
}
}
// compute max element per row of S
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int max_idx = row;
max_vec[max_idx] = host_S[idx++];
for (int col = 1; col < seq_len_kv; col++, idx++) {
if (max_vec[max_idx] < host_S[idx])
max_vec[max_idx] = host_S[idx];
}
}
// compute exp of S
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int max_idx = row;
for (int col = 0; col < seq_len_kv; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size_qk))));
}
}
// compute sum per row of S
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int sum_idx = row;
for (int col = 0; col < seq_len_kv; col++, idx++) {
sum_vec[sum_idx] += host_S[idx];
}
// scale each row with the sum to compute softmax
idx = row * seq_len_kv;
sum_idx = row;
for (int col = 0; col < seq_len_kv; col++, idx++) {
if(is_causal && row < discard_seq_coord) {
host_S[idx] = 0;
} else {
host_S[idx] /= sum_vec[sum_idx];
}
}
}
std::vector<ElementV> host_P(host_S.size());
for (int p = 0; p < host_P.size(); p++)
host_P[p] = static_cast<ElementV>(host_S[p]);
cutlass::DeviceAllocation<ElementV> block_P;
block_P.reset(host_P.size());
syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
syclcompat::wait();
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
0.f, ref_O, ref_O, ElementAccumulator(0),
1, // batch_count
seq_len_qo * seq_len_kv, // batch_stride_P
seq_len_kv * head_size_vo, // batch_stride_V
seq_len_qo * head_size_vo, // batch_stride_O
seq_len_qo * head_size_vo // batch_stride_O
);
syclcompat::wait();
// delete this memory as it is no longer needed
block_P.reset();
offset_q += seq_len_qo * head_size_qk;
if(kv_group_update % q_group_size==0) {
offset_k += seq_len_kv * head_size_qk;
offset_v += seq_len_kv * head_size_vo;
}
kv_group_update++;
offset_o += seq_len_qo * head_size_vo;
}
}
syclcompat::wait();
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
block_O.size(), 0.5f, 0.5f);
return passed;
}
template<class ProblemShape>
auto initialize_varlen(const ProblemShape& problem_size, const bool VarlenSame = true) {
int num_batches = get<0>(problem_size);
// generate Q as --b times
// gaussian (--Q, --Q / 2) sampled positive
// track cumulative
std::mt19937 rng(0x202305151552ull);
std::normal_distribution<double> dist_q(get<3>(problem_size), get<3>(problem_size) / 2);
std::normal_distribution<double> dist_kv(get<4>(problem_size), get<4>(problem_size) / 2);
auto generate_positive_int = [](auto& dist, auto& gen) {
int result = 0;
do {
result = static_cast<int>(dist(gen));
} while (result <= 0);
return result;
};
cumulative_seqlen_q = {0};
cumulative_seqlen_kv = {0};
int total_seqlen_q = 0;
int total_seqlen_kv = 0;
int max_seqlen_q = 0;
int max_seqlen_kv = 0;
for (int i = 0; i < num_batches; i++) {
int seqlen_q = VarlenSame ? get<3>(problem_size) : generate_positive_int(dist_q, rng);
int seqlen_kv = VarlenSame ? get<4>(problem_size) : generate_positive_int(dist_kv, rng);
total_seqlen_q += seqlen_q;
total_seqlen_kv += seqlen_kv;
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
}
ProblemShape problem_size_for_init = problem_size;
get<0>(problem_size_for_init) = 1;
get<3>(problem_size_for_init) = total_seqlen_q;
get<4>(problem_size_for_init) = total_seqlen_kv;
ProblemShapeType problem_size_for_launch;
get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q};
get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv};
get<5>(problem_size_for_launch) = get<5>(problem_size);
get<6>(problem_size_for_launch) = get<6>(problem_size);
get<0>(problem_size_for_launch) = get<0>(problem_size);
get<1>(problem_size_for_launch) = get<1>(problem_size);
get<2>(problem_size_for_launch) = get<2>(problem_size);
return cute::make_tuple(problem_size_for_init, problem_size_for_launch);
}
/// Initialize operands to be used in the GEMM and reference GEMM
ProblemShapeType initialize(const Options &options) {
auto problem_shape_in =
cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo);
ProblemShapeType problem_shape;
decltype(problem_shape_in) problem_size;
if constexpr (isVarLen) {
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
problem_shape = problem_shape_launch;
problem_size = problem_shape_init;
}
else {
problem_size = problem_shape_in;
problem_shape = problem_shape_in;
}
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_size;
stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, batch * num_heads_q));
stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv));
stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv));
stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, batch * num_heads_q));
block_Q.reset(static_cast<std::size_t>(batch) * num_heads_q * seq_len_qo * head_size_qk);
block_K.reset(static_cast<std::size_t>(batch) * num_heads_kv * seq_len_kv * head_size_qk);
block_V.reset(static_cast<std::size_t>(batch) * num_heads_kv * seq_len_kv * head_size_vo);
block_O.reset(static_cast<std::size_t>(batch) * num_heads_q * seq_len_qo * head_size_vo);
block_ref_O.reset(static_cast<std::size_t>(batch) * num_heads_q * seq_len_qo * head_size_vo);
initialize_block(block_Q, seed + 2023);
initialize_block(block_K, seed + 2022);
initialize_block(block_V, seed + 2021);
if (!cumulative_seqlen_q.empty()) {
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
device_cumulative_seqlen_q.copy_from_host(
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
}
if (!cumulative_seqlen_kv.empty()) {
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
device_cumulative_seqlen_kv.copy_from_host(
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
}
if constexpr (isVarLen) {
get<3>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
get<4>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
}
return problem_shape;
}
// Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this
// secondary `run` function is required to launch the kernel.
static void run(typename GemmKernel::Params params) {
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = GemmKernel::get_grid_shape(params);
// configure smem size and carveout
int smem_size = GemmKernel::SharedStorageSize;
const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z);
const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z);
// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension
#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
using namespace syclcompat::experimental;
auto event = launch<cutlass::device_kernel<GemmKernel>>(
launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)},
kernel_properties{sycl_exp::sub_group_size<GemmKernel::DispatchPolicy::SubgroupSize>}},
params);
#else
syclcompat::experimental::launch_properties launch_props {
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
};
syclcompat::experimental::kernel_properties kernel_props{
sycl::ext::oneapi::experimental::sub_group_size<GemmKernel::DispatchPolicy::SubgroupSize>
};
syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};
auto event = syclcompat::experimental::launch<cutlass::device_kernel<GemmKernel>>(policy, params);
#endif
EventManager::getInstance().addEvent(event);
}
cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) {
ProblemShapeType problem_size = initialize(options);
typename GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V},
{options.softmax_scale},
{block_O.get(), stride_O},
hw_info};
// GemmKernel gemm_op;
// Define device-global scratch memory
size_t workspace_size = GemmKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
if (!GemmKernel::can_implement(arguments)) {
std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' <<
options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo
<< (options.is_causal ? "xCausal" : "xNonCausal") << std::endl;
return cutlass::Status::kErrorInvalidProblem;
}
// Initialize the workspace
CUTLASS_CHECK(GemmKernel::initialize_workspace(arguments, workspace.get()));
// Convert host-side arguments to device-side arguments to be passed to the kernel
auto params = GemmKernel::to_underlying_arguments(arguments, workspace.get());
// Run the GEMM
run(params);
syclcompat::wait();
// Verify that the result is correct
bool passed = verify(problem_size, options.is_causal);
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if (!passed) {
return cutlass::Status::kErrorInternal;
}
if (options.iterations > 0) {
GPU_Clock timer;
timer.start();
for (int i = 0; i < options.iterations; ++i) {
run(params);
}
syclcompat::wait();
double effective_seq_len_kv = options.is_causal ?
options.seq_len_kv / 2.0 :
options.seq_len_kv;
double cute_time = timer.seconds() / options.iterations;
double flops_qk = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * effective_seq_len_kv * options.head_size_qk;
double flops_pv = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.head_size_vo * effective_seq_len_kv;
double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time;
double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options.seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk);
double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * options.seq_len_qo + options.seq_len_qo * options.head_size_vo);
double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time);
std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo
<< "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo
<< "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false")
<< "\t Scheduler: " << options.scheduler;
printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000);
}
return cutlass::Status::kSuccess;
}
};
template <bool Causal, typename TileShape, typename TiledMma> struct FMHAConfig {
template <bool isVarLen, class Scheduler>
static int run(const Options &options) {
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputQ = bfloat16_t; // <- data type of elements in input matrix A
using ElementInputKV = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
using GmemTiledCopyQ = XE_2D_U16x16x32_LD_N;
using GmemTiledCopyK = XE_2D_U16x16x16_LD_T; // _T designates a transposed block load operation
using GmemTiledCopyV = XE_2D_U16x32x32_LD_V;
using GmemTiledCopyStore = XE_2D_U32x8x16_ST_N;
using CollectiveEpilogue = cutlass::flash_attention::collective::CollectiveEpilogueAttention<
EpilogueDispatchPolicy, TileShape, ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
GmemTiledCopyStore>;
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::CollectiveSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
using ProblemShapeRegular = cute::tuple<int, int, int, int, int, int, int>;
using namespace cutlass::fmha::collective;
using ProblemShapeVarlen = cute::tuple<int, int, int, VariableLength, VariableLength, int, int>;
using ProblemShapeType = std::conditional_t<isVarLen, ProblemShapeVarlen, ProblemShapeRegular>;
// Mainloop
using CollectiveMainloop = cutlass::flash_attention::collective::CollectiveMmaAttention<
GEMMDispatchPolicy, ProblemShapeType, TileShape, ElementInputQ, cutlass::gemm::TagToStrideA_t<LayoutQ>, ElementInputKV,
cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV, cutlass::gemm::TagToStrideB_t<LayoutV>, TiledMma,
GmemTiledCopyQ, // Q
GmemTiledCopyK, // K
GmemTiledCopyV, // V,
Causal>;
using GemmKernel = cutlass::flash_attention::kernel::GemmUniversalAttention<ProblemShapeType, CollectiveMainloop,
CollectiveSoftmaxEpilogue, CollectiveEpilogue, Scheduler>;
ExampleRunner<GemmKernel, isVarLen> runner;
CUTLASS_CHECK(runner.run(options, hw_info));
return 0;
}
static int run(const Options &options) {
if(options.varlen) {
if(options.scheduler.compare(std::string("Persistent")) == 0) {
return run<true, cutlass::flash_attention::PersistentScheduler>(options);
} else {
return run<true, cutlass::flash_attention::IndividualScheduler>(options);
}
} else {
if(options.scheduler.compare(std::string("Persistent")) == 0) {
return run<false, cutlass::flash_attention::PersistentScheduler>(options);
} else {
return run<false, cutlass::flash_attention::IndividualScheduler>(options);
}
}
}
};