Skip to content

Commit 1f91cfd

Browse files
committed
fix merge
Signed-off-by: Bill Nell <[email protected]>
1 parent 23cf129 commit 1f91cfd

File tree

2 files changed

+11
-29
lines changed

2 files changed

+11
-29
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.nn.parameter import UninitializedParameter
1414

1515
import vllm.envs as envs
16-
from vllm.config import get_current_vllm_config, ParallelConfig
16+
from vllm.config import ParallelConfig, get_current_vllm_config
1717
from vllm.distributed import (get_dp_group, get_ep_group,
1818
get_tensor_model_parallel_rank,
1919
get_tensor_model_parallel_world_size,
@@ -322,6 +322,7 @@ def __init__(self, moe: MoEConfig):
322322
super().__init__()
323323
self.fused_experts = fused_experts
324324
self.moe = moe
325+
325326
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
326327
if self.rocm_aiter_moe_enabled:
327328
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -501,6 +502,8 @@ def forward_cuda(
501502
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
502503

503504
if self.rocm_aiter_moe_enabled:
505+
assert not apply_router_weight_on_input
506+
assert expert_map is None
504507
return self.rocm_aiter_fused_experts(
505508
hidden_states=x,
506509
w1=layer.w13_weight,
@@ -510,8 +513,8 @@ def forward_cuda(
510513
activation=activation,
511514
apply_router_weight_on_input=apply_router_weight_on_input)
512515
else:
513-
return fused_experts(
514-
a1=x,
516+
return self.fused_experts(
517+
hidden_states=x,
515518
w1=layer.w13_weight,
516519
w2=layer.w2_weight,
517520
topk_weights=topk_weights,
@@ -1191,8 +1194,7 @@ def select_experts(hidden_states: torch.Tensor,
11911194
scoring_func: str = "softmax",
11921195
e_score_correction_bias: Optional[torch.Tensor] = None,
11931196
indices_type: Optional[torch.dtype] = None):
1194-
from vllm.model_executor.layers.fused_moe.fused_moe import (
1195-
fused_topk, grouped_topk)
1197+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
11961198

11971199
# DeekSeekv2 uses grouped_top_k
11981200
if use_grouped_topk:
@@ -1228,24 +1230,6 @@ def select_experts(hidden_states: torch.Tensor,
12281230

12291231
return topk_weights, topk_ids
12301232

1231-
def naive_multicast(self, x: torch.Tensor,
1232-
cu_tokens_across_dp_cpu: torch.Tensor):
1233-
assert (len(x.shape) == 2)
1234-
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
1235-
device=x.device,
1236-
dtype=x.dtype)
1237-
1238-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1239-
self.dp_rank - 1]
1240-
end = cu_tokens_across_dp_cpu[self.dp_rank]
1241-
buffer[start:end, :].copy_(x)
1242-
for idx in range(get_dp_group().world_size):
1243-
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
1244-
end = cu_tokens_across_dp_cpu[idx]
1245-
get_dp_group().broadcast(buffer[start:end, :], idx)
1246-
1247-
return buffer
1248-
12491233
def must_reduce_shared_expert_outputs(self) -> bool:
12501234
"""
12511235
The shared_experts are typically computed using the RowParallelLinear

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def set_prepare_finalize(
800800
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
801801
TritonOrDeepGemmExperts)
802802

803-
if self.use_marlin:
803+
if self.use_marlin or self.rocm_aiter_moe_enabled:
804804
return False
805805

806806
experts = TritonOrDeepGemmExperts(
@@ -834,9 +834,6 @@ def apply(
834834
apply_router_weight_on_input: bool = False,
835835
activation: str = "silu",
836836
) -> torch.Tensor:
837-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
838-
rocm_aiter_fused_experts)
839-
840837
topk_weights, topk_ids = FusedMoE.select_experts(
841838
hidden_states=x,
842839
router_logits=router_logits,
@@ -851,6 +848,8 @@ def apply(
851848
)
852849

853850
if self.rocm_aiter_moe_enabled:
851+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
852+
rocm_aiter_fused_experts)
854853
return rocm_aiter_fused_experts(
855854
x,
856855
layer.w13_weight,
@@ -867,8 +866,7 @@ def apply(
867866
a1_scale=layer.w13_input_scale,
868867
a2_scale=layer.w2_input_scale,
869868
block_shape=self.quant_config.weight_block_size)
870-
871-
if self.use_marlin:
869+
elif self.use_marlin:
872870
assert activation == "silu", (
873871
f"{activation} not supported for Marlin MoE.")
874872
assert not apply_router_weight_on_input, (

0 commit comments

Comments
 (0)