33#include " dgemm_vbatch.h"
44#include " source_base/module_device/device.h"
55
6- // ----------------------------------------------------------------------------
7- // FP64-only big-tile dispatch (Phase V4)
8- // ----------------------------------------------------------------------------
9- // Pattern mirrors the C++11-compatible overload trick used in
10- // gint_vl.cpp / gint_rho.cpp (note in those files: "C++11-compatible
11- // alternative to if constexpr"). The non-template double overload is the
12- // preferred candidate when T = double; the template fallback returns
13- // false for every other dtype so the FP32 path stays untouched (per the
14- // V100 plan: 3090 FP32 was the test proxy and the big tile was not
15- // validated on FP32).
16- // Returns true if the big tile dispatched (caller should `return`).
17- // ----------------------------------------------------------------------------
18-
19- inline bool nn_try_big_tile_ (
20- int m, int n, int k,
21- const double * const * A_array_d, const int * lda_d,
22- const double * const * B_array_d, const int * ldb_d,
23- double ** C_array_d, const int * ldc_d,
24- int batchCount, cudaStream_t stream, const double * alpha)
25- {
26- // 16x16 threads = 256, BLK_M=BLK_N=64, BLK_K=16. THR_M = THR_N = 4
27- // -> 16 FMAs per inner step, 32 with VK=2. Loaders use DIM_*A=DIM_*B=16.
28- if (n >= 48 && m >= 64 ) {
29- vbatched_gemm_nn_impl<double ,
30- /* DIM_X */ 16 , /* DIM_Y */ 16 ,
31- /* BLK_M */ 64 , /* BLK_N */ 64 , /* BLK_K*/ 16 ,
32- /* DIM_XA*/ 16 , /* DIM_YA*/ 16 ,
33- /* DIM_XB*/ 16 , /* DIM_YB*/ 16 >(
34- m, n, k,
35- A_array_d, lda_d, B_array_d, ldb_d,
36- C_array_d, ldc_d, batchCount, stream, alpha);
37- return true ;
38- }
39- return false ;
40- }
41-
42- template <typename T>
43- inline bool nn_try_big_tile_ (
44- int /* m*/ , int /* n*/ , int /* k*/ ,
45- const T* const * /* A*/ , const int * /* lda*/ ,
46- const T* const * /* B*/ , const int * /* ldb*/ ,
47- T** /* C*/ , const int * /* ldc*/ ,
48- int /* batch*/ , cudaStream_t /* stream*/ , const T* /* alpha*/ )
49- {
50- return false ;
51- }
52-
53- inline bool tn_try_big_tile_ (
54- int m, int n, int k,
55- const double * const * A_array_d, const int * lda_d,
56- const double * const * B_array_d, const int * ldb_d,
57- double ** C_array_d, const int * ldc_d,
58- int batchCount, cudaStream_t stream, const double * alpha)
59- {
60- // Axis flip vs NN: kernel M = wrapper n = nw2, kernel N = wrapper m = nw1.
61- // Threshold is symmetric at 48 in both axes.
62- // BLK_K=16 (not 32 as in the existing TN ladder) keeps the big-tile shmem
63- // footprint at ~18 KB/block so 4 blocks/SM still fit on V100's 96 KB.
64- if (n >= 48 && m >= 48 ) {
65- vbatched_gemm_tn_impl<double ,
66- /* DIM_X */ 16 , /* DIM_Y */ 16 ,
67- /* BLK_M */ 64 , /* BLK_N */ 64 , /* BLK_K*/ 16 ,
68- /* DIM_XA*/ 16 , /* DIM_YA*/ 16 ,
69- /* DIM_XB*/ 16 , /* DIM_YB*/ 16 >(
70- m, n, k,
71- A_array_d, lda_d, B_array_d, ldb_d,
72- C_array_d, ldc_d, batchCount, stream, alpha);
73- return true ;
74- }
75- return false ;
76- }
77-
78- template <typename T>
79- inline bool tn_try_big_tile_ (
80- int /* m*/ , int /* n*/ , int /* k*/ ,
81- const T* const * /* A*/ , const int * /* lda*/ ,
82- const T* const * /* B*/ , const int * /* ldb*/ ,
83- T** /* C*/ , const int * /* ldc*/ ,
84- int /* batch*/ , cudaStream_t /* stream*/ , const T* /* alpha*/ )
85- {
86- return false ;
87- }
88-
89- // ----------------------------------------------------------------------------
90- // Shape-exact dispatch
91- // ----------------------------------------------------------------------------
92- //
93- // The caller (phi_operator_gpu.cu) buckets atom pairs by (nw1, nw2) so every
94- // item in a batch has exactly the same (m, n, k). The scalars passed here are
95- // the *exact* per-matrix shapes (not upper bounds), which lets the tile
96- // ladder pick the tightest template and sizes the grid tightly (no
97- // over-launched blocks that short-circuit inside the kernel).
98- //
99- // Kernel-level dimension mapping (after the A/B swap inside
100- // vbatched_gemm_*_impl):
101- //
102- // call | wrapper m | wrapper n | wrapper k
103- // --------|--------------|-------------|---------------
104- // NN | bxyz (large) | nw2 (small) | nw1 (small)
105- // TN | nw1 (small) | nw2 (small) | bxyz (large)
106- //
107- // (m, n, k) flow through as scalars all the way down into the kernel, so
108- // there is no per-batchid M/N/K load and no fill-kernel scratch buffer.
109- // ----------------------------------------------------------------------------
110-
1116template <typename T>
1127void gemm_nn_vbatch (
1138 int m, int n, int k,
@@ -117,77 +12,45 @@ void gemm_nn_vbatch(
11712 int batchCount, cudaStream_t stream,
11813 const T* alpha)
11914{
120- // FP64 big tile (256-thread 64x64). Little's Law says V100 needs
121- // ~300 in-flight FP64 FMAs/SM to saturate; the 16x16-thread 4x4
122- // register tile puts 4096 FMAs/step/block in flight, so one block
123- // already covers the pipe and the second one hides LDS latency.
124- if (nn_try_big_tile_ (m, n, k,
125- A_array_d, lda_d, B_array_d, ldb_d,
126- C_array_d, ldc_d, batchCount, stream, alpha))
127- {
128- return ;
129- }
130-
131- // 4 x 2 ladder (8 instantiations), tuned for V100 / A100:
132- // n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (smallest full-cover)
133- // m (bxyz axis) -> BLK_N in {32, 64} (larger-is-better)
134- // BLK_K fixed at 16 (nw1 axis, <=27)
135- // DIM_X=8, DIM_Y=16 (128 threads/block, unchanged)
136- //
137- // Philosophy vs the prior tail-waste-min ladder:
138- //
139- // That ladder picked BLK_N by minimizing (tail_waste, grid_blocks)
140- // lexicographically. It's the right objective on sm_86 consumer
141- // Ampere (RTX 3090) where FP64 is 1/64 of FP32 -- every masked FMA
142- // there is a full FP64-pipe-bound cycle, so minimizing launched
143- // cells dominates.
144- //
145- // On V100 (sm_70) / A100 (sm_80) FP64 is a first-class pipe
146- // (7.8 / 9.7 TFLOPS peak, ridge ~6-9 FLOP/B), and the inner loop
147- // is LDS-bound for the nw1/nw2 ranges we see (ncu confirms L1/TEX
148- // >= 95% on these tiles). The right objective flips to maximizing
149- // per-block LDS reuse. The wide-LDS inner loop delivers
150- // FMA / LDS = VK * THR_M * THR_N / (THR_M + THR_N)
151- // to the shmem pipe; for the scalar-K-tail regime (nw1 < BLK_K,
152- // hit by nw1 <= 16 on NN) the VK factor drops out but the ratio
153- // shape is the same. Rough V100 FP64 target ratio is 2 FMAs/LDS
154- // (32 FP64-FMA/cycle/SM vs LDS.64 delivering one serving/cycle).
15+ // 4 (nw2 bracket) x 2 (bxyz bracket) = 8 instantiations.
15516 //
156- // At DIM=8x16, BLK_N=64 gives THR_N=4, THR_M=4 -> FMA/LDS = 2.0
157- // (matched). BLK_N=32 drops it to 1.33 (LDS-bound, FP64 headroom
158- // unused). So BLK_N=64 is strictly better for every bxyz >= 48;
159- // only bxyz=27 still prefers BLK_N=32 to cap the N-axis mask waste
160- // below 50%. The intermediate rungs {16, 48} are dropped: {32, 64}
161- // covers bxyz in {27, 48, 64, 80, 100, 125} at its LDS-optimal
162- // point in every case.
163- //
164- // BLK_M retains four rungs: the nw2 axis is tiny (<=44 in practice)
165- // and a wrong-BLK_M costs twice -- masked FMAs *and* a wider sA
166- // row load per K-step. The 48-rung is kept specifically for nw2=44
167- // extended-basis atoms (Ti/Mn/Fe/Co/Ni/Cu/Zn/Zr/Ba); otherwise the
168- // 32-rung falls off to a 2-tile grid at ~31% total waste.
169- //
170- // All eight (BLK_M, BLK_N) satisfy the kernel's BLK_M % DIM_X=0
171- // and BLK_N % DIM_Y=0 constraints, and the tiny-tile {BLK_M=8,
172- // BLK_N=32} rung still has THR_M=1 which compiles cleanly.
173- #define NN_DISPATCH (BLK_M_, BLK_N_ ) \
174- vbatched_gemm_nn_impl<T, 8 , 8 , BLK_M_, BLK_N_, 16 , 8 , 8 , 8 , 8 >( \
175- m, n, k, \
176- A_array_d, lda_d, B_array_d, ldb_d, \
17+ // Mapping into the impl's parameter list is:
18+ // <T, DIM_X, DIM_Y, BLK_M, BLK_N, BLK_K=16,
19+ // DIM_XA=DIM_X, DIM_YA=DIM_Y, DIM_XB=DIM_X, DIM_YB=DIM_Y>
20+ // which satisfies the kernel's tile-divisibility asserts because every
21+ // (BLK_M, BLK_N, BLK_K=16) chosen below is a multiple of the matching
22+ // (DIM_X, DIM_Y) pair.
23+ #define NN_DISPATCH (DX, DY, BM, BN ) \
24+ vbatched_gemm_nn_impl<T, DX, DY, BM, BN, 16 , DX, DY, DX, DY>( \
25+ m, n, k, \
26+ A_array_d, lda_d, B_array_d, ldb_d, \
17727 C_array_d, ldc_d, batchCount, stream, alpha)
17828
179- const int blk_m_tag = (n <= 8 ) ? 0 : (n <= 16 ) ? 1 : (n <= 32 ) ? 2 : 3 ;
180- const int blk_n_tag = (m < 48 ) ? 0 : 1 ; // {32, 64}
29+ // BLK_M bracket -- smallest tile in {8,16,32,48} covering nw2.
30+ const int blk_m_tag = (n <= 8 ) ? 0
31+ : (n <= 16 ) ? 1
32+ : (n <= 32 ) ? 2
33+ : 3 ;
18134
182- switch (blk_m_tag * 2 + blk_n_tag) {
183- case 0 : NN_DISPATCH ( 8 , 32 ); break ;
184- case 1 : NN_DISPATCH ( 8 , 64 ); break ;
185- case 2 : NN_DISPATCH (16 , 32 ); break ;
186- case 3 : NN_DISPATCH (16 , 64 ); break ;
187- case 4 : NN_DISPATCH (32 , 32 ); break ;
188- case 5 : NN_DISPATCH (32 , 64 ); break ;
189- case 6 : NN_DISPATCH (48 , 32 ); break ;
190- case 7 : NN_DISPATCH (48 , 64 ); break ;
35+ // BLK_N bracket -- 32 only when bxyz <=32 (caps mask waste at 50% for
36+ // bxyz=27); 64 for everything else (best LDS reuse).
37+ const int blk_n_tag = (m <= 32 ) ? 0 : 1 ;
38+
39+ switch (blk_m_tag * 2 + blk_n_tag)
40+ {
41+ // BLK_M=8 (nw2 <=8 ). DIM=4x8 -> THR_M=2.
42+ case 0 : NN_DISPATCH ( 4 , 8 , 8 , 32 ); break ; // THR=2*4=8 (under)
43+ case 1 : NN_DISPATCH ( 4 , 8 , 8 , 64 ); break ; // THR=2*8=16 (in band)
44+ // BLK_M=16 (nw2<=16). DIM=4x8 -> THR_M=4.
45+ case 2 : NN_DISPATCH ( 4 , 8 , 16 , 32 ); break ; // THR=4*4=16 (in band)
46+ case 3 : NN_DISPATCH ( 4 , 8 , 16 , 64 ); break ; // THR=4*8=32 (in band)
47+ // BLK_M=32 (nw2<=32). DIM=8x8 -> THR_M=4.
48+ case 4 : NN_DISPATCH ( 8 , 8 , 32 , 32 ); break ; // THR=4*4=16 (in band)
49+ case 5 : NN_DISPATCH ( 8 , 8 , 32 , 64 ); break ; // THR=4*8=32 (in band)
50+ // BLK_M=48 (nw2<=48). DIM=16x8 -> THR_M=3 (cap at 3 to keep
51+ // register pressure room for the BLK_N=64 sibling).
52+ case 6 : NN_DISPATCH (16 , 8 , 48 , 32 ); break ; // THR=3*4=12 (just under)
53+ case 7 : NN_DISPATCH (16 , 8 , 48 , 64 ); break ; // THR=3*8=24 (in band)
19154 }
19255
19356 #undef NN_DISPATCH
@@ -202,75 +65,47 @@ void gemm_tn_vbatch(
20265 int batchCount, cudaStream_t stream,
20366 const T* alpha)
20467{
205- // FP64 big tile (256-thread 64x64). Symmetric n>=48 && m>=48
206- // because, after the kernel's A/B swap, both output axes are small
207- // (kernel M = wrapper n = nw2, kernel N = wrapper m = nw1) and
208- // neither is intrinsically larger than the other.
209- if (tn_try_big_tile_ (m, n, k,
210- A_array_d, lda_d, B_array_d, ldb_d,
211- C_array_d, ldc_d, batchCount, stream, alpha))
212- {
213- return ;
214- }
215-
216- // 4 x 4 ladder (16 instantiations), tuned for V100 / A100:
217- // n (nw2 axis) -> BLK_M in {8, 16, 32, 48}
218- // m (nw1 axis) -> BLK_N in {8, 16, 32, 48}
219- // BLK_K fixed at 32 (bxyz axis)
220- // DIM_X=8, DIM_Y=8 (64 threads/block)
68+ // 4 (nw2 bracket) x 4 (nw1 bracket) = 16 instantiations.
22169 //
222- // Smallest-covering-tile selection, symmetric in both axes. This
223- // is *not* the same choice as NN -- on TN both output axes are
224- // small (nw1, nw2 in {4, 9, 13, 27, 44}) and neither is long
225- // enough to amortize the "prefer bigger" BLK_N logic from NN.
226- // Doubling BLK_* here would just push nw=4/9/13 cases off their
227- // exact-fit tile into a 2-4x mask-waste regime with no LDS-reuse
228- // upside (both axes of the output are already covered by one tile
229- // in this regime; a bigger tile just adds masked FMAs).
230- //
231- // The 48-rung covers nw=44 extended-basis atoms (Ti/Mn/Fe/Co/Ni/
232- // Cu/Zn/Zr/Ba) at ~8% mask waste per axis; without it those cases
233- // fall to a 2-tile BLK=32 grid at ~52% cell-launch waste.
234- //
235- // BLK_K=32 (larger than NN's 16) because K = bxyz here is large
236- // (27-125) and the K-axis tail wastes only shmem loads, never
237- // masked FMAs on the output -- bxyz <= 32 fits in one K-tile,
238- // larger bxyz wraps into 2-4 K-tiles. The modest __syncthreads()
239- // overhead from more K-tiles is cheaper than doubling BLK_K and
240- // forcing a re-tune of the `ra/rb` double-buffer register budget.
241- //
242- // All 16 (BLK_M, BLK_N) pairs are divisible by
243- // DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8, so every
244- // instantiation compiles to a valid kernel.
245- #define TN_DISPATCH (BLK_M_, BLK_N_ ) \
246- vbatched_gemm_tn_impl<T, 4 , 8 , BLK_M_, BLK_N_, 32 , 4 , 8 , 4 , 8 >( \
70+ // Both output axes here are the small nw axis, so we use the same
71+ // {8,16,32,48} ladder on both. BLK_K = 32 (the bxyz axis -- large).
72+ #define TN_DISPATCH (DX, DY, BM, BN ) \
73+ vbatched_gemm_tn_impl<T, DX, DY, BM, BN, 32 , DX, DY, DX, DY>( \
24774 m, n, k, \
24875 A_array_d, lda_d, B_array_d, ldb_d, \
24976 C_array_d, ldc_d, batchCount, stream, alpha)
25077
251- auto tag_for = [](int x) {
252- return (x <= 8 ) ? 0 : (x <= 16 ) ? 1 : (x <= 32 ) ? 2 : 3 ;
78+ auto bracket = [](int x) {
79+ return (x <= 8 ) ? 0
80+ : (x <= 16 ) ? 1
81+ : (x <= 32 ) ? 2
82+ : 3 ;
25383 };
254- const int blk_m_tag = tag_for (n); // kernel's M-dim grid -> wrapper n
255- const int blk_n_tag = tag_for (m); // kernel's N-dim grid -> wrapper m
84+ const int blk_m_tag = bracket (n); // BLK_M <- nw2
85+ const int blk_n_tag = bracket (m); // BLK_N <- nw1
25686
257- switch (blk_m_tag * 4 + blk_n_tag) {
258- case 0 : TN_DISPATCH ( 8 , 8 ); break ;
259- case 1 : TN_DISPATCH ( 8 , 16 ); break ;
260- case 2 : TN_DISPATCH ( 8 , 32 ); break ;
261- case 3 : TN_DISPATCH ( 8 , 48 ); break ;
262- case 4 : TN_DISPATCH (16 , 8 ); break ;
263- case 5 : TN_DISPATCH (16 , 16 ); break ;
264- case 6 : TN_DISPATCH (16 , 32 ); break ;
265- case 7 : TN_DISPATCH (16 , 48 ); break ;
266- case 8 : TN_DISPATCH (32 , 8 ); break ;
267- case 9 : TN_DISPATCH (32 , 16 ); break ;
268- case 10 : TN_DISPATCH (32 , 32 ); break ;
269- case 11 : TN_DISPATCH (32 , 48 ); break ;
270- case 12 : TN_DISPATCH (48 , 8 ); break ;
271- case 13 : TN_DISPATCH (48 , 16 ); break ;
272- case 14 : TN_DISPATCH (48 , 32 ); break ;
273- case 15 : TN_DISPATCH (48 , 48 ); break ;
87+ switch (blk_m_tag * 4 + blk_n_tag)
88+ {
89+ // BLK_M=8 rungs (nw2<=8). DIM_X=4, THR_M=2.
90+ case 0 : TN_DISPATCH (4 , 8 , 8 , 8 ); break ; // THR=2*1=2 (corner)
91+ case 1 : TN_DISPATCH (4 , 8 , 8 , 16 ); break ; // THR=2*2=4
92+ case 2 : TN_DISPATCH (4 , 8 , 8 , 32 ); break ; // THR=2*4=8
93+ case 3 : TN_DISPATCH (4 , 8 , 8 , 48 ); break ; // THR=2*6=12
94+ // BLK_M=16 rungs (nw2<=16). DIM_X=4, THR_M=4.
95+ case 4 : TN_DISPATCH (4 , 8 , 16 , 8 ); break ; // THR=4*1=4
96+ case 5 : TN_DISPATCH (4 , 8 , 16 , 16 ); break ; // THR=4*2=8
97+ case 6 : TN_DISPATCH (4 , 8 , 16 , 32 ); break ; // THR=4*4=16 (in band)
98+ case 7 : TN_DISPATCH (4 , 8 , 16 , 48 ); break ; // THR=4*6=24 (in band)
99+ // BLK_M=32 rungs (nw2<=32). DIM_X=8, THR_M=4.
100+ case 8 : TN_DISPATCH (8 , 4 , 32 , 8 ); break ; // THR=4*2=8
101+ case 9 : TN_DISPATCH (8 , 4 , 32 , 16 ); break ; // THR=4*4=16 (in band)
102+ case 10 : TN_DISPATCH (8 , 8 , 32 , 32 ); break ; // THR=4*4=16 (in band)
103+ case 11 : TN_DISPATCH (8 , 8 , 32 , 48 ); break ; // THR=4*6=24 (in band)
104+ // BLK_M=48 rungs (nw2<=48). DIM_X=8, THR_M=6.
105+ case 12 : TN_DISPATCH (8 , 4 , 48 , 8 ); break ; // THR=6*2=12
106+ case 13 : TN_DISPATCH (8 , 4 , 48 , 16 ); break ; // THR=6*4=24 (in band)
107+ case 14 : TN_DISPATCH (8 , 8 , 48 , 32 ); break ; // THR=6*4=24 (in band)
108+ case 15 : TN_DISPATCH (8 , 8 , 48 , 48 ); break ; // THR=6*6=36 (top of band)
274109 }
275110
276111 #undef TN_DISPATCH
0 commit comments