Skip to content

Commit f936444

Browse files
committed
support mxfp8 quantization (liner layer)
Signed-off-by: wangyao <[email protected]>
1 parent 92353c0 commit f936444

File tree

3 files changed

+115
-5
lines changed

3 files changed

+115
-5
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
4646
mlp_tp_enable, oproj_tp_enable)
4747

48-
from .utils import get_quant_method
49-
48+
from .utils import get_quant_method, is_mx_quant_type
5049

5150
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
5251
class AscendQuantConfig(QuantizationConfig):
@@ -387,8 +386,9 @@ def create_weights(
387386
set_weight_attrs(param, {"output_dim": 0})
388387
layer.register_parameter(pergroup_name, param)
389388
set_weight_attrs(param, extra_weight_attrs)
390-
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
391-
setattr(param, "input_dim", 1)
389+
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
390+
or is_mx_quant_type(self.quant_method):
391+
setattr(param, "input_dim", 1)
392392
param.input_dim = 1
393393

394394
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

vllm_ascend/quantization/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
1616
AscendW8A8PDMixLinearMethod)
1717
from .w8a16 import AscendW8A16LinearMethod
18+
from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
1819

1920
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
2021
"W4A16": {
@@ -40,7 +41,10 @@
4041
},
4142
"W8A16": {
4243
"linear": AscendW8A16LinearMethod,
43-
}
44+
},
45+
"W8A8_MXFP8": {
46+
"linear": AscendW8A8MXFP8DynamicLinearMethod,
47+
},
4448
}
4549

4650

@@ -113,3 +117,11 @@ def get_quant_method_modelslim(
113117
)
114118
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
115119
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
120+
121+
122+
def is_mx_quant_type(instance: Any) :
123+
mx_quant_type_list = {AscendW8A8MXFP8DynamicLinearMethod}
124+
for mx_quant_type in mx_quant_type_list:
125+
if isinstance(instance, mx_quant_type):
126+
return True
127+
return False
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
from typing import Any, Callable, Dict, Optional, Tuple, Union
20+
21+
import torch
22+
import torch_npu
23+
from vllm.config import get_current_vllm_config
24+
25+
from vllm_ascend.ascend_config import get_ascend_config
26+
from vllm_ascend.distributed.parallel_state import get_mc2_group
27+
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
28+
29+
30+
31+
class AscendW8A8MXFP8DynamicLinearMethod:
32+
"""Linear method for Ascend W8A8_DYNAMIC.
33+
"""
34+
model_dtype = None
35+
36+
37+
def __init__(self):
38+
vllm_config = get_current_vllm_config()
39+
self.group_size = vllm_config.quant_config.quant_description.get(
40+
"group_size", 32)
41+
42+
@staticmethod
43+
def get_weight(input_size: int, output_size: int,
44+
params_dtype: torch.dtype) -> Dict[str, Any]:
45+
params_dict = {
46+
"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)
47+
}
48+
return params_dict
49+
50+
@staticmethod
51+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
52+
return {}
53+
54+
@staticmethod
55+
def get_perchannel_param(
56+
output_size: int,
57+
params_dtype: torch.dtype,
58+
) -> Dict[str, Any]:
59+
return {}
60+
61+
def get_pergroup_param(self, input_size: int, output_size: int,
62+
params_dtype: torch.dtype, layer_type: Optional[str] = None) -> Dict[str, Any]:
63+
params_dict = {}
64+
params_dict["weight_scale"] = torch.empty(
65+
output_size, input_size // self.group_size, dtype=torch.uint8)
66+
return params_dict
67+
68+
def apply(
69+
self,
70+
layer: torch.nn.Module,
71+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
72+
bias: Optional[torch.Tensor] = None,
73+
tp_rank: Optional[int] = 0,
74+
) -> torch.Tensor:
75+
76+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
77+
pertoken_scale = dynamic_scale
78+
output_dtype = x.dtype
79+
80+
output = torch_npu.npu_quant_matmul(
81+
quantized_x,
82+
layer.weight,
83+
layer.weight_scale,
84+
scale_dtype=torch_npu.float8_e8m0fnu,
85+
pertoken_scale=pertoken_scale,
86+
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
87+
bias=bias,
88+
output_dtype=output_dtype,
89+
group_sizes=[1, 1, self.group_size]
90+
)
91+
92+
return output
93+
94+
def process_weights_after_loading(self, layer):
95+
n_dim, k_dim = layer.weight_scale.data.shape
96+
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim//2, 2)
97+
layer.weight.data = layer.weight.data.transpose(0, 1)
98+
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)

0 commit comments

Comments
 (0)