Skip to content

Commit 67f7e67

Browse files
committed
optimize the template parameters
1 parent 8313c52 commit 67f7e67

1 file changed

Lines changed: 69 additions & 234 deletions

File tree

source/source_lcao/module_gint/kernel/dgemm_vbatch.cu

Lines changed: 69 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -3,111 +3,6 @@
33
#include "dgemm_vbatch.h"
44
#include "source_base/module_device/device.h"
55

6-
// ----------------------------------------------------------------------------
7-
// FP64-only big-tile dispatch (Phase V4)
8-
// ----------------------------------------------------------------------------
9-
// Pattern mirrors the C++11-compatible overload trick used in
10-
// gint_vl.cpp / gint_rho.cpp (note in those files: "C++11-compatible
11-
// alternative to if constexpr"). The non-template double overload is the
12-
// preferred candidate when T = double; the template fallback returns
13-
// false for every other dtype so the FP32 path stays untouched (per the
14-
// V100 plan: 3090 FP32 was the test proxy and the big tile was not
15-
// validated on FP32).
16-
// Returns true if the big tile dispatched (caller should `return`).
17-
// ----------------------------------------------------------------------------
18-
19-
inline bool nn_try_big_tile_(
20-
int m, int n, int k,
21-
const double* const* A_array_d, const int* lda_d,
22-
const double* const* B_array_d, const int* ldb_d,
23-
double** C_array_d, const int* ldc_d,
24-
int batchCount, cudaStream_t stream, const double* alpha)
25-
{
26-
// 16x16 threads = 256, BLK_M=BLK_N=64, BLK_K=16. THR_M = THR_N = 4
27-
// -> 16 FMAs per inner step, 32 with VK=2. Loaders use DIM_*A=DIM_*B=16.
28-
if (n >= 48 && m >= 64) {
29-
vbatched_gemm_nn_impl<double,
30-
/*DIM_X */ 16, /*DIM_Y */ 16,
31-
/*BLK_M */ 64, /*BLK_N */ 64, /*BLK_K*/ 16,
32-
/*DIM_XA*/ 16, /*DIM_YA*/ 16,
33-
/*DIM_XB*/ 16, /*DIM_YB*/ 16>(
34-
m, n, k,
35-
A_array_d, lda_d, B_array_d, ldb_d,
36-
C_array_d, ldc_d, batchCount, stream, alpha);
37-
return true;
38-
}
39-
return false;
40-
}
41-
42-
template <typename T>
43-
inline bool nn_try_big_tile_(
44-
int /*m*/, int /*n*/, int /*k*/,
45-
const T* const* /*A*/, const int* /*lda*/,
46-
const T* const* /*B*/, const int* /*ldb*/,
47-
T** /*C*/, const int* /*ldc*/,
48-
int /*batch*/, cudaStream_t /*stream*/, const T* /*alpha*/)
49-
{
50-
return false;
51-
}
52-
53-
inline bool tn_try_big_tile_(
54-
int m, int n, int k,
55-
const double* const* A_array_d, const int* lda_d,
56-
const double* const* B_array_d, const int* ldb_d,
57-
double** C_array_d, const int* ldc_d,
58-
int batchCount, cudaStream_t stream, const double* alpha)
59-
{
60-
// Axis flip vs NN: kernel M = wrapper n = nw2, kernel N = wrapper m = nw1.
61-
// Threshold is symmetric at 48 in both axes.
62-
// BLK_K=16 (not 32 as in the existing TN ladder) keeps the big-tile shmem
63-
// footprint at ~18 KB/block so 4 blocks/SM still fit on V100's 96 KB.
64-
if (n >= 48 && m >= 48) {
65-
vbatched_gemm_tn_impl<double,
66-
/*DIM_X */ 16, /*DIM_Y */ 16,
67-
/*BLK_M */ 64, /*BLK_N */ 64, /*BLK_K*/ 16,
68-
/*DIM_XA*/ 16, /*DIM_YA*/ 16,
69-
/*DIM_XB*/ 16, /*DIM_YB*/ 16>(
70-
m, n, k,
71-
A_array_d, lda_d, B_array_d, ldb_d,
72-
C_array_d, ldc_d, batchCount, stream, alpha);
73-
return true;
74-
}
75-
return false;
76-
}
77-
78-
template <typename T>
79-
inline bool tn_try_big_tile_(
80-
int /*m*/, int /*n*/, int /*k*/,
81-
const T* const* /*A*/, const int* /*lda*/,
82-
const T* const* /*B*/, const int* /*ldb*/,
83-
T** /*C*/, const int* /*ldc*/,
84-
int /*batch*/, cudaStream_t /*stream*/, const T* /*alpha*/)
85-
{
86-
return false;
87-
}
88-
89-
// ----------------------------------------------------------------------------
90-
// Shape-exact dispatch
91-
// ----------------------------------------------------------------------------
92-
//
93-
// The caller (phi_operator_gpu.cu) buckets atom pairs by (nw1, nw2) so every
94-
// item in a batch has exactly the same (m, n, k). The scalars passed here are
95-
// the *exact* per-matrix shapes (not upper bounds), which lets the tile
96-
// ladder pick the tightest template and sizes the grid tightly (no
97-
// over-launched blocks that short-circuit inside the kernel).
98-
//
99-
// Kernel-level dimension mapping (after the A/B swap inside
100-
// vbatched_gemm_*_impl):
101-
//
102-
// call | wrapper m | wrapper n | wrapper k
103-
// --------|--------------|-------------|---------------
104-
// NN | bxyz (large) | nw2 (small) | nw1 (small)
105-
// TN | nw1 (small) | nw2 (small) | bxyz (large)
106-
//
107-
// (m, n, k) flow through as scalars all the way down into the kernel, so
108-
// there is no per-batchid M/N/K load and no fill-kernel scratch buffer.
109-
// ----------------------------------------------------------------------------
110-
1116
template<typename T>
1127
void gemm_nn_vbatch(
1138
int m, int n, int k,
@@ -117,77 +12,45 @@ void gemm_nn_vbatch(
11712
int batchCount, cudaStream_t stream,
11813
const T* alpha)
11914
{
120-
// FP64 big tile (256-thread 64x64). Little's Law says V100 needs
121-
// ~300 in-flight FP64 FMAs/SM to saturate; the 16x16-thread 4x4
122-
// register tile puts 4096 FMAs/step/block in flight, so one block
123-
// already covers the pipe and the second one hides LDS latency.
124-
if (nn_try_big_tile_(m, n, k,
125-
A_array_d, lda_d, B_array_d, ldb_d,
126-
C_array_d, ldc_d, batchCount, stream, alpha))
127-
{
128-
return;
129-
}
130-
131-
// 4 x 2 ladder (8 instantiations), tuned for V100 / A100:
132-
// n (nw2 axis) -> BLK_M in {8, 16, 32, 48} (smallest full-cover)
133-
// m (bxyz axis) -> BLK_N in {32, 64} (larger-is-better)
134-
// BLK_K fixed at 16 (nw1 axis, <=27)
135-
// DIM_X=8, DIM_Y=16 (128 threads/block, unchanged)
136-
//
137-
// Philosophy vs the prior tail-waste-min ladder:
138-
//
139-
// That ladder picked BLK_N by minimizing (tail_waste, grid_blocks)
140-
// lexicographically. It's the right objective on sm_86 consumer
141-
// Ampere (RTX 3090) where FP64 is 1/64 of FP32 -- every masked FMA
142-
// there is a full FP64-pipe-bound cycle, so minimizing launched
143-
// cells dominates.
144-
//
145-
// On V100 (sm_70) / A100 (sm_80) FP64 is a first-class pipe
146-
// (7.8 / 9.7 TFLOPS peak, ridge ~6-9 FLOP/B), and the inner loop
147-
// is LDS-bound for the nw1/nw2 ranges we see (ncu confirms L1/TEX
148-
// >= 95% on these tiles). The right objective flips to maximizing
149-
// per-block LDS reuse. The wide-LDS inner loop delivers
150-
// FMA / LDS = VK * THR_M * THR_N / (THR_M + THR_N)
151-
// to the shmem pipe; for the scalar-K-tail regime (nw1 < BLK_K,
152-
// hit by nw1 <= 16 on NN) the VK factor drops out but the ratio
153-
// shape is the same. Rough V100 FP64 target ratio is 2 FMAs/LDS
154-
// (32 FP64-FMA/cycle/SM vs LDS.64 delivering one serving/cycle).
15+
// 4 (nw2 bracket) x 2 (bxyz bracket) = 8 instantiations.
15516
//
156-
// At DIM=8x16, BLK_N=64 gives THR_N=4, THR_M=4 -> FMA/LDS = 2.0
157-
// (matched). BLK_N=32 drops it to 1.33 (LDS-bound, FP64 headroom
158-
// unused). So BLK_N=64 is strictly better for every bxyz >= 48;
159-
// only bxyz=27 still prefers BLK_N=32 to cap the N-axis mask waste
160-
// below 50%. The intermediate rungs {16, 48} are dropped: {32, 64}
161-
// covers bxyz in {27, 48, 64, 80, 100, 125} at its LDS-optimal
162-
// point in every case.
163-
//
164-
// BLK_M retains four rungs: the nw2 axis is tiny (<=44 in practice)
165-
// and a wrong-BLK_M costs twice -- masked FMAs *and* a wider sA
166-
// row load per K-step. The 48-rung is kept specifically for nw2=44
167-
// extended-basis atoms (Ti/Mn/Fe/Co/Ni/Cu/Zn/Zr/Ba); otherwise the
168-
// 32-rung falls off to a 2-tile grid at ~31% total waste.
169-
//
170-
// All eight (BLK_M, BLK_N) satisfy the kernel's BLK_M % DIM_X=0
171-
// and BLK_N % DIM_Y=0 constraints, and the tiny-tile {BLK_M=8,
172-
// BLK_N=32} rung still has THR_M=1 which compiles cleanly.
173-
#define NN_DISPATCH(BLK_M_, BLK_N_) \
174-
vbatched_gemm_nn_impl<T, 8, 8, BLK_M_, BLK_N_, 16, 8, 8, 8, 8>( \
175-
m, n, k, \
176-
A_array_d, lda_d, B_array_d, ldb_d, \
17+
// Mapping into the impl's parameter list is:
18+
// <T, DIM_X, DIM_Y, BLK_M, BLK_N, BLK_K=16,
19+
// DIM_XA=DIM_X, DIM_YA=DIM_Y, DIM_XB=DIM_X, DIM_YB=DIM_Y>
20+
// which satisfies the kernel's tile-divisibility asserts because every
21+
// (BLK_M, BLK_N, BLK_K=16) chosen below is a multiple of the matching
22+
// (DIM_X, DIM_Y) pair.
23+
#define NN_DISPATCH(DX, DY, BM, BN) \
24+
vbatched_gemm_nn_impl<T, DX, DY, BM, BN, 16, DX, DY, DX, DY>( \
25+
m, n, k, \
26+
A_array_d, lda_d, B_array_d, ldb_d, \
17727
C_array_d, ldc_d, batchCount, stream, alpha)
17828

179-
const int blk_m_tag = (n <= 8) ? 0 : (n <= 16) ? 1 : (n <= 32) ? 2 : 3;
180-
const int blk_n_tag = (m < 48) ? 0 : 1; // {32, 64}
29+
// BLK_M bracket -- smallest tile in {8,16,32,48} covering nw2.
30+
const int blk_m_tag = (n <= 8) ? 0
31+
: (n <= 16) ? 1
32+
: (n <= 32) ? 2
33+
: 3;
18134

182-
switch (blk_m_tag * 2 + blk_n_tag) {
183-
case 0: NN_DISPATCH( 8, 32); break;
184-
case 1: NN_DISPATCH( 8, 64); break;
185-
case 2: NN_DISPATCH(16, 32); break;
186-
case 3: NN_DISPATCH(16, 64); break;
187-
case 4: NN_DISPATCH(32, 32); break;
188-
case 5: NN_DISPATCH(32, 64); break;
189-
case 6: NN_DISPATCH(48, 32); break;
190-
case 7: NN_DISPATCH(48, 64); break;
35+
// BLK_N bracket -- 32 only when bxyz <=32 (caps mask waste at 50% for
36+
// bxyz=27); 64 for everything else (best LDS reuse).
37+
const int blk_n_tag = (m <= 32) ? 0 : 1;
38+
39+
switch (blk_m_tag * 2 + blk_n_tag)
40+
{
41+
// BLK_M=8 (nw2 <=8 ). DIM=4x8 -> THR_M=2.
42+
case 0: NN_DISPATCH( 4, 8, 8, 32); break; // THR=2*4=8 (under)
43+
case 1: NN_DISPATCH( 4, 8, 8, 64); break; // THR=2*8=16 (in band)
44+
// BLK_M=16 (nw2<=16). DIM=4x8 -> THR_M=4.
45+
case 2: NN_DISPATCH( 4, 8, 16, 32); break; // THR=4*4=16 (in band)
46+
case 3: NN_DISPATCH( 4, 8, 16, 64); break; // THR=4*8=32 (in band)
47+
// BLK_M=32 (nw2<=32). DIM=8x8 -> THR_M=4.
48+
case 4: NN_DISPATCH( 8, 8, 32, 32); break; // THR=4*4=16 (in band)
49+
case 5: NN_DISPATCH( 8, 8, 32, 64); break; // THR=4*8=32 (in band)
50+
// BLK_M=48 (nw2<=48). DIM=16x8 -> THR_M=3 (cap at 3 to keep
51+
// register pressure room for the BLK_N=64 sibling).
52+
case 6: NN_DISPATCH(16, 8, 48, 32); break; // THR=3*4=12 (just under)
53+
case 7: NN_DISPATCH(16, 8, 48, 64); break; // THR=3*8=24 (in band)
19154
}
19255

19356
#undef NN_DISPATCH
@@ -202,75 +65,47 @@ void gemm_tn_vbatch(
20265
int batchCount, cudaStream_t stream,
20366
const T* alpha)
20467
{
205-
// FP64 big tile (256-thread 64x64). Symmetric n>=48 && m>=48
206-
// because, after the kernel's A/B swap, both output axes are small
207-
// (kernel M = wrapper n = nw2, kernel N = wrapper m = nw1) and
208-
// neither is intrinsically larger than the other.
209-
if (tn_try_big_tile_(m, n, k,
210-
A_array_d, lda_d, B_array_d, ldb_d,
211-
C_array_d, ldc_d, batchCount, stream, alpha))
212-
{
213-
return;
214-
}
215-
216-
// 4 x 4 ladder (16 instantiations), tuned for V100 / A100:
217-
// n (nw2 axis) -> BLK_M in {8, 16, 32, 48}
218-
// m (nw1 axis) -> BLK_N in {8, 16, 32, 48}
219-
// BLK_K fixed at 32 (bxyz axis)
220-
// DIM_X=8, DIM_Y=8 (64 threads/block)
68+
// 4 (nw2 bracket) x 4 (nw1 bracket) = 16 instantiations.
22169
//
222-
// Smallest-covering-tile selection, symmetric in both axes. This
223-
// is *not* the same choice as NN -- on TN both output axes are
224-
// small (nw1, nw2 in {4, 9, 13, 27, 44}) and neither is long
225-
// enough to amortize the "prefer bigger" BLK_N logic from NN.
226-
// Doubling BLK_* here would just push nw=4/9/13 cases off their
227-
// exact-fit tile into a 2-4x mask-waste regime with no LDS-reuse
228-
// upside (both axes of the output are already covered by one tile
229-
// in this regime; a bigger tile just adds masked FMAs).
230-
//
231-
// The 48-rung covers nw=44 extended-basis atoms (Ti/Mn/Fe/Co/Ni/
232-
// Cu/Zn/Zr/Ba) at ~8% mask waste per axis; without it those cases
233-
// fall to a 2-tile BLK=32 grid at ~52% cell-launch waste.
234-
//
235-
// BLK_K=32 (larger than NN's 16) because K = bxyz here is large
236-
// (27-125) and the K-axis tail wastes only shmem loads, never
237-
// masked FMAs on the output -- bxyz <= 32 fits in one K-tile,
238-
// larger bxyz wraps into 2-4 K-tiles. The modest __syncthreads()
239-
// overhead from more K-tiles is cheaper than doubling BLK_K and
240-
// forcing a re-tune of the `ra/rb` double-buffer register budget.
241-
//
242-
// All 16 (BLK_M, BLK_N) pairs are divisible by
243-
// DIM_X/DIM_Y/DIM_XA/DIM_YA/DIM_XB/DIM_YB = 8, so every
244-
// instantiation compiles to a valid kernel.
245-
#define TN_DISPATCH(BLK_M_, BLK_N_) \
246-
vbatched_gemm_tn_impl<T, 4, 8, BLK_M_, BLK_N_, 32, 4, 8, 4, 8>( \
70+
// Both output axes here are the small nw axis, so we use the same
71+
// {8,16,32,48} ladder on both. BLK_K = 32 (the bxyz axis -- large).
72+
#define TN_DISPATCH(DX, DY, BM, BN) \
73+
vbatched_gemm_tn_impl<T, DX, DY, BM, BN, 32, DX, DY, DX, DY>( \
24774
m, n, k, \
24875
A_array_d, lda_d, B_array_d, ldb_d, \
24976
C_array_d, ldc_d, batchCount, stream, alpha)
25077

251-
auto tag_for = [](int x) {
252-
return (x <= 8) ? 0 : (x <= 16) ? 1 : (x <= 32) ? 2 : 3;
78+
auto bracket = [](int x) {
79+
return (x <= 8) ? 0
80+
: (x <= 16) ? 1
81+
: (x <= 32) ? 2
82+
: 3;
25383
};
254-
const int blk_m_tag = tag_for(n); // kernel's M-dim grid -> wrapper n
255-
const int blk_n_tag = tag_for(m); // kernel's N-dim grid -> wrapper m
84+
const int blk_m_tag = bracket(n); // BLK_M <- nw2
85+
const int blk_n_tag = bracket(m); // BLK_N <- nw1
25686

257-
switch (blk_m_tag * 4 + blk_n_tag) {
258-
case 0: TN_DISPATCH( 8, 8); break;
259-
case 1: TN_DISPATCH( 8, 16); break;
260-
case 2: TN_DISPATCH( 8, 32); break;
261-
case 3: TN_DISPATCH( 8, 48); break;
262-
case 4: TN_DISPATCH(16, 8); break;
263-
case 5: TN_DISPATCH(16, 16); break;
264-
case 6: TN_DISPATCH(16, 32); break;
265-
case 7: TN_DISPATCH(16, 48); break;
266-
case 8: TN_DISPATCH(32, 8); break;
267-
case 9: TN_DISPATCH(32, 16); break;
268-
case 10: TN_DISPATCH(32, 32); break;
269-
case 11: TN_DISPATCH(32, 48); break;
270-
case 12: TN_DISPATCH(48, 8); break;
271-
case 13: TN_DISPATCH(48, 16); break;
272-
case 14: TN_DISPATCH(48, 32); break;
273-
case 15: TN_DISPATCH(48, 48); break;
87+
switch (blk_m_tag * 4 + blk_n_tag)
88+
{
89+
// BLK_M=8 rungs (nw2<=8). DIM_X=4, THR_M=2.
90+
case 0: TN_DISPATCH(4, 8, 8, 8); break; // THR=2*1=2 (corner)
91+
case 1: TN_DISPATCH(4, 8, 8, 16); break; // THR=2*2=4
92+
case 2: TN_DISPATCH(4, 8, 8, 32); break; // THR=2*4=8
93+
case 3: TN_DISPATCH(4, 8, 8, 48); break; // THR=2*6=12
94+
// BLK_M=16 rungs (nw2<=16). DIM_X=4, THR_M=4.
95+
case 4: TN_DISPATCH(4, 8, 16, 8); break; // THR=4*1=4
96+
case 5: TN_DISPATCH(4, 8, 16, 16); break; // THR=4*2=8
97+
case 6: TN_DISPATCH(4, 8, 16, 32); break; // THR=4*4=16 (in band)
98+
case 7: TN_DISPATCH(4, 8, 16, 48); break; // THR=4*6=24 (in band)
99+
// BLK_M=32 rungs (nw2<=32). DIM_X=8, THR_M=4.
100+
case 8: TN_DISPATCH(8, 4, 32, 8); break; // THR=4*2=8
101+
case 9: TN_DISPATCH(8, 4, 32, 16); break; // THR=4*4=16 (in band)
102+
case 10: TN_DISPATCH(8, 8, 32, 32); break; // THR=4*4=16 (in band)
103+
case 11: TN_DISPATCH(8, 8, 32, 48); break; // THR=4*6=24 (in band)
104+
// BLK_M=48 rungs (nw2<=48). DIM_X=8, THR_M=6.
105+
case 12: TN_DISPATCH(8, 4, 48, 8); break; // THR=6*2=12
106+
case 13: TN_DISPATCH(8, 4, 48, 16); break; // THR=6*4=24 (in band)
107+
case 14: TN_DISPATCH(8, 8, 48, 32); break; // THR=6*4=24 (in band)
108+
case 15: TN_DISPATCH(8, 8, 48, 48); break; // THR=6*6=36 (top of band)
274109
}
275110

276111
#undef TN_DISPATCH

0 commit comments

Comments
 (0)