Skip to content

Commit 48fdd31

Browse files
authored
Ruff lint (#1646)
lint
1 parent 7815262 commit 48fdd31

3 files changed

+85
-59
lines changed

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
)
2222
from torchao.dtypes.utils import AQTTensorImpl, Layout
2323
from torchao.quantization.quant_primitives import (
24-
ZeroPointDomain,
2524
MappingType,
25+
ZeroPointDomain,
2626
choose_qparams_affine,
2727
quantize_affine,
2828
)
29-
3029
from torchao.utils import (
3130
TORCH_VERSION_AT_LEAST_2_6,
3231
)
@@ -41,12 +40,14 @@
4140
handler.setFormatter(formatter)
4241
logger.addHandler(handler)
4342

43+
4444
class Target(Enum):
4545
"""Enum that indicates the backend target"""
4646

4747
NATIVE = auto()
4848
ATEN = auto()
4949

50+
5051
def target_from_str(target: str) -> Target:
5152
if target.lower() == "native":
5253
return Target.NATIVE
@@ -55,6 +56,7 @@ def target_from_str(target: str) -> Target:
5556
else:
5657
raise ValueError(f"Invalid target: {target}")
5758

59+
5860
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
5961
bit_width: Optional[int]
6062
group_size: Optional[int]
@@ -157,7 +159,10 @@ def from_plain(
157159
):
158160
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
159161
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
160-
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
162+
assert layout.target in {
163+
Target.NATIVE,
164+
Target.ATEN,
165+
}, f"Unexpected target: {layout.target}"
161166

162167
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
163168
# when AOTI supports int
@@ -167,10 +172,14 @@ def from_plain(
167172
k_tensor = torch.empty(0, k, dtype=torch.int8)
168173

169174
if layout.target == Target.ATEN:
170-
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
175+
assert (
176+
TORCH_VERSION_AT_LEAST_2_6
177+
), "aten target is requires torch version > 2.6.0"
171178
int_data = int_data.add(8)
172-
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
173-
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
179+
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
180+
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(
181+
int_data, scale, bias, layout.group_size, k, n
182+
)
174183
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)
175184

176185
if layout.has_weight_zeros:
@@ -248,12 +257,11 @@ def __tensor_unflatten__(
248257
def _linear_check(input_tensor, weight_tensor, bias):
249258
layout = weight_tensor.tensor_impl.get_layout()
250259
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
251-
bias is None or layout.target == Target.ATEN # Aten target allows bias
260+
bias is None or layout.target == Target.ATEN # Aten target allows bias
252261
)
253262

254263

255264
def _linear_impl(input_tensor, weight_tensor, bias):
256-
257265
def _impl_2d_native(input_tensor, weight_tensor):
258266
assert input_tensor.dim() == 2
259267
assert weight_tensor.dim() == 2
@@ -299,14 +307,13 @@ def _impl_2d_aten(input_tensor, weight_tensor):
299307
group_size = weight_tensor.tensor_impl.get_layout().group_size
300308
packed_weight = weight_tensor.tensor_impl.packed_weight
301309
return torch.ops.aten._dyn_quant_matmul_4bit(
302-
input_tensor, packed_weight, group_size, k_, n)
310+
input_tensor, packed_weight, group_size, k_, n
311+
)
303312

304313
target = weight_tensor.tensor_impl.get_layout().target
305314

306315
if target == Target.ATEN:
307-
assert (
308-
TORCH_VERSION_AT_LEAST_2_6 == 1
309-
), "Target.ATEN requires torch >= 2.6.0"
316+
assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0"
310317
_impl_2d = _impl_2d_aten
311318
elif target == Target.NATIVE:
312319
_impl_2d = _impl_2d_native
@@ -327,6 +334,7 @@ def _impl_2d_aten(input_tensor, weight_tensor):
327334
res = res.reshape(*lead_shape, m, n)
328335
return res
329336

337+
330338
register_aqt_quantized_linear_dispatch(
331339
_linear_check,
332340
_linear_impl,
@@ -354,12 +362,17 @@ def from_hp_to_intx(
354362
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
355363
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
356364
use_hqq: bool = False,
357-
bias: Optional[torch.Tensor] = None
365+
bias: Optional[torch.Tensor] = None,
358366
):
359-
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
367+
assert (
368+
use_hqq == False
369+
), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360370
assert isinstance(
361-
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
362-
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
371+
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout
372+
), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
373+
assert (
374+
_layout.target == Target.ATEN
375+
), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363376
original_shape = input_float.shape
364377
input_float = _layout.pre_process(input_float)
365378

@@ -405,4 +418,7 @@ def from_hp_to_intx(
405418
dtype=input_float.dtype,
406419
)
407420

408-
to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
421+
422+
to_packedlinearint8dynamicactivationintxweight_quantized_intx = (
423+
PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
424+
)

torchao/experimental/quant_api.py

+50-39
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import sys
87
import logging
8+
import sys
99
from typing import Optional, Union
1010

1111
import torch
@@ -15,22 +15,21 @@
1515
quantize_per_channel_group,
1616
)
1717

18+
from torchao.dtypes import PlainLayout
1819
from torchao.quantization.granularity import (
1920
PerGroup,
2021
PerRow,
2122
)
2223
from torchao.utils import (
2324
TORCH_VERSION_AT_LEAST_2_6,
2425
)
25-
from torchao.dtypes import PlainLayout
2626

2727
logger = logging.getLogger(__name__)
2828
logger.setLevel(logging.WARNING)
2929

3030

3131
handler = logging.StreamHandler(sys.stdout)
32-
formatter = logging.Formatter(
33-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
3433
handler.setFormatter(formatter)
3534
logger.addHandler(handler)
3635

@@ -494,8 +493,8 @@ def quantize(self, model: nn.Module) -> nn.Module:
494493

495494
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
496495
PackedLinearInt8DynamicActivationIntxWeightLayout,
497-
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
498496
Target,
497+
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
499498
)
500499
from torchao.quantization.linear_activation_quantized_tensor import (
501500
to_linear_activation_quantized,
@@ -515,7 +514,9 @@ def int8_dynamic_activation_intx_weight(
515514
has_weight_zeros: bool = False,
516515
weight_mapping_type=MappingType.ASYMMETRIC,
517516
act_mapping_type=MappingType.ASYMMETRIC,
518-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow
517+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
518+
target="native"
519+
), # PlainLayout() also works, but will be slow
519520
):
520521
"""
521522
Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers.
@@ -540,13 +541,10 @@ def int8_dynamic_activation_intx_weight(
540541
"""
541542

542543
def is_torchao_op_skippable(layout):
543-
return (
544-
isinstance(layout, PlainLayout) or
545-
(
546-
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and
547-
layout.target == Target.ATEN
548-
)
549-
)
544+
return isinstance(layout, PlainLayout) or (
545+
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
546+
and layout.target == Target.ATEN
547+
)
550548

551549
if not is_torchao_op_skippable(layout):
552550
try:
@@ -574,7 +572,10 @@ def is_torchao_op_skippable(layout):
574572
)
575573
bit_width = dtype_to_bit_width[weight_dtype]
576574
layout_arg = layout
577-
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN
575+
propagate_bias = (
576+
isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout)
577+
and layout_arg.target == Target.ATEN
578+
)
578579

579580
def apply(weight, bias: Optional[torch.Tensor] = None):
580581
if isinstance(granularity, PerGroup):
@@ -612,35 +613,45 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
612613
target="aten" if layout.target == Target.ATEN else "native",
613614
)
614615
if layout.target == Target.ATEN:
615-
if weight_dtype != torch.int4 or \
616-
has_weight_zeros != True or \
617-
weight_mapping_type == MappingType.ASYMMETRIC:
616+
if (
617+
weight_dtype != torch.int4
618+
or has_weight_zeros != True
619+
or weight_mapping_type == MappingType.ASYMMETRIC
620+
):
618621
raise NotImplementedError(
619-
f"target 'aten' requires:\n"
620-
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
621-
f"- has_weight_zeros to be True,\n"
622-
f"- weight_dtype to be torch.int4,\n"
623-
f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
622+
"target 'aten' requires:\n"
623+
"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
624+
"- has_weight_zeros to be True,\n"
625+
"- weight_dtype to be torch.int4,\n"
626+
"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
624627
)
625-
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
628+
assert (
629+
TORCH_VERSION_AT_LEAST_2_6
630+
), "aten target is requires torch version > 2.6.0"
626631
if torch.backends.kleidiai.is_available():
627632
if isinstance(granularity, PerGroup):
628-
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
629-
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx
630-
631-
quantizer_args = [weight,
632-
weight_mapping_type,
633-
(1, group_size),
634-
torch.int32,
635-
quant_min,
636-
quant_max,
637-
torch.finfo(torch.float32).eps,
638-
scale_dtype,
639-
torch.int8,
640-
has_weight_zeros,
641-
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
642-
layout,
643-
False] + ([bias] if propagate_bias else [])
633+
scale_dtype = (
634+
torch.bfloat16
635+
) # KleidiAI kernel requires bfloat16 scale_dtype
636+
tensor_quantizer = (
637+
to_packedlinearint8dynamicactivationintxweight_quantized_intx
638+
)
639+
640+
quantizer_args = [
641+
weight,
642+
weight_mapping_type,
643+
(1, group_size),
644+
torch.int32,
645+
quant_min,
646+
quant_max,
647+
torch.finfo(torch.float32).eps,
648+
scale_dtype,
649+
torch.int8,
650+
has_weight_zeros,
651+
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
652+
layout,
653+
False,
654+
] + ([bias] if propagate_bias else [])
644655

645656
weight = tensor_quantizer(*quantizer_args)
646657

torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
int8_dynamic_activation_intx_weight,
1818
)
1919
from torchao.quantization.granularity import (
20-
PerGroup,
2120
PerRow,
2221
)
2322
from torchao.quantization.quant_api import quantize_
24-
from torchao.utils import unwrap_tensor_subclass
2523
from torchao.quantization.quant_primitives import MappingType
2624

2725

@@ -57,7 +55,8 @@ def test_accuracy(self):
5755
has_weight_zeros=has_weight_zeros,
5856
weight_mapping_type=weight_mapping_type,
5957
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
60-
target="aten"), # default
58+
target="aten"
59+
), # default
6160
),
6261
)
6362

0 commit comments

Comments
 (0)