Skip to content

Commit a1becad

Browse files
Update
[ghstack-poisoned]
1 parent ab8d5b5 commit a1becad

File tree

1 file changed

+16
-24
lines changed

1 file changed

+16
-24
lines changed

torchao/quantization/quant_api.py

+16-24
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@
3535
PlainLayout,
3636
SemiSparseLayout,
3737
TensorCoreTiledLayout,
38-
UintxLayout,
38+
to_affine_quantized_float8,
3939
to_affine_quantized_floatx,
4040
to_affine_quantized_floatx_static,
4141
to_affine_quantized_intx,
4242
to_marlinqqq_quantized_intx,
43+
UintxLayout,
4344
)
4445
from torchao.float8.float8_linear import Float8Linear
4546
from torchao.float8.inference import Float8MMConfig
@@ -51,36 +52,28 @@
5152
to_weight_tensor_with_linear_activation_quantization_metadata,
5253
)
5354
from torchao.utils import (
54-
TORCH_VERSION_AT_LEAST_2_4,
55-
TORCH_VERSION_AT_LEAST_2_5,
56-
TORCH_VERSION_AT_LEAST_2_6,
5755
is_MI300,
5856
is_sm_at_least_89,
5957
is_sm_at_least_90,
58+
TORCH_VERSION_AT_LEAST_2_4,
59+
TORCH_VERSION_AT_LEAST_2_5,
60+
TORCH_VERSION_AT_LEAST_2_6,
6061
)
6162

62-
from .autoquant import AutoQuantizableLinearWeight, autoquant
63+
from .autoquant import autoquant, AutoQuantizableLinearWeight
6364
from .GPTQ import (
6465
Int4WeightOnlyGPTQQuantizer,
6566
Int4WeightOnlyQuantizer,
6667
Int8DynActInt4WeightGPTQQuantizer,
6768
Int8DynActInt4WeightQuantizer,
6869
)
69-
from .granularity import (
70-
PerRow,
71-
PerTensor,
72-
)
70+
from .granularity import PerRow, PerTensor
7371
from .linear_activation_quantized_tensor import (
7472
LinearActivationQuantizedTensor,
7573
to_linear_activation_quantized,
7674
)
77-
from .qat import (
78-
intx_quantization_aware_training,
79-
)
80-
from .quant_primitives import (
81-
MappingType,
82-
ZeroPointDomain,
83-
)
75+
from .qat import intx_quantization_aware_training
76+
from .quant_primitives import MappingType, ZeroPointDomain
8477
from .subclass import (
8578
Int4WeightOnlyQuantizedLinearWeight,
8679
Int8DynamicallyQuantizedLinearWeight,
@@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915908
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916909
quantization + 2:4 sparsity to linear layers.
917910
"""
918-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
911+
warnings.warn(
912+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919913
920914
from torchao.dtypes import SemiSparseLayout
921-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
915+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
916+
)
922917

923918
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
924919

@@ -938,11 +933,10 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
938933

939934
def apply_float8wo_quant(weight):
940935
block_size = (1, weight.shape[1])
941-
return to_affine_quantized_floatx(
936+
return to_affine_quantized_float8(
942937
input_float=weight,
943938
block_size=block_size,
944939
target_dtype=weight_dtype,
945-
scale_dtype=None,
946940
_layout=Float8Layout(mm_config=None),
947941
)
948942

@@ -1016,11 +1010,10 @@ def _input_activation_quant_func_fp8(
10161010

10171011
block_size = get_block_size(x.shape, activation_granularity)
10181012
if scale is None:
1019-
activation = to_affine_quantized_floatx(
1013+
activation = to_affine_quantized_float8(
10201014
input_float=x,
10211015
block_size=block_size,
10221016
target_dtype=activation_dtype,
1023-
scale_dtype=torch.float32,
10241017
_layout=Float8Layout(mm_config=None), # Config is stored on weight
10251018
)
10261019
else:
@@ -1102,11 +1095,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
11021095
), "PerRow quantization only works for bfloat16 precision input weight"
11031096

11041097
block_size = get_block_size(weight.shape, weight_granularity)
1105-
quantized_weight = to_affine_quantized_floatx(
1098+
quantized_weight = to_affine_quantized_float8(
11061099
input_float=weight,
11071100
block_size=block_size,
11081101
target_dtype=weight_dtype,
1109-
scale_dtype=torch.float32,
11101102
_layout=Float8Layout(mm_config=mm_config),
11111103
)
11121104

0 commit comments

Comments
 (0)