Skip to content

Commit 76abd0c

Browse files
[Bugfix] Better FP8 supported defaults
1 parent 5b19b93 commit 76abd0c

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1616
_normalize_quant_group_shape, scaled_dequantize)
1717
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18-
apply_fp8_linear)
18+
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
1919
from vllm.platforms import current_platform
2020

2121
logger = init_logger(__name__)
@@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
3838
weight_scale: torch.Tensor,
3939
input_scale: Optional[torch.Tensor] = None,
4040
bias: Optional[torch.Tensor] = None,
41-
cutlass_block_fp8_supported: bool = True,
41+
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
4242
) -> torch.Tensor:
4343
assert input_scale is None
4444
# View input as 2D matrix for fp8 methods
@@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
8585
# `apply_fp8_linear`
8686
# NOTE(lucas): this is quite messy, we should think through this more formally
8787
def apply_fp8_linear_generic(
88-
input: torch.Tensor,
89-
weight: torch.Tensor,
90-
weight_scale: torch.Tensor,
91-
input_group_shape: Tuple[int, int],
92-
weight_group_shape: Tuple[int, int],
93-
input_scale: Optional[torch.Tensor] = None, # static scale if one
88+
input: torch.Tensor,
89+
weight: torch.Tensor,
90+
weight_scale: torch.Tensor,
91+
input_group_shape: Tuple[int, int],
92+
weight_group_shape: Tuple[int, int],
93+
input_scale: Optional[torch.Tensor] = None, # static scale if one
94+
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
95+
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
9496
) -> torch.Tensor:
9597
# View input as 2D matrix for fp8 methods
9698
input = input.view(-1, input.shape[-1])
@@ -105,14 +107,18 @@ def is_dim_blocked(dim, shape, group_shape):
105107
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
106108
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
107109
input_group_shape == (1, weight_group_shape[1]):
108-
return apply_w8a8_block_fp8_linear(input, weight,
109-
list(weight_group_shape),
110-
weight_scale)
110+
return apply_w8a8_block_fp8_linear(
111+
input,
112+
weight,
113+
list(weight_group_shape),
114+
weight_scale,
115+
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
111116
else:
112117
# Despite having linear in the it doesn't conform to
113118
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
114119
# so we explicitly transpose the weight matrix here
115120
return apply_fp8_linear(input, weight.T, weight_scale.T,
121+
cutlass_fp8_supported=cutlass_fp8_supported,
116122
use_per_token_if_dynamic=\
117123
(input_group_shape == (1, input.shape[1])))
118124

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
4242
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
4343

4444

45+
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
46+
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
47+
48+
4549
def per_tensor_dequantize(
4650
tensor: torch.Tensor, inv_scale: Union[float,
4751
torch.Tensor]) -> torch.Tensor:
@@ -109,7 +113,7 @@ def apply_fp8_linear(
109113
input_scale: Optional[torch.Tensor] = None,
110114
input_scale_ub: Optional[torch.Tensor] = None,
111115
bias: Optional[torch.Tensor] = None,
112-
cutlass_fp8_supported: bool = True,
116+
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
113117
use_per_token_if_dynamic: bool = False,
114118
) -> torch.Tensor:
115119
# ops.scaled_fp8_quant supports both dynamic and static quant.

0 commit comments

Comments
 (0)