Skip to content

Commit f6ca3e7

Browse files
Enable FP8_E5M2 GEMM (#352)
### Summary 1. Enable FP8 GEMM for float_e5m2_t dtype. float_e5m2_t -> FP16 conversion code has been adapted/copy-pasted from https://github.com/pytorch/pytorch/blob/dfcfad2112933cc34247421ac0a4d3f19a1806c1/c10/util/Float8_e5m2.h#L30-L43 2. The existing `E4M3 -> FP16` conversion uses many copies, and `E5M2 -> FP16` conversion was also sharing some code with it in the first commit of this PR. Achieved ~85% speedup after eliminating all unnecessary copies for `E5M2 -> FP16` conversion. 3. The FP8 GEMM example is now run for both E5M2 & E4M3. ### Caveat It seems the of `E5M2 -> FP16` conversion time can't be estimated by comparing this implementation with an implementation that disables `E5M2 -> FP16` conversion because weirdly, the throughput is higher with the implementation in this PR than the case of disabling `FP8 -> FP16` conversion, which uses garbage values for `A` & `B` fragments (but FP8 data is still loaded into registers) - ```cpp //convert_FP8_to_FP16(tCrA, tCrA_fp16); //convert_FP8_to_FP16(tCrB, tCrB_fp16); ``` Some ops involved in the compute may be sensitive to NaNs (maybe present/generated in case of garbage `A` & `B` FP16 values), which may be causing a slowdown in that case. ### E5M2 perf comparison with E4M3 | M | N | K | L | Latency with E4M3 | Latency with E5M2 |Speedup| |--|--|--|--|-----|-----|---| |1024|1536|7168|1|2.9335 ms |0.4216 ms | 6.95x | |1024|1536|1536|1|0.6363 ms |0.0950 ms | 6.69x| |1024|576|7168|1|2.9326 ms | 0.4214 ms| 6.95x | |1024|2048|512|1|0.2203 ms |0.0359 ms | 6.14x | |1024|7168|1024|1|0.8571 ms | 0.1301 ms| 6.59x | |1024|256|7168|1| 2.9286 ms| 0.4209 ms| 6.96x | |1024|7168|128|1|0.1256 ms |0.0269 ms | 4.66x | Intel GPU Max 1550 was used #### Build commands (in cutlass directory) ``` export IGC_ExtraOCLOptions="-cl-intel-256-GRF-per-thread" export IGC_VectorAliasBBThreshold=12000000 export IGC_VISAOptions="-perfmodel" rm -rf build; mkdir build; cd build; CC=clang CXX=clang++ cmake .. -GNinja -DCUTLASS_ENABLE_EXAMPLES=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUTLASS_ENABLE_SYCL=ON -DCUTLASS_SYCL_PROFILING_ENABLED=ON -DDPCPP_SYCL_TARGET=intel_gpu_pvc -DCUTLASS_ENABLE_BENCHMARKS=OFF -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fdiagnostics-color=always" ``` cc @pengzhao-intel --------- Co-authored-by: Alejandro Acosta <[email protected]>
1 parent 925745b commit f6ca3e7

File tree

3 files changed

+470
-444
lines changed

3 files changed

+470
-444
lines changed

examples/sycl/08_pvc_gemm_f8/08_pvc_gemm_f8.cpp

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
*
3030
**************************************************************************************************/
3131
/*! \file
32-
\brief CUTLASS Intel PVC Gemm with float8 (float_e4m3_t) input
32+
\brief CUTLASS Intel PVC Gemm with float8 (float_e4m3_t or float_e5m2_t) input
3333
34-
This example demonstrates GEMM on PVC with float8 input. cutlass::float_e4m3_t is an 8-bit
35-
floating point type with 4-bit exponent, 3-bit mantissa and 1 sign bit. The GEMM in this example
36-
performs the MMA with fp16 input, first upcasting the float_e4m3_t data for both A and B.
34+
This example demonstrates GEMM on PVC with float8 input. The GEMM in this example
35+
performs the MMA with fp16 input, first upcasting the fp8 data for both A and B.
36+
37+
Aside from the input datatypes, this example is identical to 00_pvc_gemm, except that
38+
we're currently being forced to load A with VNNI layout, which probably degrades
39+
performance. Ref: https://github.com/codeplaysoftware/cutlass-sycl/issues/357
3740
38-
Aside from the input datatypes, this example is identical to 00_pvc_gemm.
3941
4042
Verification for this example is a standard fp16 GEMM, with input data upcasted on the host.
4143
@@ -172,7 +174,7 @@ struct ExampleRunner {
172174
// Methods
173175
//
174176
template <typename SrcT, typename DstT>
175-
void convert_e4m3_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
177+
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
176178
SrcT* h_src = new SrcT[size];
177179
syclcompat::memcpy(h_src, d_src, size * sizeof(SrcT));
178180
syclcompat::wait();
@@ -193,12 +195,12 @@ struct ExampleRunner {
193195
cutlass::DeviceAllocation<half_t> block_B_fp16(block_B.size());
194196

195197
// fp8 -> fp16
196-
convert_e4m3_to_fp16<float_e4m3_t, half_t>(
198+
convert_fp8_to_fp16<ElementA, half_t>(
197199
block_A.get(),
198200
block_A_fp16.get(),
199201
block_A.size()
200202
);
201-
convert_e4m3_to_fp16<float_e4m3_t, half_t>(
203+
convert_fp8_to_fp16<ElementA, half_t>(
202204
block_B.get(),
203205
block_B_fp16.get(),
204206
block_B.size()
@@ -307,6 +309,11 @@ struct ExampleRunner {
307309
float cute_time = timer.seconds() / options.iterations;
308310
double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
309311
std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
312+
if constexpr (std::is_same_v<ElementA, float_e4m3_t>) {
313+
std::cout << "Datatype: float_e4m3_t"<< std::endl;
314+
} else if constexpr (std::is_same_v<ElementA, float_e5m2_t>) {
315+
std::cout << "Datatype: float_e5m2_t"<< std::endl;
316+
}
310317
printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000);
311318
}
312319

@@ -315,26 +322,9 @@ struct ExampleRunner {
315322

316323
};
317324

318-
int main(int argc, const char** argv)
325+
template<typename ElementType>
326+
int launcher(Options& options)
319327
{
320-
//
321-
// Parse options
322-
//
323-
324-
Options options;
325-
326-
options.parse(argc, argv);
327-
328-
if (options.help) {
329-
options.print_usage(std::cout) << std::endl;
330-
return 0;
331-
}
332-
333-
if (options.error) {
334-
std::cerr << "Aborting execution." << std::endl;
335-
return -1;
336-
}
337-
338328
//
339329
// Run examples
340330
//
@@ -346,10 +336,9 @@ int main(int argc, const char** argv)
346336
bool passed;
347337

348338
using ElementAccumulator = float;
349-
using ElementComputeEpilogue = float;
350-
// TODO: support E5M2
351-
using ElementInputA = cutlass::float_e4m3_t;
352-
using ElementInputB = cutlass::float_e4m3_t;
339+
using ElementComputeEpilogue = float;
340+
using ElementInputA = ElementType;
341+
using ElementInputB = ElementType;
353342
using ElementOutput = float;
354343

355344
using LayoutA = cutlass::layout::RowMajor;
@@ -416,3 +405,26 @@ int main(int argc, const char** argv)
416405

417406
return 0;
418407
}
408+
409+
int main(int argc, const char** argv) {
410+
//
411+
// Parse options
412+
//
413+
414+
Options options;
415+
416+
options.parse(argc, argv);
417+
418+
if (options.help) {
419+
options.print_usage(std::cout) << std::endl;
420+
return 0;
421+
}
422+
423+
if (options.error) {
424+
std::cerr << "Aborting execution." << std::endl;
425+
return -1;
426+
}
427+
launcher<cutlass::float_e5m2_t>(options);
428+
launcher<cutlass::float_e4m3_t>(options);
429+
return 0;
430+
}

include/cutlass/fp8_to_fp16.h

Lines changed: 128 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,128 @@
1-
/***************************************************************************************************
2-
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3-
* SPDX-License-Identifier: BSD-3-Clause
4-
*
5-
* Redistribution and use in source and binary forms, with or without
6-
* modification, are permitted provided that the following conditions are met:
7-
*
8-
* 1. Redistributions of source code must retain the above copyright notice, this
9-
* list of conditions and the following disclaimer.
10-
*
11-
* 2. Redistributions in binary form must reproduce the above copyright notice,
12-
* this list of conditions and the following disclaimer in the documentation
13-
* and/or other materials provided with the distribution.
14-
*
15-
* 3. Neither the name of the copyright holder nor the names of its
16-
* contributors may be used to endorse or promote products derived from
17-
* this software without specific prior written permission.
18-
*
19-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20-
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22-
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25-
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26-
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27-
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28-
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29-
*
30-
**************************************************************************************************/
31-
32-
#pragma once
33-
34-
#include <cutlass/half.h>
35-
#include <cute/util/sycl_vec.hpp>
36-
37-
using half_t = cutlass::half_t;
38-
using uchar16 = cute::intel::uchar16;
39-
using ushort16 = cute::intel::ushort16;
40-
41-
static inline ushort16 convert_ushort16(uchar16 x) {
42-
ushort16 result;
43-
#pragma unroll
44-
for (int i = 0; i < 16; ++i) {
45-
result[i] = static_cast<uint16_t>(x[i]);
46-
}
47-
return result;
48-
}
49-
50-
static inline ushort16 E4M3_to_FP16_vec16(uchar16 xin) {
51-
uchar16 xa = xin & 0x7F;
52-
uchar16 sgn_x = xin ^ xa;
53-
54-
uchar16 zero_mask;
55-
#pragma unroll
56-
for (int i = 0; i < 16; ++i) {
57-
zero_mask[i] = (xa[i] == 0) ? 1 : 0;
58-
}
59-
uchar16 nan_mask = (0x7E - xa) & 0x80;
60-
uchar16 den_mask = ((xa - 8) >> 7) & 0x01;
61-
62-
xa += (nan_mask >> 1);
63-
xa |= (den_mask & 8);
64-
den_mask &= 0x48;
65-
xa += 0x40 & ~(zero_mask * 0x40);
66-
67-
ushort16 x16 = convert_ushort16(xa) << 7;
68-
ushort16 den_corr = convert_ushort16(den_mask & ~zero_mask) << 7;
69-
70-
ushort16 result = x16 - den_corr;
71-
result &= ~(convert_ushort16(zero_mask) << 7);
72-
73-
ushort16 sign_ext = convert_ushort16(sgn_x) << 8;
74-
result ^= sign_ext;
75-
76-
return result;
77-
}
78-
79-
static inline unsigned short E4M3_to_FP16(unsigned char xin) {
80-
unsigned char xa, sgn_x, nan_mask, den_mask;
81-
82-
union {
83-
signed short i;
84-
_Float16 f;
85-
} x16, den_corr;
86-
87-
xa = xin & 0x7f;
88-
sgn_x = xin ^ xa;
89-
90-
// mask for NaN input
91-
nan_mask = (0x7e - xa) & 0x80;
92-
// mask for denormal / zero input
93-
den_mask = (((signed char)(xa - 8)) >> 7);
94-
95-
// apply Nan correction
96-
xa += (nan_mask >> 1);
97-
// first denormal correction
98-
xa |= (den_mask & 8);
99-
den_mask &= 0x48;
100-
// exponent bias correction
101-
xa += 0x40;
102-
103-
// zero-extend to 16 bits
104-
x16.i = xa;
105-
den_corr.i = den_mask;
106-
// FP16 format
107-
x16.i <<= 7;
108-
den_corr.i <<= 7;
109-
110-
// apply correction for denormals/zero
111-
x16.f -= den_corr.f;
112-
113-
// finally, apply the sign
114-
x16.i ^= (((signed short)sgn_x) << 8);
115-
116-
return (unsigned short)x16.i;
117-
}
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
#pragma once
33+
34+
#include <cutlass/half.h>
35+
#include <cute/util/sycl_vec.hpp>
36+
37+
using uchar16 = cute::intel::uchar16;
38+
using ushort16 = cute::intel::ushort16;
39+
40+
static inline ushort16 convert_ushort16(uchar16 x) {
41+
ushort16 result;
42+
#pragma unroll
43+
for (int i = 0; i < 16; ++i) {
44+
result[i] = static_cast<uint16_t>(x[i]);
45+
}
46+
return result;
47+
}
48+
49+
static inline unsigned short E4M3_to_FP16(unsigned char xin) {
50+
unsigned char xa, sgn_x, nan_mask, den_mask;
51+
52+
union {
53+
signed short i;
54+
_Float16 f;
55+
} x16, den_corr;
56+
57+
xa = xin & 0x7f;
58+
sgn_x = xin ^ xa;
59+
60+
// mask for NaN input
61+
nan_mask = (0x7e - xa) & 0x80;
62+
// mask for denormal / zero input
63+
den_mask = (((signed char)(xa - 8)) >> 7);
64+
65+
// apply Nan correction
66+
xa += (nan_mask >> 1);
67+
// first denormal correction
68+
xa |= (den_mask & 8);
69+
den_mask &= 0x48;
70+
// exponent bias correction
71+
xa += 0x40;
72+
73+
// zero-extend to 16 bits
74+
x16.i = xa;
75+
den_corr.i = den_mask;
76+
// FP16 format
77+
x16.i <<= 7;
78+
den_corr.i <<= 7;
79+
80+
// apply correction for denormals/zero
81+
x16.f -= den_corr.f;
82+
83+
// finally, apply the sign
84+
x16.i ^= (((signed short)sgn_x) << 8);
85+
86+
return (unsigned short)x16.i;
87+
}
88+
89+
90+
91+
static inline ushort16 E4M3_to_FP16_chunk16(uchar16 xin) {
92+
uchar16 xa = xin & 0x7F;
93+
uchar16 sgn_x = xin ^ xa;
94+
95+
uchar16 zero_mask;
96+
#pragma unroll
97+
for (int i = 0; i < 16; ++i) {
98+
zero_mask[i] = (xa[i] == 0) ? 1 : 0;
99+
}
100+
uchar16 nan_mask = (0x7E - xa) & 0x80;
101+
uchar16 den_mask = ((xa - 8) >> 7) & 0x01;
102+
103+
xa += (nan_mask >> 1);
104+
xa |= (den_mask & 8);
105+
den_mask &= 0x48;
106+
xa += 0x40 & ~(zero_mask * 0x40);
107+
108+
ushort16 x16 = convert_ushort16(xa) << 7;
109+
ushort16 den_corr = convert_ushort16(den_mask & ~zero_mask) << 7;
110+
111+
ushort16 result = x16 - den_corr;
112+
result &= ~(convert_ushort16(zero_mask) << 7);
113+
114+
ushort16 sign_ext = convert_ushort16(sgn_x) << 8;
115+
result ^= sign_ext;
116+
117+
return result;
118+
}
119+
120+
121+
template<int N>
122+
static inline void E5M2_to_FP16(cutlass::Array<uint8_t, N> const &xin, cutlass::Array<uint16_t, N> &xout) {
123+
// Adapted from https://github.com/pytorch/pytorch/blob/dfcfad2112933cc34247421ac0a4d3f19a1806c1/c10/util/Float8_e5m2.h#L30-L43
124+
CUTLASS_PRAGMA_UNROLL
125+
for (int i = 0; i < N; i++) {
126+
xout[i] = (static_cast<uint16_t>(xin[i])) << 8;
127+
}
128+
}

0 commit comments

Comments
 (0)