@@ -117,83 +117,77 @@ void gemm_nn_vbatch(
117117 int batchCount, cudaStream_t stream,
118118 const T* alpha)
119119{
120- // Phase V4 (FP64 only): route the largest shapes to a 256-thread
121- // 64x64 big tile. The big tile keeps more independent FMA chains in
122- // flight per block, which V100's strong FP64 pipe needs (Little's
123- // Law: ~300 in-flight FP64 FMAs per SM). The FP32 dispatch path is
124- // left untouched (3090 proxy already matches Iter 02 perf).
120+ // FP64 big tile (256-thread 64x64). Little's Law says V100 needs
121+ // ~300 in-flight FP64 FMAs/SM to saturate; the 16x16-thread 4x4
122+ // register tile puts 4096 FMAs/step/block in flight, so one block
123+ // already covers the pipe and the second one hides LDS latency.
125124 if (nn_try_big_tile_ (m, n, k,
126125 A_array_d, lda_d, B_array_d, ldb_d,
127126 C_array_d, ldc_d, batchCount, stream, alpha))
128127 {
129128 return ;
130129 }
131130
132- // 4x4 ladder (16 instantiations), tuned for Ampere:
133- // n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (threshold ladder)
134- // m (bxyz axis) -> BLK_N in {16, 32, 48, 64} (waste-minimizing)
135- // BLK_K fixed at 16 (nw1 axis, <=13 here)
131+ // 4 x 2 ladder (8 instantiations), tuned for V100 / A100:
132+ // n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (smallest full-cover)
133+ // m (bxyz axis) -> BLK_N in {32, 64} (larger-is-better)
134+ // BLK_K fixed at 16 (nw1 axis, <=27)
135+ // DIM_X=8, DIM_Y=16 (128 threads/block, unchanged)
136136 //
137- // After the A/B swap in vbatched_gemm_nn_impl, the kernel's N-axis covers
138- // the bxyz dimension of the output C. Because M = bxyz is a runtime
139- // scalar that varies across benchmark cases (27, 48, 64, 80, 100, 125)
140- // and the register tile THR_N = BLK_N / DIM_Y is unrolled at compile
141- // time, a BLK_N that does not evenly divide bxyz produces fully-computed
142- // but mostly-masked tiles -- pure FMA waste on the under-full last
143- // grid-y block.
137+ // Philosophy vs the prior tail-waste-min ladder:
144138 //
145- // BLK_N is chosen by minimizing (tail_waste, grid_blocks)
146- // lexicographically over the candidate set. This lands bxyz=48 on
147- // BLK_N=48 (1 block, 0 waste) and bxyz=80/100 on BLK_N=16 (many blocks,
148- // 0 waste), while bxyz=64/125 still pick BLK_N=64 and bxyz=27 still
149- // picks BLK_N=32 (same 5-row tail as BLK_N=16 but 1 block instead of 2).
150- // All four BLK_N values satisfy BLK_N % DIM_Y = BLK_N % DIM_YB = 0, so
151- // the shmem-load loops and register tiles compile without changes.
139+ // That ladder picked BLK_N by minimizing (tail_waste, grid_blocks)
140+ // lexicographically. It's the right objective on sm_86 consumer
141+ // Ampere (RTX 3090) where FP64 is 1/64 of FP32 -- every masked FMA
142+ // there is a full FP64-pipe-bound cycle, so minimizing launched
143+ // cells dominates.
144+ //
145+ // On V100 (sm_70) / A100 (sm_80) FP64 is a first-class pipe
146+ // (7.8 / 9.7 TFLOPS peak, ridge ~6-9 FLOP/B), and the inner loop
147+ // is LDS-bound for the nw1/nw2 ranges we see (ncu confirms L1/TEX
148+ // >= 95% on these tiles). The right objective flips to maximizing
149+ // per-block LDS reuse. The wide-LDS inner loop delivers
150+ // FMA / LDS = VK * THR_M * THR_N / (THR_M + THR_N)
151+ // to the shmem pipe; for the scalar-K-tail regime (nw1 < BLK_K,
152+ // hit by nw1 <= 16 on NN) the VK factor drops out but the ratio
153+ // shape is the same. Rough V100 FP64 target ratio is 2 FMAs/LDS
154+ // (32 FP64-FMA/cycle/SM vs LDS.64 delivering one serving/cycle).
155+ //
156+ // At DIM=8x16, BLK_N=64 gives THR_N=4, THR_M=4 -> FMA/LDS = 2.0
157+ // (matched). BLK_N=32 drops it to 1.33 (LDS-bound, FP64 headroom
158+ // unused). So BLK_N=64 is strictly better for every bxyz >= 48;
159+ // only bxyz=27 still prefers BLK_N=32 to cap the N-axis mask waste
160+ // below 50%. The intermediate rungs {16, 48} are dropped: {32, 64}
161+ // covers bxyz in {27, 48, 64, 80, 100, 125} at its LDS-optimal
162+ // point in every case.
163+ //
164+ // BLK_M retains four rungs: the nw2 axis is tiny (<=44 in practice)
165+ // and a wrong-BLK_M costs twice -- masked FMAs *and* a wider sA
166+ // row load per K-step. The 48-rung is kept specifically for nw2=44
167+ // extended-basis atoms (Ti/Mn/Fe/Co/Ni/Cu/Zn/Zr/Ba); otherwise the
168+ // 32-rung falls off to a 2-tile grid at ~31% total waste.
169+ //
170+ // All eight (BLK_M, BLK_N) satisfy the kernel's BLK_M % DIM_X=0
171+ // and BLK_N % DIM_Y=0 constraints, and the tiny-tile {BLK_M=8,
172+ // BLK_N=32} rung still has THR_M=1 which compiles cleanly.
152173 #define NN_DISPATCH (BLK_M_, BLK_N_ ) \
153174 vbatched_gemm_nn_impl<T, 8 , 16 , BLK_M_ , BLK_N_ , 16 , 8 , 16 , 8 , 16 >( \
154175 m, n, k, \
155176 A_array_d, lda_d, B_array_d, ldb_d, \
156177 C_array_d, ldc_d, batchCount, stream, alpha)
157178
158- // VERIFICATION PATCH 2026-04-22: extend BLK_M ladder to include 48 so
159- // nw2 in (32, 48] (e.g. nw2=44 extended-basis atoms) lands on a 1-tile
160- // grid with ~10% waste instead of a 2-tile BLK_M=32 grid with ~45% waste.
161179 const int blk_m_tag = (n <= 8 ) ? 0 : (n <= 16 ) ? 1 : (n <= 32 ) ? 2 : 3 ;
180+ const int blk_n_tag = (m < 48 ) ? 0 : 1 ; // {32, 64}
162181
163- int blk_n_tag = 0 ;
164- {
165- constexpr int cands[4 ] = {16 , 32 , 48 , 64 };
166- int best_waste = ((m + cands[0 ] - 1 ) / cands[0 ]) * cands[0 ] - m;
167- int best_blocks = (m + cands[0 ] - 1 ) / cands[0 ];
168- for (int i = 1 ; i < 4 ; ++i) {
169- const int blocks = (m + cands[i] - 1 ) / cands[i];
170- const int waste = blocks * cands[i] - m;
171- if (waste < best_waste ||
172- (waste == best_waste && blocks < best_blocks)) {
173- best_waste = waste;
174- best_blocks = blocks;
175- blk_n_tag = i;
176- }
177- }
178- }
179-
180- switch (blk_m_tag * 4 + blk_n_tag) {
181- case 0 : NN_DISPATCH ( 8 , 16 ); break ;
182- case 1 : NN_DISPATCH ( 8 , 32 ); break ;
183- case 2 : NN_DISPATCH ( 8 , 48 ); break ;
184- case 3 : NN_DISPATCH ( 8 , 64 ); break ;
185- case 4 : NN_DISPATCH (16 , 16 ); break ;
186- case 5 : NN_DISPATCH (16 , 32 ); break ;
187- case 6 : NN_DISPATCH (16 , 48 ); break ;
188- case 7 : NN_DISPATCH (16 , 64 ); break ;
189- case 8 : NN_DISPATCH (32 , 16 ); break ;
190- case 9 : NN_DISPATCH (32 , 32 ); break ;
191- case 10 : NN_DISPATCH (32 , 48 ); break ;
192- case 11 : NN_DISPATCH (32 , 64 ); break ;
193- case 12 : NN_DISPATCH (48 , 16 ); break ;
194- case 13 : NN_DISPATCH (48 , 32 ); break ;
195- case 14 : NN_DISPATCH (48 , 48 ); break ;
196- case 15 : NN_DISPATCH (48 , 64 ); break ;
182+ switch (blk_m_tag * 2 + blk_n_tag) {
183+ case 0 : NN_DISPATCH ( 8 , 32 ); break ;
184+ case 1 : NN_DISPATCH ( 8 , 64 ); break ;
185+ case 2 : NN_DISPATCH (16 , 32 ); break ;
186+ case 3 : NN_DISPATCH (16 , 64 ); break ;
187+ case 4 : NN_DISPATCH (32 , 32 ); break ;
188+ case 5 : NN_DISPATCH (32 , 64 ); break ;
189+ case 6 : NN_DISPATCH (48 , 32 ); break ;
190+ case 7 : NN_DISPATCH (48 , 64 ); break ;
197191 }
198192
199193 #undef NN_DISPATCH
@@ -208,41 +202,52 @@ void gemm_tn_vbatch(
208202 int batchCount, cudaStream_t stream,
209203 const T* alpha)
210204{
211- // Phase V4 (FP64 only): 256-thread 64x64 big tile for nw1 >= 48 &&
212- // nw2 >= 48 (axis flip vs NN: kernel M is wrapper n = nw2, kernel N
213- // is wrapper m = nw1, so the per-axis check is symmetric at 48).
205+ // FP64 big tile (256-thread 64x64). Symmetric n>=48 && m>=48
206+ // because, after the kernel's A/B swap, both output axes are small
207+ // (kernel M = wrapper n = nw2, kernel N = wrapper m = nw1) and
208+ // neither is intrinsically larger than the other.
214209 if (tn_try_big_tile_ (m, n, k,
215210 A_array_d, lda_d, B_array_d, ldb_d,
216211 C_array_d, ldc_d, batchCount, stream, alpha))
217212 {
218213 return ;
219214 }
220215
221- // 4x4 ladder (16 instantiations), tuned for A100:
216+ // 4 x 4 ladder (16 instantiations), tuned for V100 / A100:
222217 // n (nw2 axis) -> BLK_M in {8, 16, 32, 48}
223218 // m (nw1 axis) -> BLK_N in {8, 16, 32, 48}
224219 // BLK_K fixed at 32 (bxyz axis)
220+ // DIM_X=8, DIM_Y=8 (64 threads/block)
221+ //
222+ // Smallest-covering-tile selection, symmetric in both axes. This
223+ // is *not* the same choice as NN -- on TN both output axes are
224+ // small (nw1, nw2 in {4, 9, 13, 27, 44}) and neither is long
225+ // enough to amortize the "prefer bigger" BLK_N logic from NN.
226+ // Doubling BLK_* here would just push nw=4/9/13 cases off their
227+ // exact-fit tile into a 2-4x mask-waste regime with no LDS-reuse
228+ // upside (both axes of the output are already covered by one tile
229+ // in this regime; a bigger tile just adds masked FMAs).
230+ //
231+ // The 48-rung covers nw=44 extended-basis atoms (Ti/Mn/Fe/Co/Ni/
232+ // Cu/Zn/Zr/Ba) at ~8% mask waste per axis; without it those cases
233+ // fall to a 2-tile BLK=32 grid at ~52% cell-launch waste.
225234 //
226- // BLK_K is not split by bxyz: the K-axis tail wastes only shmem loads
227- // (not FMAs), so a single BLK_K keeps the template table small while
228- // still covering bxyz in [27, 125] via ceil(bxyz/32) K-tiles. bxyz=27
229- // fits in one tile (5/32 = 16% load waste); larger bxyz wraps into
230- // 2-4 K-tiles with modest __syncthreads() overhead.
235+ // BLK_K=32 (larger than NN's 16) because K = bxyz here is large
236+ // (27-125) and the K-axis tail wastes only shmem loads, never
237+ // masked FMAs on the output -- bxyz <= 32 fits in one K-tile,
238+ // larger bxyz wraps into 2-4 K-tiles. The modest __syncthreads()
239+ // overhead from more K-tiles is cheaper than doubling BLK_K and
240+ // forcing a re-tune of the `ra/rb` double-buffer register budget.
231241 //
232- // Block shape is DIM_X=8 x DIM_Y=8 (64 threads). Every ( BLK_M, BLK_N)
233- // pair is divisible by DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8,
234- // so all nine combinations compile to valid kernels .
242+ // All 16 ( BLK_M, BLK_N) pairs are divisible by
243+ // DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8, so every
244+ // instantiation compiles to a valid kernel .
235245 #define TN_DISPATCH (BLK_M_, BLK_N_ ) \
236246 vbatched_gemm_tn_impl<T, 8 , 8 , BLK_M_ , BLK_N_ , 32 , 8 , 8 , 8 , 8 >( \
237247 m, n, k, \
238248 A_array_d, lda_d, B_array_d, ldb_d, \
239249 C_array_d, ldc_d, batchCount, stream, alpha)
240250
241- // VERIFICATION PATCH 2026-04-22: extend both BLK_M and BLK_N ladders up
242- // to 48 so that nw in (32, 48] (extended-basis nw=44 atoms: Ti/Mn/Fe/Co/
243- // Ni/Cu/Zn/Zr/Ba) lands on a 1-tile grid per axis (48^2 cells for 44^2
244- // output, ~19% waste) instead of a 2-tile BLK_M=32 grid (64^2 cells,
245- // ~52% waste).
246251 auto tag_for = [](int x) {
247252 return (x <= 8 ) ? 0 : (x <= 16 ) ? 1 : (x <= 32 ) ? 2 : 3 ;
248253 };
0 commit comments