Skip to content

Commit d42a382

Browse files
authored
Lint torchao folder (#1518)
1 parent 2c3d44c commit d42a382

File tree

93 files changed

+101873
-1284
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+101873
-1284
lines changed

ruff.toml

+5-12
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,13 @@
33
# To add a new path: Simply add it to the 'include' list.
44
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
55
include = [
6-
"torchao/float8/**/*.py",
7-
"torchao/quantization/**/*.py",
8-
"torchao/dtypes/**/*.py",
9-
"torchao/sparsity/**/*.py",
10-
"torchao/profiler/**/*.py",
11-
"torchao/testing/**/*.py",
12-
"torchao/_models/**/*.py",
13-
"torchao/kernel/**/*.py",
14-
"torchao/prototype/low_bit_optim/**.py",
15-
"torchao/utils.py",
16-
"torchao/ops.py",
17-
"torchao/_executorch_ops.py",
6+
"torchao/**/*.py",
187
"test/**/*.py",
198
]
209

10+
exclude = [
11+
"torchao/experimental/**/*.py",
12+
]
13+
2114
lint.select = ["F", "I"]
2215
lint.ignore = ["E731"]

torchao/__init__.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1-
import torch
21
import logging
32

43
# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
54
import warnings
6-
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: No module named 'numpy'")
5+
6+
import torch
7+
8+
warnings.filterwarnings(
9+
"ignore", message="Failed to initialize NumPy: No module named 'numpy'"
10+
)
711

812

913
# We use this "hack" to set torchao.__version__ correctly
1014
# the version of ao is dependent on environment variables for multiple architectures
1115
# For local development this will default to whatever is version.txt
1216
# For release builds this will be set the version+architecture_postfix
13-
from importlib.metadata import version, PackageNotFoundError
17+
from importlib.metadata import PackageNotFoundError, version
18+
1419
try:
1520
__version__ = version("torchao")
1621
except PackageNotFoundError:
17-
__version__ = 'unknown' # In case this logic breaks don't break the build
22+
__version__ = "unknown" # In case this logic breaks don't break the build
1823

1924
_IS_FBCODE = (
20-
hasattr(torch._utils_internal, "IS_FBSOURCE") and
21-
torch._utils_internal.IS_FBSOURCE
25+
hasattr(torch._utils_internal, "IS_FBSOURCE") and torch._utils_internal.IS_FBSOURCE
2226
)
2327
if not _IS_FBCODE:
2428
try:
2529
from pathlib import Path
30+
2631
so_files = list(Path(__file__).parent.glob("_C*.so"))
2732
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
2833
torch.ops.load_library(so_files[0])
@@ -34,14 +39,15 @@
3439
autoquant,
3540
quantize_,
3641
)
37-
from . import dtypes
38-
from . import testing
42+
43+
from . import dtypes, testing
3944

4045
__all__ = [
4146
"dtypes",
4247
"autoquant",
4348
"quantize_",
4449
"testing",
50+
"ops",
4551
]
4652

4753
# test-pytorchbot

torchao/prototype/autoround/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@
33
prepare_model_for_applying_auto_round_,
44
)
55
from torchao.prototype.autoround.multi_tensor import MultiTensor
6+
7+
__all__ = [
8+
"apply_auto_round",
9+
"prepare_model_for_applying_auto_round_",
10+
"MultiTensor",
11+
]

torchao/prototype/autoround/autoround_llm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import logging
32
from typing import Optional
43

54
import torch
@@ -67,7 +66,7 @@ def quantize_model_with_autoround_(
6766
multi_t_input_ids = MultiTensor(input_ids_lst)
6867

6968
# The optimization is applied during the forward pass
70-
out = model(multi_t_input_ids)
69+
model(multi_t_input_ids)
7170

7271
# Step 3. Apply the quantization
7372
quantize_(model, apply_auto_round(), is_target_module, device=device)

torchao/prototype/autoround/core.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from typing import Any, Callable, Dict, Optional, Tuple
44

55
import torch
6-
from torch.utils._pytree import tree_flatten, tree_unflatten
76

87
import torchao.prototype.autoround.utils as ar_utils
98
import torchao.quantization as ao_quant
109
from torchao.dtypes import TensorCoreTiledLayout, to_affine_quantized_intx_static
11-
from torchao.prototype.autoround.multi_tensor import _multi_tensor_config, MultiTensor
10+
from torchao.prototype.autoround.multi_tensor import MultiTensor, _multi_tensor_config
1211
from torchao.quantization.quant_primitives import ZeroPointDomain
1312
from torchao.utils import find_multiple
1413

torchao/prototype/autoround/eval_autoround.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run_evaluation(model, tokenizer, tasks, compile=False, batch_size=4):
4242
from lm_eval.evaluator import evaluate
4343
from lm_eval.models.huggingface import HFLM
4444
from lm_eval.tasks import get_task_dict
45-
except ImportError as e:
45+
except ImportError:
4646
print(
4747
"""
4848
Error: The 'lm_eval' module was not found.
@@ -70,7 +70,7 @@ def bench_accuracy(model, tokenizer, tasks, msg=""):
7070
from torchao.prototype.autoround.hf_eval_utils import run_evaluation
7171

7272
torch.cuda.empty_cache()
73-
res = run_evaluation(model, tokenizer, tasks=tasks)
73+
run_evaluation(model, tokenizer, tasks=tasks)
7474
torch.cuda.empty_cache()
7575

7676

torchao/prototype/autoround/multi_tensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def grouped_to_flat(cls, grouped):
101101
min(
102102
[True]
103103
+ [ # handle situation where tuples have size 0
104-
tup[0] == x for x in tup # check all elements match
104+
tup[0] == x
105+
for x in tup # check all elements match
105106
]
106107
)
107108
for tup in flat_tups

torchao/prototype/autoround/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def get_float_model_info(model_name_or_path, torch_dtype=torch.float32):
146146
logging.warning(f"Detected decoder class: {decoder_cls}")
147147
if decoder_cls is None:
148148
raise ValueError(
149-
f"Cannot detect the decoder class from the model, please provide it manually."
149+
"Cannot detect the decoder class from the model, please provide it manually."
150150
)
151151
return model, tokenizer, decoder_cls
152152

torchao/prototype/awq/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1-
from .api import insert_awq_observer_, awq_uintx
2-
from .core import AWQObservedLinear
1+
from .api import awq_uintx, insert_awq_observer_
2+
from .core import AWQObservedLinear
3+
4+
__all__ = [
5+
"awq_uintx",
6+
"insert_awq_observer_",
7+
"AWQObservedLinear",
8+
]

torchao/prototype/awq/api.py

+73-41
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,37 @@
11
import torch
2-
import torch.nn.functional as F
32

3+
from torchao.dtypes import (
4+
TensorCoreTiledLayout,
5+
to_affine_quantized_intx,
6+
)
7+
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
8+
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
49
from torchao.quantization.granularity import PerGroup
10+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
511
from torchao.quantization.quant_primitives import (
12+
_DTYPE_TO_QVALUE_BOUNDS,
613
MappingType,
714
ZeroPointDomain,
8-
_DTYPE_TO_QVALUE_BOUNDS,
915
)
10-
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
11-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
12-
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
13-
from torchao.dtypes import(
14-
to_affine_quantized_intx,
15-
TensorCoreTiledLayout,
16+
17+
from .core import (
18+
AWQObservedLinear,
19+
AWQObserver,
1620
)
17-
from .core import(
18-
AWQObserver,
19-
AWQObservedLinear,
20-
)
2121

22+
assert (
23+
len(_DTYPE_TO_BIT_WIDTH) > 0
24+
), "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
2225

23-
assert len(_DTYPE_TO_BIT_WIDTH) > 0, "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
2426

25-
def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, validation_sequence_len: int, quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128):
27+
def insert_awq_observer_(
28+
model: torch.nn.Module,
29+
n_validation_examples: int,
30+
validation_sequence_len: int,
31+
quant_dtype: torch.dtype = torch.uint4,
32+
scale_search_space_size: int = 20,
33+
group_size: int = 128,
34+
):
2635
"""
2736
Inserts AWQObserver into Linear layers of a given model.
2837
@@ -35,58 +44,75 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val
3544
group_size: Quantization granularity. Use -1 for channel wise quantization
3645
"""
3746
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
38-
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
47+
assert (
48+
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
49+
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
3950
# AQT config
4051
mapping_type = MappingType.ASYMMETRIC
4152
quantization_granularity = PerGroup(group_size)
4253
quant_min = 0
43-
quant_max = 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1
54+
quant_max = (
55+
255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1
56+
)
4457
eps = torch.finfo(torch.float32).eps
4558
preserve_zero = True
4659
zero_point_dtype = torch.int64
4760
zero_point_domain = ZeroPointDomain.INT
48-
4961

5062
def replace_with_observer(layer):
5163
# creates observer and replaces linear layers with AWQObservedLinear layers
5264
observer = AWQObserver(
5365
layer.weight,
54-
layer.bias,
55-
quantization_granularity,
66+
layer.bias,
67+
quantization_granularity,
5668
mapping_type,
57-
quant_dtype,
69+
quant_dtype,
5870
n_validation_examples,
5971
validation_sequence_len,
6072
scale_search_space_size,
61-
preserve_zero = preserve_zero,
62-
zero_point_domain = zero_point_domain,
63-
zero_point_dtype = zero_point_dtype,
73+
preserve_zero=preserve_zero,
74+
zero_point_domain=zero_point_domain,
75+
zero_point_dtype=zero_point_dtype,
6476
quant_min=quant_min,
65-
quant_max = quant_max,
66-
eps = eps)
77+
quant_max=quant_max,
78+
eps=eps,
79+
)
6780
return AWQObservedLinear.from_float(layer, observer)
81+
6882
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
6983

84+
7085
def _observed_linear_subclass_inserter(constructor):
7186
"""
7287
Replaces unquantized AWQObservedLinear instances with quantized linear instances.
7388
7489
Args:
7590
constructor: the function which applies quantization to the AWQObservedLinear layer
7691
"""
92+
7793
def insert_subclass(observed_linear):
7894
# creates the new linear layer using constructor
79-
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, observed_linear.bias!=None, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
80-
linear.weight = torch.nn.Parameter(constructor(observed_linear), requires_grad=False)
95+
linear = torch.nn.Linear(
96+
observed_linear.in_features,
97+
observed_linear.out_features,
98+
observed_linear.bias != None,
99+
device=observed_linear.weight.device,
100+
dtype=observed_linear.weight.dtype,
101+
)
102+
linear.weight = torch.nn.Parameter(
103+
constructor(observed_linear), requires_grad=False
104+
)
81105
linear.bias = observed_linear.bias
82106
return linear
83107

84108
return insert_subclass
85-
86109

87-
def awq_uintx(quant_dtype: torch.dtype = torch.uint4,
88-
group_size: int = 64,
89-
use_hqq: bool = False,):
110+
111+
def awq_uintx(
112+
quant_dtype: torch.dtype = torch.uint4,
113+
group_size: int = 64,
114+
use_hqq: bool = False,
115+
):
90116
"""
91117
Quantizes linear layers when passed into quantize_()
92118
@@ -95,8 +121,10 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4,
95121
group_size: Quantization granularity. Use -1 for channel wise quantization
96122
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
97123
"""
98-
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
99-
124+
assert (
125+
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
126+
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
127+
100128
def weight_quant_func(observed_linear):
101129
equalization_scale = observed_linear.act_obs.calculate_qparams()
102130
# AQT config
@@ -114,24 +142,28 @@ def weight_quant_func(observed_linear):
114142
zero_point_dtype = torch.int64
115143
zero_point_domain = ZeroPointDomain.INT
116144
_layout = UintxLayout(quant_dtype)
117-
145+
118146
mapping_type = MappingType.ASYMMETRIC
119147
block_size = (1, group_size)
120148
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
121149
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
122150
qw = to_affine_quantized_intx(
123151
observed_linear.weight * equalization_scale,
124152
mapping_type,
125-
block_size,
126-
target_dtype, quant_min,
127-
quant_max, eps,
153+
block_size,
154+
target_dtype,
155+
quant_min,
156+
quant_max,
157+
eps,
128158
zero_point_dtype=zero_point_dtype,
129159
preserve_zero=preserve_zero,
130160
zero_point_domain=zero_point_domain,
131161
_layout=_layout,
132-
use_hqq=use_hqq
162+
use_hqq=use_hqq,
133163
)
134-
135-
return to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
136-
164+
165+
return to_weight_tensor_with_linear_activation_scale_metadata(
166+
qw, equalization_scale
167+
)
168+
137169
return _observed_linear_subclass_inserter(weight_quant_func)

0 commit comments

Comments
 (0)