Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 796 files
7 changes: 5 additions & 2 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,16 @@
"f'{AITER_CSRC_DIR}/kernels/moe_align_block_size_kernels.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_fmoe.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_moe_2stage.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_topksoftmax.cu'"
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_topksoftmax.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_ck/topk_sigmoid_kernels.cu'",
"f'{CK_DIR}/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/include/ck_tile'"
"f'{AITER_CSRC_DIR}/include/ck_tile'",
"f'{CK_DIR}/example/ck_tile/09_topk_softmax'"
],
"verbose": "False",
"blob_gen_cmd": [
Expand Down
6 changes: 6 additions & 0 deletions aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def topk_softmax_asm(
) -> None: ...


@compile_ops("module_moe_asm")
def topk_sigmoid(
topk_weights: Tensor, topk_indices: Tensor, gating_output: Tensor
) -> None: ...


@compile_ops("module_moe_asm")
def moe_sum(input: Tensor, output: Tensor) -> None: ...

Expand Down
4 changes: 4 additions & 0 deletions csrc/include/moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,8 @@ void moe_align_block_size(torch::Tensor topk_ids,

void moe_sum(torch::Tensor& input, torch::Tensor& output);

void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk]
torch::Tensor topk_indices, // [tokens, topk]
torch::Tensor gating_output); // [tokens, experts]

} // namespace aiter
6 changes: 6 additions & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,12 @@
py::arg("need_renorm"), \
py::arg("routed_scaling_factor") = 1.0f, \
"Apply biased grouped topk softmax to the gating outputs."); \
m.def("topk_sigmoid", \
&aiter::topk_sigmoid, \
py::arg("topk_weights"), \
py::arg("topk_indices"), \
py::arg("gating_output"), \
"Apply topk sigmoid to the gating outputs."); \
m.def("moe_fused_gate", \
&moe_fused_gate, \
py::arg("input"), \
Expand Down
72 changes: 72 additions & 0 deletions csrc/py_itfs_ck/topk_sigmoid_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "py_itfs_common.h"

// from CK examples:
#include "topk_softmax_api.hpp"

namespace aiter
{

void topk_sigmoid(torch::Tensor topk_weights, // [tokens, topk]
torch::Tensor topk_indices, // [tokens, topk]
torch::Tensor gating_output) // [tokens, experts]
{
// Ensure the tensors are on the correct device
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Extract dimensions
const int tokens = gating_output.size(0);
const int experts = gating_output.size(1);
const int topk = topk_weights.size(1);

// Assume default strides
const int stride_input = experts;
const int stride_output = topk;

// Determine datatypes
auto dtype_to_string = [](const auto dtype) -> std::string {
if(dtype == torch::kFloat16)
{
return "fp16";
}
else if(dtype == torch::kBFloat16)
{
return "bf16";
}
else if(dtype == torch::kFloat32)
{
return "fp32";
}
else
{
throw std::runtime_error("invalid datatype for topk_sigmoid: only fp16/bf16/fp32!");
}
};
std::string input_prec = dtype_to_string(gating_output.dtype());
std::string weight_prec = dtype_to_string(topk_weights.dtype());

// Prepare kernel arguments
static const std::string activation = "sigmoid";
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};

topk_softmax_kargs karg{gating_output.data_ptr(),
topk_weights.data_ptr(),
topk_indices.data_ptr(),
tokens,
experts,
topk,
stride_input,
stride_output};

ck_tile::stream_config sc{stream};

topk_softmax(trait, karg, sc);
}

} // namespace aiter
78 changes: 78 additions & 0 deletions op_tests/test_moe_topk_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.

import torch
import aiter
from aiter.test_common import (
checkAllclose,
perftest,
)


@perftest(num_iters=10, num_warmup=1)
def run_torch(gating_output: torch.Tensor, topk: int):
# llama4 maverick custom routing function
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return router_scores, router_indices.to(torch.int32)


@perftest(num_iters=10, num_warmup=1)
def run_fused(gating_output: torch.Tensor, topk: int):
tokens, _ = gating_output.shape
router_scores = torch.empty(
(tokens, topk), dtype=torch.float32, device=gating_output.device
)
router_indices = torch.empty(
(tokens, topk), dtype=torch.int32, device=gating_output.device
)
aiter.topk_sigmoid(router_scores, router_indices, gating_output)
return router_scores, router_indices


def test_topk_sigmoid(
num_experts: int = 128,
num_tokens: int = 1024,
topk: int = 4,
dtype: torch.dtype = torch.float16,
):
# generate data - each row has only unique values
gating_output = (
torch.arange(-1, 1, 2.0 / num_experts)
.repeat((num_tokens, 1))
.to(dtype=dtype, device="cuda")
)
permutation = torch.argsort(torch.rand_like(gating_output), dim=-1)
gating_output = torch.gather(gating_output, dim=-1, index=permutation)
assert gating_output.is_contiguous()
# run benchmarks
(scores_torch, indices_torch), avg_torch = run_torch(gating_output.clone(), topk)
(scores_fused, indices_fused), avg_fused = run_fused(gating_output.clone(), topk)
# check correctness
score_errors = checkAllclose(scores_torch, scores_fused, tol_err_ratio=0.01)
index_errors = checkAllclose(indices_torch, indices_fused, tol_err_ratio=0.01)
# print some failed rows
if score_errors > 0.01 or index_errors > 0.01:
failed_rows = (indices_torch != indices_fused).sum(dim=-1) > 0
print("Wrong scores:")
print(scores_torch[failed_rows][:5])
print(scores_fused[failed_rows][:5])
print("Wrong indices:")
print(indices_torch[failed_rows][:5])
print(indices_fused[failed_rows][:5])
print("Gating outputs:")
failed_values = gating_output[failed_rows][:5]
failed_values, _ = failed_values.sort(dim=-1, descending=True)
print(failed_values[:, :10])
print(
f"Number of wrong tokens: {sum(failed_rows)} / {len(failed_rows)}, {100 * sum(failed_rows) / len(failed_rows):.2f} %"
)
# print run times
print(f"Runtime (torch baseline): {avg_torch}")
print(f"Runtime (fused topk sigmoid): {avg_fused}")
print(f"Uplift: {avg_torch / avg_fused:.2f}x")


if __name__ == "__main__":
test_topk_sigmoid(dtype=torch.float16)
test_topk_sigmoid(dtype=torch.bfloat16)
Loading