15
15
from vllm .model_executor .layers .quantization .utils .quant_utils import (
16
16
_normalize_quant_group_shape , scaled_dequantize )
17
17
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
18
- apply_fp8_linear )
18
+ CUTLASS_BLOCK_FP8_SUPPORTED , CUTLASS_FP8_SUPPORTED , apply_fp8_linear )
19
19
from vllm .platforms import current_platform
20
20
21
21
logger = init_logger (__name__ )
@@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
38
38
weight_scale : torch .Tensor ,
39
39
input_scale : Optional [torch .Tensor ] = None ,
40
40
bias : Optional [torch .Tensor ] = None ,
41
- cutlass_block_fp8_supported : bool = True ,
41
+ cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
42
42
) -> torch .Tensor :
43
43
assert input_scale is None
44
44
# View input as 2D matrix for fp8 methods
@@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
85
85
# `apply_fp8_linear`
86
86
# NOTE(lucas): this is quite messy, we should think through this more formally
87
87
def apply_fp8_linear_generic (
88
- input : torch .Tensor ,
89
- weight : torch .Tensor ,
90
- weight_scale : torch .Tensor ,
91
- input_group_shape : Tuple [int , int ],
92
- weight_group_shape : Tuple [int , int ],
93
- input_scale : Optional [torch .Tensor ] = None , # static scale if one
88
+ input : torch .Tensor ,
89
+ weight : torch .Tensor ,
90
+ weight_scale : torch .Tensor ,
91
+ input_group_shape : Tuple [int , int ],
92
+ weight_group_shape : Tuple [int , int ],
93
+ input_scale : Optional [torch .Tensor ] = None , # static scale if one
94
+ cutlass_fp8_supported : bool = CUTLASS_FP8_SUPPORTED ,
95
+ cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
94
96
) -> torch .Tensor :
95
97
# View input as 2D matrix for fp8 methods
96
98
input = input .view (- 1 , input .shape [- 1 ])
@@ -105,14 +107,18 @@ def is_dim_blocked(dim, shape, group_shape):
105
107
if is_dim_blocked (0 , weight .shape , weight_group_shape [0 ])\
106
108
and is_dim_blocked (1 , weight .shape , weight_group_shape [1 ]) and \
107
109
input_group_shape == (1 , weight_group_shape [1 ]):
108
- return apply_w8a8_block_fp8_linear (input , weight ,
109
- list (weight_group_shape ),
110
- weight_scale )
110
+ return apply_w8a8_block_fp8_linear (
111
+ input ,
112
+ weight ,
113
+ list (weight_group_shape ),
114
+ weight_scale ,
115
+ cutlass_block_fp8_supported = cutlass_block_fp8_supported )
111
116
else :
112
117
# Despite having linear in the it doesn't conform to
113
118
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
114
119
# so we explicitly transpose the weight matrix here
115
120
return apply_fp8_linear (input , weight .T , weight_scale .T ,
121
+ cutlass_fp8_supported = cutlass_fp8_supported ,
116
122
use_per_token_if_dynamic = \
117
123
(input_group_shape == (1 , input .shape [1 ])))
118
124
0 commit comments