Skip to content

Commit 073cfa1

Browse files
committed
GLM4.6 support mtp with fullgraph
Signed-off-by: 1092626063 <[email protected]>
1 parent a539ae7 commit 073cfa1

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

tests/e2e/nightly/single_node/models/test_glm4_5.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
TENSOR_PARALLELS = [8]
3131
DATA_PARALLELS = [2]
32+
FULL_GRAPH = [True, False]
3233

3334
prompts = [
3435
"San Francisco is a",
@@ -65,11 +66,9 @@
6566
@pytest.mark.parametrize("model", MODELS)
6667
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
6768
@pytest.mark.parametrize("dp_size", DATA_PARALLELS)
68-
async def test_models(
69-
model: str,
70-
tp_size: int,
71-
dp_size: int,
72-
) -> None:
69+
@pytest.mark.parametrize("full_graph", FULL_GRAPH)
70+
async def test_models(model: str, tp_size: int, dp_size: int,
71+
full_graph: bool) -> None:
7372
port = get_open_port()
7473
env_dict = {"HCCL_BUFFSIZE": "1024"}
7574
server_args = [
@@ -91,6 +90,11 @@ async def test_models(
9190
"--gpu-memory-utilization",
9291
"0.9",
9392
]
93+
if full_graph:
94+
server_args += [
95+
"--compilation-config",
96+
'{"cudagraph_capture": [1,2,4,8,16], "cudagraph_model":"FULL_DECODE_ONLY"}'
97+
]
9498
request_keyword_args: dict[str, Any] = {
9599
**api_keyword_args,
96100
}

vllm_ascend/quantization/quant_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,15 @@ def is_layer_skipped_ascend(
173173
"are quantized. All shards of fused layers "
174174
"to have the same precision.")
175175
else:
176-
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
176+
# NOTE: In GLM4.6, the MTP draft model shares the same LM head weigthts
177+
# with the main model. Therefore, before `load_weights()` runs, some parameter
178+
# names may not include the expected prefix and may appear only with the
179+
# ".head" suffix. This can trigger a load-time error, so here we replace the
180+
# key with "lm_head.weight".
181+
key = prefix + '.weight'
182+
if key not in self.quant_description and ".head" in prefix:
183+
key = 'lm_head.weight'
184+
is_skipped = self.quant_description[key] == "FLOAT"
177185

178186
assert is_skipped is not None
179187
return is_skipped

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848
"DeepseekV32ForCausalLM":
4949
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
5050
"Qwen3NextForCausalLM":
51-
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
51+
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP"),
52+
"Glm4MoeForCausalLM": ("vllm.model_executor.models.glm4_moe_mtp",
53+
"Glm4MoeMTP")
5254
}
5355

5456

0 commit comments

Comments
 (0)