diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py new file mode 100644 index 00000000000..b0d34ddfd42 --- /dev/null +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# This is a test for the AITER ops. +# It tests if the AITER ops are +# 1. correctly registered as custom ops +# 2. correctly defined the relationship between +# implementation and fake function +# 3. can be used with torch.compile +# This file will be skipped if AITER is not installed +# and the platform is not ROCm. + +import importlib.util + +import pytest +import torch + +# this import statement is needed to ensure the ops are registered +import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 +from vllm.platforms import current_platform + +# need to import once to ensure the ops are registered +# Check if aiter package is installed +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and aiter_available), + reason="AITER ops are only available on ROCm with aiter package installed") + + +def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): + """Test that the custom op is correctly registered.""" + # Check if the op exists in torch.ops.vllm + assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + + # Check if the op is callable + assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) + + +def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): + """Test that the op can be used with torch.compile.""" + # Create test tensors + token = 64 + expert = 256 + num_expert_group = 8 + topk = 8 + topk_group = 4 + renormalize = True + scale_factor = 1.0 + + gating_output = torch.randn((token, expert), + dtype=torch.bfloat16, + device="cuda") + e_score_correction_bias = torch.randn((expert, ), + dtype=torch.bfloat16, + device="cuda") + + device = gating_output.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + + # Define a function that uses the op + def biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights, topk_ids): + return torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, e_score_correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, renormalize, scale_factor) + + # Verify the op's fake implementation + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_biased_grouped_topk, + (gating_output, e_score_correction_bias, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "routed_scaling_factor": scale_factor + }, + test_utils=("test_faketensor")) + + # Compile the function with appropriate settings + compiled_fn = torch.compile(biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + topk_weights_original = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_original = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + topk_weights_compiled = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_compiled = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) + biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights_original, topk_ids_original) + compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, + topk_ids_compiled) + + # Sort the results for comparison since the order might not be deterministic + topk_ids_original, indices_original = torch.sort(topk_ids_original) + topk_weights_original = torch.gather(topk_weights_original, 1, + indices_original) + + topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, + indices_compiled) + + # Verify results match + assert torch.allclose(topk_weights_original, + topk_weights_compiled, + rtol=1e-2, + atol=1e-2) + assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eae029b33e8..5d447754e2b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -16,6 +16,8 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -27,6 +29,11 @@ from .fused_moe import fused_experts else: fused_experts = None # type: ignore +if is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_biased_group_topk as grouped_topk) +else: + from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas @@ -801,8 +808,7 @@ def select_experts(hidden_states: torch.Tensor, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None): - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, grouped_topk) + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index d42cfce4aad..02e75402f4c 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -216,6 +216,37 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, pass +def rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + + from aiter import biased_grouped_topk + + biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, need_renorm, + routed_scaling_factor) + + +def rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + pass + + if current_platform.is_rocm(): direct_register_custom_op( @@ -258,6 +289,46 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_biased_group_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "sigmoid", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + assert scoring_func == "sigmoid", ( + "rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.") + assert e_score_correction_bias is not None, ( + "'e_score_correction_bias' must not be None.") + token = hidden_states.shape[0] + device = hidden_states.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + ) + return topk_weights, topk_ids + def rocm_aiter_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,