Skip to content

Commit 5b9f29e

Browse files
wangyao-iguanguan0308
authored andcommitted
support mxfp8 quantization (qwen dense) (vllm-project#5723)
### What this PR does / why we need it? support mxfp8 quantization (qwen liner layer) ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 Signed-off-by: wangyao <[email protected]>
1 parent fd4b7d9 commit 5b9f29e

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
4646
mlp_tp_enable, oproj_tp_enable)
4747

48-
from .utils import get_quant_method
48+
from .utils import get_quant_method, is_mx_quant_type
4949

5050

5151
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
@@ -401,7 +401,8 @@ def create_weights(
401401
set_weight_attrs(param, {"output_dim": 0})
402402
layer.register_parameter(pergroup_name, param)
403403
set_weight_attrs(param, extra_weight_attrs)
404-
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
404+
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
405+
or is_mx_quant_type(self.quant_method):
405406
setattr(param, "input_dim", 1)
406407
param.input_dim = 1
407408

vllm_ascend/quantization/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AscendW8A8DynamicLinearMethod)
1515
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
1616
AscendW8A8PDMixLinearMethod)
17+
from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
1718
from .w8a16 import AscendW8A16LinearMethod
1819

1920
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
@@ -40,7 +41,10 @@
4041
},
4142
"W8A16": {
4243
"linear": AscendW8A16LinearMethod,
43-
}
44+
},
45+
"W8A8_MXFP8": {
46+
"linear": AscendW8A8MXFP8DynamicLinearMethod,
47+
},
4448
}
4549

4650

@@ -113,3 +117,9 @@ def get_quant_method_modelslim(
113117
)
114118
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
115119
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
120+
121+
122+
def is_mx_quant_type(instance: Any) -> bool:
123+
"""Checks if the quantization method is a mix-precision type."""
124+
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
125+
return isinstance(instance, MX_QUANT_TYPES)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Any, Dict, Optional
19+
20+
import torch
21+
import torch_npu
22+
from vllm.config import get_current_vllm_config
23+
24+
25+
class AscendW8A8MXFP8DynamicLinearMethod:
26+
"""Linear method for Ascend W8A8_DYNAMIC.
27+
"""
28+
model_dtype = None
29+
30+
def __init__(self):
31+
vllm_config = get_current_vllm_config()
32+
self.group_size = vllm_config.quant_config.quant_description.get(
33+
"group_size", 32)
34+
35+
@staticmethod
36+
def get_weight(input_size: int, output_size: int,
37+
params_dtype: torch.dtype) -> Dict[str, Any]:
38+
params_dict = {
39+
"weight":
40+
torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)
41+
}
42+
return params_dict
43+
44+
@staticmethod
45+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
46+
return {}
47+
48+
@staticmethod
49+
def get_perchannel_param(
50+
output_size: int,
51+
params_dtype: torch.dtype,
52+
) -> Dict[str, Any]:
53+
return {}
54+
55+
def get_pergroup_param(self,
56+
input_size: int,
57+
output_size: int,
58+
params_dtype: torch.dtype,
59+
layer_type: Optional[str] = None) -> Dict[str, Any]:
60+
params_dict = {}
61+
params_dict["weight_scale"] = torch.empty(output_size,
62+
input_size //
63+
self.group_size,
64+
dtype=torch.uint8)
65+
return params_dict
66+
67+
def apply(
68+
self,
69+
layer: torch.nn.Module,
70+
x: torch.Tensor,
71+
bias: Optional[torch.Tensor] = None,
72+
tp_rank: Optional[int] = 0,
73+
) -> torch.Tensor:
74+
75+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(
76+
x, dst_type=torch.float8_e4m3fn)
77+
pertoken_scale = dynamic_scale
78+
output_dtype = x.dtype
79+
80+
output = torch_npu.npu_quant_matmul(
81+
quantized_x,
82+
layer.weight,
83+
layer.weight_scale,
84+
scale_dtype=torch_npu.float8_e8m0fnu,
85+
pertoken_scale=pertoken_scale,
86+
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
87+
bias=bias,
88+
output_dtype=output_dtype,
89+
group_sizes=[1, 1, self.group_size])
90+
91+
return output
92+
93+
def process_weights_after_loading(self, layer):
94+
n_dim, k_dim = layer.weight_scale.data.shape
95+
layer.weight_scale.data = layer.weight_scale.data.reshape(
96+
n_dim, k_dim // 2, 2)
97+
layer.weight.data = layer.weight.data.transpose(0, 1)
98+
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)

0 commit comments

Comments
 (0)