Skip to content

Commit f36600c

Browse files
sanchitintelt4c1
andauthored
Add prefetching code in FP8 GEMM mainloop (#304)
## Summary Upon copy-pasting the A, B tile prefetching code from `xe_mma.hpp` to `xe_mma_w8a8.hpp` to make the two files even more similar (both are almost same, except that xe_mma_w8a8.hpp converts FP8 `A` & `B` to FP16, but the issue of whether or not it's possible to refactor & merge both files is beyond this PR's scope), I noticed a performance boost of ~16% for many input shapes on Intel GPU Max 1550. ## Performance data On PVC 1550, with dpcpp nightly of March 23, 2025. Used the existing benchmark, e.g. `./examples/sycl/08_pvc_gemm_f8/08_pvc_gemm_f8 --iterations=1 --m=1024 --n=7168 --k=128` I benchmarked by running a GEMM problem only once to verify that the change (adding prefetching) indeed resulted in a perf boost (an alternative would've been multiple iterations after cache flushes between each iteration. A digression - in real workloads, though, the activation is likely to be in cache while the weights are likely to be in global memory before a linear op, for example, and that scenario isn't simulated by either of the two approaches I considered). | M | N | K | L | Latency of one invocation before this change | Latency of one invocation after this change|Speedup| |--|--|--|--|-----|-----|---| |1024|1536|7168|1|3.76 ms |3.2304 ms | 1.16x | |1024|1536|1536|1|0.824 ms |0.7034 ms | 1.17x| |1024|576|7168|1|3.75 ms | 3.2274 ms| 1.16x | |1024|2048|512|1|0.2853 ms |0.2458 ms | 1.16x | |1024|7168|1024|1|1.54 ms | 1.2762 ms| 1.20x | |1024|256|7168|1| 3.76 ms| 3.2237 ms| 1.16x | |1024|7168|128|1|0.2270 ms |0.1997 ms | 1.13x | |1|1536|7168|1| 3.76 ms|3.1790 ms | 1.18x | |1|1536|1536|1|0.8206 ms |0.6925 ms | 1.18x | |1|576|7168|1| 3.76 ms| 2.7831 ms| 1.352x | |1|2048|512|1| 0.2802 ms| 0.2413 ms| 1.16x | |1|7168|1024|1|0.5504 ms |0.4669 ms | 1.17x | |1|256|7168|1| 3.7701 ms| 3.1678 ms| 1.19x | |1|7168|128|1|0.0733 ms | 0.0683 ms| 1.07x | ## Build instructions ``` export IGC_ExtraOCLOptions="-cl-intel-256-GRF-per-thread" export IGC_VectorAliasBBThreshold=1200 export IGC_VISAOptions="-perfmodel" mkdir build; cd build; CC=clang CXX=clang++ cmake .. -GNinja -DCUTLASS_ENABLE_EXAMPLES=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUTLASS_ENABLE_SYCL=ON -DCUTLASS_SYCL_PROFILING_ENABLED=ON -DDPCPP_SYCL_TARGET=intel_gpu_pvc -DCUTLASS_ENABLE_BENCHMARKS=OFF -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fdiagnostics-color=always" ``` cc @pengzhao-intel Thanks! Co-authored-by: Tadej Ciglarič <[email protected]>
1 parent 5e57dea commit f36600c

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

include/cutlass/gemm/collective/xe_mma_w8a8.hpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ struct CollectiveMma<MainloopIntelW8A8<Stages, Schedule>, TileShape_, ElementA_,
102102
static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K);
103103

104104
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;
105+
106+
// 32
107+
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K;
105108
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
106109

107110
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
@@ -238,15 +241,31 @@ struct CollectiveMma<MainloopIntelW8A8<Stages, Schedule>, TileShape_, ElementA_,
238241
// Retile global tile for copies
239242
Tensor tAgA = thr_copy_A.retile_S(tCgA);
240243
Tensor tBgB = thr_copy_B.retile_S(tCgB);
244+
245+
auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_a);
246+
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_b);
247+
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
248+
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
249+
250+
// Partition global tile for prefetch
251+
auto pAgA = thr_prefetch_A.partition_S(gA);
252+
auto pBgB = thr_prefetch_B.partition_S(gB);
241253

242254
//
243255
// Mainloop
244256
//
245257
const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start));
246258
constexpr int barrier_scope = 2;
259+
int prefetch_k = k_start_idx;
260+
261+
CUTLASS_PRAGMA_UNROLL
262+
for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) {
263+
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
264+
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
265+
}
247266

248267
CUTLASS_PRAGMA_UNROLL
249-
for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++) {
268+
for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
250269
barrier_arrive(barrier_scope);
251270

252271
// copy fp8 into uint8
@@ -257,6 +276,11 @@ struct CollectiveMma<MainloopIntelW8A8<Stages, Schedule>, TileShape_, ElementA_,
257276
convert_E4M3_to_FP16(tCrA, tCrA_fp16);
258277
convert_E4M3_to_FP16(tCrB, tCrB_fp16);
259278

279+
if (prefetch_k < k_tile_count) {
280+
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
281+
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
282+
}
283+
260284
// compute using fp16
261285
cute::gemm(tiled_mma, tCrA_fp16, tCrB_fp16, accum);
262286

0 commit comments

Comments
 (0)