Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne_group.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz.co
Binary file not shown.
Binary file modified hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz_group.co
Binary file not shown.
115 changes: 103 additions & 12 deletions hsa/gfx942/fmha_v3_fwd/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,35 @@ class fmha_fwd_v3_kernel
int gdx = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div;
int gdy = fmha_v3_traits.h;
int gdz = fmha_v3_traits.b;
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}

void
launch_kernel_group(fmha_fwd_v3_traits fmha_v3_traits, fmha_fwd_v3_args args, const ck_tile::stream_config& s) const
{
size_t arg_size = sizeof(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END};

int tg_div = (fmha_v3_traits.mask != 0) ? 2 : 1;

int bdx = 512;
int gdx = fmha_v3_traits.h;
int gdy = fmha_v3_traits.b;
int gdz = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div;
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
Expand Down Expand Up @@ -224,6 +252,69 @@ class fmha_fwd_v3_kernel
);
}

template <typename fmha_fwd_kernel_selector>
float fmha_fwd_v3_group_dispatcher(const ck_tile::stream_config& s, mha_fwd_args a,
const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr)
{
if(s.log_level_ > 0)
std::cout << ", " << FmhaFwdV3Name<fmha_fwd_kernel_selector>::fwd_v3_name << std::flush;

int tune_opt = 5;
if (a.mask_type != 0 && ((a.nhead_q % 8 != 0) || (a.seqlen_q > 16384))) //if num_head is not 8N, or seqlen is bigger than 16K, downgrade to 2and3
{
tune_opt -= 2;
}

fmha_fwd_v3_args args;
args.ptr_o = a.o_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_lse = a.lse_ptr;

args.scalar = a.scale_s;
args.s_seq_len = a.seqlen_q;
args.s_Seqs = a.stride_q * 2;
args.s_Ts = FmhaFwdV3Ts<fmha_fwd_kernel_selector>::ts_qo * a.stride_q * 2;
args.s_Hs = a.nhead_stride_q * 2;
args.s_Bs = a.batch_stride_q * 2;
args.s_gqa = a.nhead_q / a.nhead_k;
args.s_k_Seqs = a.stride_k * 2;
args.s_k_Hs = a.nhead_stride_k * 2;
args.s_k_Bs = a.batch_stride_k * 2;
args.s_opt = tune_opt;
args.s_lse = fmha_fwd_kernel_selector::kStoreLSE;
args.s_kv_seq_len = a.seqlen_k;
args.s_qk_head_dim = a.hdim_q;
args.s_v_head_dim = a.hdim_v;
args.s_q_head_num = a.nhead_q;
args.s_v_Seqs = a.stride_v * 2;
args.s_v_Hs = a.nhead_stride_v * 2;
args.s_v_Bs = a.batch_stride_v * 2;
args.s_o_Seqs = a.stride_o * 2;
args.s_o_Hs = a.nhead_stride_o * 2;
args.s_o_Bs = a.batch_stride_o * 2;

args.s_lse_Hs = a.nhead_stride_lse * 4;
args.ptr_qseq = a.seqstart_q_ptr;
args.ptr_kseq = a.seqstart_k_ptr;
args.ptr_qseq_padding = seqstart_q_padding_ptr == nullptr ? a.seqstart_q_ptr : seqstart_q_padding_ptr;
args.ptr_kseq_padding = seqstart_k_padding_ptr == nullptr ? a.seqstart_k_ptr : seqstart_k_padding_ptr;

auto traits = fmha_fwd_v3_traits{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
a.mask_type,
FmhaFwdV3Ts<fmha_fwd_kernel_selector>::ts_qo,
FmhaFwdV3Ts<fmha_fwd_kernel_selector>::ts_kv};

static thread_local fmha_fwd_v3_kernel impl(FmhaFwdV3Name<fmha_fwd_kernel_selector>::fwd_v3_name, FmhaFwdV3Buf<fmha_fwd_kernel_selector>::fwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){ impl.launch_kernel_group(traits, args, s_); }
);
}

float fmha_fwd_v3(mha_fwd_traits t, mha_fwd_args a, const ck_tile::stream_config& s, const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr,
bool is_v3_api_check) {
float r = -1;
Expand Down Expand Up @@ -354,14 +445,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 1, GPUArch::gfx942, 0, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
else if(t.how_v3_bf16_cvt == 1) {
Expand All @@ -370,14 +461,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 1, GPUArch::gfx942, 1, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
else if(t.how_v3_bf16_cvt == 2) {
Expand All @@ -386,14 +477,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 1, GPUArch::gfx942, 2, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
}
Expand All @@ -404,14 +495,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 1, GPUArch::gfx942, 0, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
else if(t.how_v3_bf16_cvt == 1) {
Expand All @@ -420,14 +511,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 1, GPUArch::gfx942, 1, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
else if(t.how_v3_bf16_cvt == 2) {
Expand All @@ -436,14 +527,14 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
else {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 1, GPUArch::gfx942, 2, true>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
r = fmha_fwd_v3_group_dispatcher<fmha_fwd_kernel>(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr);
}
}
}
Expand Down
Loading