Skip to content

Commit 5a01d6d

Browse files
dzzz2001claude
andcommitted
perf(gint): NN BLK_N ladder -> {32, 64} for V100/A100 LDS-bound FP64
Replace the dynamic tail-waste-minimizing BLK_N selection (over {16, 32, 48, 64}) in gemm_nn_vbatch with a static threshold keeping only {32, 64}. On sm_70/sm_80 the FP64 pipe is first-class and the inner loop is LDS-bound, so per-block LDS reuse (FMA/LDS ratio) dominates masked-FMA waste. At DIM=8x16, BLK_N=64 delivers THR_M*THR_N/(THR_M+THR_N) = 2.0 FMA/LDS -- matched to V100's FP64 throughput; BLK_N=32 drops to 1.33 and underfeeds the pipe. BLK_N=48 is dropped for the same reason; BLK_N=16 is dropped because it had no LDS-optimal regime left once we commit to FP64-heavy tiles. gemm_tn_vbatch keeps its 4x4 {8,16,32,48}^2 ladder unchanged -- both TN output axes (nw1, nw2) are small, so a bigger tile only adds mask waste without unlocking LDS reuse (a single tile already covers the output in the common regime). Comments for both kernels are rewritten to spell out the NN-vs-TN asymmetry and the V100/A100 tuning rationale. Expected impact: speedup on V100/A100 for bxyz in {48, 80, 100} which previously landed on BLK_N in {48, 16, 16}; likely regression on RTX 3090 (Ampere consumer, FP64 1/64 of FP32) where masked-FMA waste dominates LDS bandwidth. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 475b227 commit 5a01d6d

1 file changed

Lines changed: 82 additions & 77 deletions

File tree

source/source_lcao/module_gint/kernel/dgemm_vbatch.cu

Lines changed: 82 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -117,83 +117,77 @@ void gemm_nn_vbatch(
117117
int batchCount, cudaStream_t stream,
118118
const T* alpha)
119119
{
120-
// Phase V4 (FP64 only): route the largest shapes to a 256-thread
121-
// 64x64 big tile. The big tile keeps more independent FMA chains in
122-
// flight per block, which V100's strong FP64 pipe needs (Little's
123-
// Law: ~300 in-flight FP64 FMAs per SM). The FP32 dispatch path is
124-
// left untouched (3090 proxy already matches Iter 02 perf).
120+
// FP64 big tile (256-thread 64x64). Little's Law says V100 needs
121+
// ~300 in-flight FP64 FMAs/SM to saturate; the 16x16-thread 4x4
122+
// register tile puts 4096 FMAs/step/block in flight, so one block
123+
// already covers the pipe and the second one hides LDS latency.
125124
if (nn_try_big_tile_(m, n, k,
126125
A_array_d, lda_d, B_array_d, ldb_d,
127126
C_array_d, ldc_d, batchCount, stream, alpha))
128127
{
129128
return;
130129
}
131130

132-
// 4x4 ladder (16 instantiations), tuned for Ampere:
133-
// n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (threshold ladder)
134-
// m (bxyz axis) -> BLK_N in {16, 32, 48, 64} (waste-minimizing)
135-
// BLK_K fixed at 16 (nw1 axis, <=13 here)
131+
// 4 x 2 ladder (8 instantiations), tuned for V100 / A100:
132+
// n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (smallest full-cover)
133+
// m (bxyz axis) -> BLK_N in {32, 64} (larger-is-better)
134+
// BLK_K fixed at 16 (nw1 axis, <=27)
135+
// DIM_X=8, DIM_Y=16 (128 threads/block, unchanged)
136136
//
137-
// After the A/B swap in vbatched_gemm_nn_impl, the kernel's N-axis covers
138-
// the bxyz dimension of the output C. Because M = bxyz is a runtime
139-
// scalar that varies across benchmark cases (27, 48, 64, 80, 100, 125)
140-
// and the register tile THR_N = BLK_N / DIM_Y is unrolled at compile
141-
// time, a BLK_N that does not evenly divide bxyz produces fully-computed
142-
// but mostly-masked tiles -- pure FMA waste on the under-full last
143-
// grid-y block.
137+
// Philosophy vs the prior tail-waste-min ladder:
144138
//
145-
// BLK_N is chosen by minimizing (tail_waste, grid_blocks)
146-
// lexicographically over the candidate set. This lands bxyz=48 on
147-
// BLK_N=48 (1 block, 0 waste) and bxyz=80/100 on BLK_N=16 (many blocks,
148-
// 0 waste), while bxyz=64/125 still pick BLK_N=64 and bxyz=27 still
149-
// picks BLK_N=32 (same 5-row tail as BLK_N=16 but 1 block instead of 2).
150-
// All four BLK_N values satisfy BLK_N % DIM_Y = BLK_N % DIM_YB = 0, so
151-
// the shmem-load loops and register tiles compile without changes.
139+
// That ladder picked BLK_N by minimizing (tail_waste, grid_blocks)
140+
// lexicographically. It's the right objective on sm_86 consumer
141+
// Ampere (RTX 3090) where FP64 is 1/64 of FP32 -- every masked FMA
142+
// there is a full FP64-pipe-bound cycle, so minimizing launched
143+
// cells dominates.
144+
//
145+
// On V100 (sm_70) / A100 (sm_80) FP64 is a first-class pipe
146+
// (7.8 / 9.7 TFLOPS peak, ridge ~6-9 FLOP/B), and the inner loop
147+
// is LDS-bound for the nw1/nw2 ranges we see (ncu confirms L1/TEX
148+
// >= 95% on these tiles). The right objective flips to maximizing
149+
// per-block LDS reuse. The wide-LDS inner loop delivers
150+
// FMA / LDS = VK * THR_M * THR_N / (THR_M + THR_N)
151+
// to the shmem pipe; for the scalar-K-tail regime (nw1 < BLK_K,
152+
// hit by nw1 <= 16 on NN) the VK factor drops out but the ratio
153+
// shape is the same. Rough V100 FP64 target ratio is 2 FMAs/LDS
154+
// (32 FP64-FMA/cycle/SM vs LDS.64 delivering one serving/cycle).
155+
//
156+
// At DIM=8x16, BLK_N=64 gives THR_N=4, THR_M=4 -> FMA/LDS = 2.0
157+
// (matched). BLK_N=32 drops it to 1.33 (LDS-bound, FP64 headroom
158+
// unused). So BLK_N=64 is strictly better for every bxyz >= 48;
159+
// only bxyz=27 still prefers BLK_N=32 to cap the N-axis mask waste
160+
// below 50%. The intermediate rungs {16, 48} are dropped: {32, 64}
161+
// covers bxyz in {27, 48, 64, 80, 100, 125} at its LDS-optimal
162+
// point in every case.
163+
//
164+
// BLK_M retains four rungs: the nw2 axis is tiny (<=44 in practice)
165+
// and a wrong-BLK_M costs twice -- masked FMAs *and* a wider sA
166+
// row load per K-step. The 48-rung is kept specifically for nw2=44
167+
// extended-basis atoms (Ti/Mn/Fe/Co/Ni/Cu/Zn/Zr/Ba); otherwise the
168+
// 32-rung falls off to a 2-tile grid at ~31% total waste.
169+
//
170+
// All eight (BLK_M, BLK_N) satisfy the kernel's BLK_M % DIM_X=0
171+
// and BLK_N % DIM_Y=0 constraints, and the tiny-tile {BLK_M=8,
172+
// BLK_N=32} rung still has THR_M=1 which compiles cleanly.
152173
#define NN_DISPATCH(BLK_M_, BLK_N_) \
153174
vbatched_gemm_nn_impl<T, 8, 16, BLK_M_, BLK_N_, 16, 8, 16, 8, 16>( \
154175
m, n, k, \
155176
A_array_d, lda_d, B_array_d, ldb_d, \
156177
C_array_d, ldc_d, batchCount, stream, alpha)
157178

158-
// VERIFICATION PATCH 2026-04-22: extend BLK_M ladder to include 48 so
159-
// nw2 in (32, 48] (e.g. nw2=44 extended-basis atoms) lands on a 1-tile
160-
// grid with ~10% waste instead of a 2-tile BLK_M=32 grid with ~45% waste.
161179
const int blk_m_tag = (n <= 8) ? 0 : (n <= 16) ? 1 : (n <= 32) ? 2 : 3;
180+
const int blk_n_tag = (m < 48) ? 0 : 1; // {32, 64}
162181

163-
int blk_n_tag = 0;
164-
{
165-
constexpr int cands[4] = {16, 32, 48, 64};
166-
int best_waste = ((m + cands[0] - 1) / cands[0]) * cands[0] - m;
167-
int best_blocks = (m + cands[0] - 1) / cands[0];
168-
for (int i = 1; i < 4; ++i) {
169-
const int blocks = (m + cands[i] - 1) / cands[i];
170-
const int waste = blocks * cands[i] - m;
171-
if (waste < best_waste ||
172-
(waste == best_waste && blocks < best_blocks)) {
173-
best_waste = waste;
174-
best_blocks = blocks;
175-
blk_n_tag = i;
176-
}
177-
}
178-
}
179-
180-
switch (blk_m_tag * 4 + blk_n_tag) {
181-
case 0: NN_DISPATCH( 8, 16); break;
182-
case 1: NN_DISPATCH( 8, 32); break;
183-
case 2: NN_DISPATCH( 8, 48); break;
184-
case 3: NN_DISPATCH( 8, 64); break;
185-
case 4: NN_DISPATCH(16, 16); break;
186-
case 5: NN_DISPATCH(16, 32); break;
187-
case 6: NN_DISPATCH(16, 48); break;
188-
case 7: NN_DISPATCH(16, 64); break;
189-
case 8: NN_DISPATCH(32, 16); break;
190-
case 9: NN_DISPATCH(32, 32); break;
191-
case 10: NN_DISPATCH(32, 48); break;
192-
case 11: NN_DISPATCH(32, 64); break;
193-
case 12: NN_DISPATCH(48, 16); break;
194-
case 13: NN_DISPATCH(48, 32); break;
195-
case 14: NN_DISPATCH(48, 48); break;
196-
case 15: NN_DISPATCH(48, 64); break;
182+
switch (blk_m_tag * 2 + blk_n_tag) {
183+
case 0: NN_DISPATCH( 8, 32); break;
184+
case 1: NN_DISPATCH( 8, 64); break;
185+
case 2: NN_DISPATCH(16, 32); break;
186+
case 3: NN_DISPATCH(16, 64); break;
187+
case 4: NN_DISPATCH(32, 32); break;
188+
case 5: NN_DISPATCH(32, 64); break;
189+
case 6: NN_DISPATCH(48, 32); break;
190+
case 7: NN_DISPATCH(48, 64); break;
197191
}
198192

199193
#undef NN_DISPATCH
@@ -208,41 +202,52 @@ void gemm_tn_vbatch(
208202
int batchCount, cudaStream_t stream,
209203
const T* alpha)
210204
{
211-
// Phase V4 (FP64 only): 256-thread 64x64 big tile for nw1 >= 48 &&
212-
// nw2 >= 48 (axis flip vs NN: kernel M is wrapper n = nw2, kernel N
213-
// is wrapper m = nw1, so the per-axis check is symmetric at 48).
205+
// FP64 big tile (256-thread 64x64). Symmetric n>=48 && m>=48
206+
// because, after the kernel's A/B swap, both output axes are small
207+
// (kernel M = wrapper n = nw2, kernel N = wrapper m = nw1) and
208+
// neither is intrinsically larger than the other.
214209
if (tn_try_big_tile_(m, n, k,
215210
A_array_d, lda_d, B_array_d, ldb_d,
216211
C_array_d, ldc_d, batchCount, stream, alpha))
217212
{
218213
return;
219214
}
220215

221-
// 4x4 ladder (16 instantiations), tuned for A100:
216+
// 4 x 4 ladder (16 instantiations), tuned for V100 / A100:
222217
// n (nw2 axis) -> BLK_M in {8, 16, 32, 48}
223218
// m (nw1 axis) -> BLK_N in {8, 16, 32, 48}
224219
// BLK_K fixed at 32 (bxyz axis)
220+
// DIM_X=8, DIM_Y=8 (64 threads/block)
221+
//
222+
// Smallest-covering-tile selection, symmetric in both axes. This
223+
// is *not* the same choice as NN -- on TN both output axes are
224+
// small (nw1, nw2 in {4, 9, 13, 27, 44}) and neither is long
225+
// enough to amortize the "prefer bigger" BLK_N logic from NN.
226+
// Doubling BLK_* here would just push nw=4/9/13 cases off their
227+
// exact-fit tile into a 2-4x mask-waste regime with no LDS-reuse
228+
// upside (both axes of the output are already covered by one tile
229+
// in this regime; a bigger tile just adds masked FMAs).
230+
//
231+
// The 48-rung covers nw=44 extended-basis atoms (Ti/Mn/Fe/Co/Ni/
232+
// Cu/Zn/Zr/Ba) at ~8% mask waste per axis; without it those cases
233+
// fall to a 2-tile BLK=32 grid at ~52% cell-launch waste.
225234
//
226-
// BLK_K is not split by bxyz: the K-axis tail wastes only shmem loads
227-
// (not FMAs), so a single BLK_K keeps the template table small while
228-
// still covering bxyz in [27, 125] via ceil(bxyz/32) K-tiles. bxyz=27
229-
// fits in one tile (5/32 = 16% load waste); larger bxyz wraps into
230-
// 2-4 K-tiles with modest __syncthreads() overhead.
235+
// BLK_K=32 (larger than NN's 16) because K = bxyz here is large
236+
// (27-125) and the K-axis tail wastes only shmem loads, never
237+
// masked FMAs on the output -- bxyz <= 32 fits in one K-tile,
238+
// larger bxyz wraps into 2-4 K-tiles. The modest __syncthreads()
239+
// overhead from more K-tiles is cheaper than doubling BLK_K and
240+
// forcing a re-tune of the `ra/rb` double-buffer register budget.
231241
//
232-
// Block shape is DIM_X=8 x DIM_Y=8 (64 threads). Every (BLK_M, BLK_N)
233-
// pair is divisible by DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8,
234-
// so all nine combinations compile to valid kernels.
242+
// All 16 (BLK_M, BLK_N) pairs are divisible by
243+
// DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8, so every
244+
// instantiation compiles to a valid kernel.
235245
#define TN_DISPATCH(BLK_M_, BLK_N_) \
236246
vbatched_gemm_tn_impl<T, 8, 8, BLK_M_, BLK_N_, 32, 8, 8, 8, 8>( \
237247
m, n, k, \
238248
A_array_d, lda_d, B_array_d, ldb_d, \
239249
C_array_d, ldc_d, batchCount, stream, alpha)
240250

241-
// VERIFICATION PATCH 2026-04-22: extend both BLK_M and BLK_N ladders up
242-
// to 48 so that nw in (32, 48] (extended-basis nw=44 atoms: Ti/Mn/Fe/Co/
243-
// Ni/Cu/Zn/Zr/Ba) lands on a 1-tile grid per axis (48^2 cells for 44^2
244-
// output, ~19% waste) instead of a 2-tile BLK_M=32 grid (64^2 cells,
245-
// ~52% waste).
246251
auto tag_for = [](int x) {
247252
return (x <= 8) ? 0 : (x <= 16) ? 1 : (x <= 32) ? 2 : 3;
248253
};

0 commit comments

Comments
 (0)