forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 85
Expand file tree
/
Copy path04_bmg_grouped_gemm.cpp
More file actions
652 lines (535 loc) · 24.3 KB
/
04_bmg_grouped_gemm.cpp
File metadata and controls
652 lines (535 loc) · 24.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Intel BMG Group Gemm
This example demonstrates fusing multiple GEMM operations into one kernel.
Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, have been
replaced with vector equivalents, as each individual GEMM has its own inputs and outputs, which
needn't be contiguous in memory. For example, where 00_bmg_gemm receives an `ElementA *`
defining Matrix A, grouped gemm receives a `ElementA **`, i.e. a pointer to pointers, each
pointing to a distinct Matrix A. Likewise, each individual GEMM operation may have its own alpha
and beta factors for linear combination. This example demonstrates two approaches: the user can
provide `options.alpha` and `options.beta`, in which case they will apply to all GEMMs;
otherwise, random values are generated per GEMM.
Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM,
because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will
distribute an a priori unknown number of tiles to each work-group. See
include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for implementation.
Note that for simplicity, this example sets every GEMM in the group to the same shape.
Verification for this example is a conventional GEMM kernel, executed iteratively per group.
To build & run this example (from your build dir):
$ ninja 04_bmg_grouped_gemm
$ ./examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm
Call with `--help` for information about available options.
Note: the code may spill registers once compiled which will result in sub-optimal performance. This is because
of an issue inside Intel Graphics Compiler (IGC) related to VectorAliasBBThreshold being debugged internally.
To avoid register spills, build the example by setting the environment variable:
$ export IGC_VectorAliasBBThreshold=10000
*/
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/xe_array_epilogue.hpp"
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_mma.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include <cute/tensor.hpp>
#include <random>
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "sycl_common.hpp"
#include "helper.h"
#include <cfloat>
using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementA = bfloat16_t; // <- data type of elements in input matrix A
using ElementB = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool error = false;
bool help = false;
float alpha, beta;
int iterations;
int m, n, k, groups;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100),
m(5120), n(4096), k(4096), groups(2) {
problem_sizes_host.reserve(groups);
for(int i = 0; i < groups; i++) {
problem_sizes_host.push_back({m, n, k});
}
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m, 5120);
cmd.get_cmd_line_argument("n", n, 4096);
cmd.get_cmd_line_argument("k", k, 4096);
cmd.get_cmd_line_argument("groups", groups, 2);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations, 100);
assert(groups > 0);
problem_sizes_host.clear();
problem_sizes_host.reserve(groups);
for(int i = 0; i < groups; i++) {
problem_sizes_host.push_back({m, n, k});
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "BMG Grouped GEMM\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "bmg_grouped_gemm" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
{
// Number of real-valued multiply-adds
uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) {
fmas += static_cast<uint64_t>(get<0>(problem)) *
static_cast<uint64_t>(get<1>(problem)) *
static_cast<uint64_t>(get<2>(problem));
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
class Gemm
>
struct ExampleRunner {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using LayoutD = typename Gemm::LayoutD;
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementAccumulator = ElementOutput;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
// This example defines all matrices in a single allocation (e.g. block_A), but this is not a
// requirement. Matrix base pointers are read from device allocation (e.g. ptr_A)
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
cutlass::DeviceAllocation<const ElementA *> ptr_A;
cutlass::DeviceAllocation<const ElementB *> ptr_B;
cutlass::DeviceAllocation<const ElementC *> ptr_C;
cutlass::DeviceAllocation<ElementOutput *> ptr_D;
cutlass::DeviceAllocation<ElementOutput *> ptr_ref_D;
cutlass::DeviceAllocation<StrideA> stride_A;
cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
uint64_t seed = 0;
//
// Methods
//
bool verify(const Options &options) {
bool passed = true;
// Verify against individual reference GEMMs
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N}));
//
// Compute reference output
//
cutlass::reference::device::GemmComplex(
{M, N, K},
alpha_host.at(i),
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
beta_host.at(i),
ref_C,
ref_D,
ElementAccumulator(0),
1, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);
// Wait for kernel to finish
compat::wait();
// Check if output from CUTLASS kernel and reference kernel are equal or not
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
if(!passed)
break;
}
return passed;
}
/// Allocates device-side data
void allocate(const Options &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
// Compute total allocation sizes across group
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
// Offset into block allocation of each matrix base pointer
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
// Compute offsets, alpha & beta over group on host
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
// Fill host vector of alpha & beta with random values if using per-group values
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
// Fill host ptr vectors with offset addresses into device alpha/beta blocks
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
// Allocate device memory & copy from host
ptr_A.reset(options.groups);
// Per-group alpha and beta
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
// Per-group alpha and beta ptrs
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
// Per-group alpha and beta values - note these are not directly passed to kernel - the pointers
// (alpha_device/beta_device) are passed instead
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options,
const cutlass::KernelHardwareInfo& hw_info,
bool host_problem_shapes_available = true,
bool use_nullptr_c = false)
{
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
fusion_args.alpha = options.alpha;
fusion_args.beta = options.beta;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
// Single alpha and beta for all groups
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = alpha_device.get();
fusion_args.beta_ptr_array = beta_device.get();
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
}
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup<ProblemShape>::RasterOrderOptions;
// Per-GEMM problem shape info may only exist on the device.
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info,
{1, RasterOrderOptions::AlongN}
};
}
else {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info,
{1, RasterOrderOptions::AlongN}
};
}
return arguments;
}
cutlass::Status run(const Options& options,
const cutlass::KernelHardwareInfo& hw_info,
bool host_problem_shapes_available = true,
bool use_nullptr_c = false) {
allocate(options);
initialize(options);
Gemm gemm_op;
auto arguments = args_from_options(options, hw_info, host_problem_shapes_available, use_nullptr_c);
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
// Run the GEMM
CUTLASS_CHECK(gemm_op.run());
compat::wait();
// Verify that the result is correct
bool passed = verify(options);
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
if(!passed) return cutlass::Status::kErrorInternal;
if (options.iterations > 0) {
GPU_Clock timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm_op.run());
}
compat::wait();
float cute_time = timer.seconds() * 1000;
double cute_average_time = double(cute_time) / double(options.iterations);
double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl;
std::cout << " GFLOPS : " << gflops << std::endl;
}
return cutlass::Status::kSuccess;
}
};
template<bool use_nullptr_c=false>
void launcher(Options& options) {
//
// Run examples
//
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
// information is used by the underlying kernel.
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
using TiledMma =
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
constexpr int PipelineStages = 2;
// Dispatch to grouped gemm algorithm
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
using EpilogueOp =
cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest, !use_nullptr_c>>;
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC*>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD*>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U32x8x16_ST_N,
void, void>;
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
ElementA,
cutlass::gemm::TagToStrideA_t<LayoutA*>,
ElementB,
cutlass::gemm::TagToStrideB_t<LayoutB*>,
TiledMma,
GmemTiledCopyA, void, void, cute::identity, // A
GmemTiledCopyB, void, void, cute::identity // B
>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::GroupScheduler
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
ExampleRunner<Gemm> runner;
CUTLASS_CHECK(runner.run(options, hw_info, true, /* use_nullptr_c = */use_nullptr_c));
}
int main(int argc, const char** argv)
{
//
// Parse options
//
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
if (options.beta == 0.f) {
// the reference kernel doesn't accept nullptr for C, so we only test for nullptr ptr_C epilogue arg
// when beta is 0.
std::cout << "\n\nUse a nullptr as argument ptr_C of the group GEMM epilogue colective\n\n";
launcher<true>(options);
std::cout << "\n\nPass actual ptr_C as an argument to the group GEMM epilogue colective\n\n";
}
launcher<false>(options);
return 0;
}