Skip to content

Commit 2d912fb

Browse files
vllmellmtjtanaa
andauthored
[FEAT] [ROCm] [V1]: Add AITER biased group topk for DeepSeekV3 (#17955)
Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
1 parent 12e6c0b commit 2d912fb

File tree

3 files changed

+201
-2
lines changed

3 files changed

+201
-2
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# This is a test for the AITER ops.
3+
# It tests if the AITER ops are
4+
# 1. correctly registered as custom ops
5+
# 2. correctly defined the relationship between
6+
# implementation and fake function
7+
# 3. can be used with torch.compile
8+
# This file will be skipped if AITER is not installed
9+
# and the platform is not ROCm.
10+
11+
import importlib.util
12+
13+
import pytest
14+
import torch
15+
16+
# this import statement is needed to ensure the ops are registered
17+
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
18+
from vllm.platforms import current_platform
19+
20+
# need to import once to ensure the ops are registered
21+
# Check if aiter package is installed
22+
aiter_available = importlib.util.find_spec("aiter") is not None
23+
24+
pytestmark = pytest.mark.skipif(
25+
not (current_platform.is_rocm() and aiter_available),
26+
reason="AITER ops are only available on ROCm with aiter package installed")
27+
28+
29+
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
30+
"""Test that the custom op is correctly registered."""
31+
# Check if the op exists in torch.ops.vllm
32+
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
33+
34+
# Check if the op is callable
35+
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
36+
37+
38+
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
39+
"""Test that the op can be used with torch.compile."""
40+
# Create test tensors
41+
token = 64
42+
expert = 256
43+
num_expert_group = 8
44+
topk = 8
45+
topk_group = 4
46+
renormalize = True
47+
scale_factor = 1.0
48+
49+
gating_output = torch.randn((token, expert),
50+
dtype=torch.bfloat16,
51+
device="cuda")
52+
e_score_correction_bias = torch.randn((expert, ),
53+
dtype=torch.bfloat16,
54+
device="cuda")
55+
56+
device = gating_output.device
57+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
58+
topk_weights = torch.empty((token, topk),
59+
dtype=torch.float32,
60+
device=device)
61+
62+
# Define a function that uses the op
63+
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
64+
topk_weights, topk_ids):
65+
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
66+
gating_output, e_score_correction_bias, topk_weights, topk_ids,
67+
num_expert_group, topk_group, renormalize, scale_factor)
68+
69+
# Verify the op's fake implementation
70+
torch.library.opcheck(
71+
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
72+
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
73+
kwargs={
74+
"num_expert_group": num_expert_group,
75+
"topk_group": topk_group,
76+
"need_renorm": renormalize,
77+
"routed_scaling_factor": scale_factor
78+
},
79+
test_utils=("test_faketensor"))
80+
81+
# Compile the function with appropriate settings
82+
compiled_fn = torch.compile(biased_grouped_topk_fn,
83+
fullgraph=True,
84+
backend="inductor",
85+
mode="reduce-overhead",
86+
dynamic=False)
87+
88+
topk_weights_original = torch.empty((token, topk),
89+
dtype=torch.float32,
90+
device=device)
91+
topk_ids_original = torch.empty((token, topk),
92+
dtype=torch.int32,
93+
device=device)
94+
95+
topk_weights_compiled = torch.empty((token, topk),
96+
dtype=torch.float32,
97+
device=device)
98+
topk_ids_compiled = torch.empty((token, topk),
99+
dtype=torch.int32,
100+
device=device)
101+
102+
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
103+
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
104+
topk_weights_original, topk_ids_original)
105+
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
106+
topk_ids_compiled)
107+
108+
# Sort the results for comparison since the order might not be deterministic
109+
topk_ids_original, indices_original = torch.sort(topk_ids_original)
110+
topk_weights_original = torch.gather(topk_weights_original, 1,
111+
indices_original)
112+
113+
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
114+
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
115+
indices_compiled)
116+
117+
# Verify results match
118+
assert torch.allclose(topk_weights_original,
119+
topk_weights_compiled,
120+
rtol=1e-2,
121+
atol=1e-2)
122+
assert torch.allclose(topk_ids_original, topk_ids_compiled)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from vllm.forward_context import ForwardContext, get_forward_context
1818
from vllm.logger import init_logger
1919
from vllm.model_executor.custom_op import CustomOp
20+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
21+
is_rocm_aiter_moe_enabled)
2022
from vllm.model_executor.layers.quantization.base_config import (
2123
QuantizationConfig, QuantizeMethodBase)
2224
from vllm.model_executor.utils import set_weight_attrs
@@ -28,6 +30,11 @@
2830
from .fused_moe import fused_experts
2931
else:
3032
fused_experts = None # type: ignore
33+
if is_rocm_aiter_moe_enabled():
34+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
35+
rocm_aiter_biased_group_topk as grouped_topk)
36+
else:
37+
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
3138
if current_platform.is_tpu():
3239
# the iterative moe implementation is used until the moe_pallas is fixed
3340
from .moe_torch_iterative import fused_moe as fused_moe_pallas
@@ -802,8 +809,7 @@ def select_experts(hidden_states: torch.Tensor,
802809
custom_routing_function: Optional[Callable] = None,
803810
scoring_func: str = "softmax",
804811
e_score_correction_bias: Optional[torch.Tensor] = None):
805-
from vllm.model_executor.layers.fused_moe.fused_moe import (
806-
fused_topk, grouped_topk)
812+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
807813

808814
# DeekSeekv2 uses grouped_top_k
809815
if use_grouped_topk:

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,37 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
216216
pass
217217

218218

219+
def rocm_aiter_biased_grouped_topk_impl(
220+
gating_output: torch.Tensor,
221+
correction_bias: torch.Tensor,
222+
topk_weights: torch.Tensor,
223+
topk_ids: torch.Tensor,
224+
num_expert_group: int,
225+
topk_group: int,
226+
need_renorm: bool,
227+
routed_scaling_factor: float = 1.0 # mul to topk_weights
228+
) -> None:
229+
230+
from aiter import biased_grouped_topk
231+
232+
biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids,
233+
num_expert_group, topk_group, need_renorm,
234+
routed_scaling_factor)
235+
236+
237+
def rocm_aiter_biased_grouped_topk_fake(
238+
gating_output: torch.Tensor,
239+
correction_bias: torch.Tensor,
240+
topk_weights: torch.Tensor,
241+
topk_ids: torch.Tensor,
242+
num_expert_group: int,
243+
topk_group: int,
244+
need_renorm: bool,
245+
routed_scaling_factor: float = 1.0 # mul to topk_weights
246+
) -> None:
247+
pass
248+
249+
219250
if current_platform.is_rocm():
220251

221252
direct_register_custom_op(
@@ -258,6 +289,46 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
258289
dispatch_key=current_platform.dispatch_key,
259290
)
260291

292+
direct_register_custom_op(
293+
op_name="rocm_aiter_biased_grouped_topk",
294+
op_func=rocm_aiter_biased_grouped_topk_impl,
295+
mutates_args=["topk_weights", "topk_ids"],
296+
fake_impl=rocm_aiter_biased_grouped_topk_fake,
297+
dispatch_key=current_platform.dispatch_key,
298+
)
299+
300+
301+
def rocm_aiter_biased_group_topk(
302+
hidden_states: torch.Tensor,
303+
gating_output: torch.Tensor,
304+
topk: int,
305+
renormalize: bool,
306+
num_expert_group: int = 0,
307+
topk_group: int = 0,
308+
scoring_func: str = "sigmoid",
309+
e_score_correction_bias: Optional[torch.Tensor] = None
310+
) -> tuple[torch.Tensor, torch.Tensor]:
311+
assert scoring_func == "sigmoid", (
312+
"rocm_aiter_biased_group_topk only supports 'sigmoid' scoring_func.")
313+
assert e_score_correction_bias is not None, (
314+
"'e_score_correction_bias' must not be None.")
315+
token = hidden_states.shape[0]
316+
device = hidden_states.device
317+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
318+
topk_weights = torch.empty((token, topk),
319+
dtype=torch.float32,
320+
device=device)
321+
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
322+
gating_output,
323+
e_score_correction_bias,
324+
topk_weights,
325+
topk_ids,
326+
num_expert_group,
327+
topk_group,
328+
renormalize,
329+
)
330+
return topk_weights, topk_ids
331+
261332

262333
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
263334
w1: torch.Tensor,

0 commit comments

Comments
 (0)