Skip to content

[FEAT] [ROCm] [V1]: Add AITER biased group topk for DeepSeekV3 #17955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/kernels/moe/test_rocm_aiter_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
# implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.

import importlib.util

import pytest
import torch

# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
from vllm.platforms import current_platform

# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")


def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')

# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)


def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
"""Test that the op can be used with torch.compile."""
# Create test tensors
token = 64
expert = 256
num_expert_group = 8
topk = 8
topk_group = 4
renormalize = True
scale_factor = 1.0

gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
e_score_correction_bias = torch.randn((expert, ),
dtype=torch.bfloat16,
device="cuda")

device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)

# Define a function that uses the op
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights, topk_ids):
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, e_score_correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, renormalize, scale_factor)

# Verify the op's fake implementation
torch.library.opcheck(
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"routed_scaling_factor": scale_factor
},
test_utils=("test_faketensor"))

# Compile the function with appropriate settings
compiled_fn = torch.compile(biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)

topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)

topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)

# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights_original, topk_ids_original)
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
topk_ids_compiled)

# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)

topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)

# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(topk_ids_original, topk_ids_compiled)
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -27,6 +29,11 @@
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_biased_group_topk as grouped_topk)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
Expand Down Expand Up @@ -801,8 +808,7 @@ def select_experts(hidden_states: torch.Tensor,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk

# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
Expand Down
71 changes: 71 additions & 0 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,37 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
pass


def rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0 # mul to topk_weights
) -> None:

from aiter import biased_grouped_topk

biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, need_renorm,
routed_scaling_factor)


def rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0 # mul to topk_weights
) -> None:
pass


if current_platform.is_rocm():

direct_register_custom_op(
Expand Down Expand Up @@ -258,6 +289,46 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)


def rocm_aiter_biased_group_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "sigmoid",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert scoring_func == "sigmoid", (
"rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.")
assert e_score_correction_bias is not None, (
"'e_score_correction_bias' must not be None.")
token = hidden_states.shape[0]
device = hidden_states.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output,
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
)
return topk_weights, topk_ids


def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down