Skip to content

Commit 7d5ba09

Browse files
mikekgfbJack-Khuu
andauthored
Implement the AO API in torchchat quantization handlers and unify logic. (#1291)
* Align AO and CHAT quantizers in quantize.py Implement the AO API in torchchat quantization handlers and unify logic. 1 - implement .quantize() for TC quantization handlers and support args to make consistent with AO 2 - remove special handling for various combinations of parameters and use validate_args before calling with **q_kwargs 3 - remove check probing whether we successfully loaded a8wx and install an error-reporting handler if loading failed which will be called as quant handler and issue an error 4 - unify both tc and ao quantization handler dicts with shared calling logic * Typo / Docs Added comment, and a missing self parameter * Update quantize.py Fix typo (func -> q.__init__) * Update quantize.py Fix arg order (args with default after args w/o default) * Fix typo Fixed 2 typos. * Fix default args Fix default args * Update quantize.py * Update quantize.py * Update quantize.py --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 1ba88ad commit 7d5ba09

File tree

1 file changed

+66
-50
lines changed

1 file changed

+66
-50
lines changed

torchchat/utils/quantize.py

+66-50
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# from functools import reduce
2828
# from math import gcd
29-
from typing import Dict, Optional
29+
from typing import Dict, Optional, Callable, Any, List
3030

3131
import torch
3232
import torch.nn as nn
@@ -54,6 +54,33 @@
5454
# Flag for whether the a8wxdq quantizer is available.
5555
a8wxdq_load_error: Optional[Exception] = None
5656

57+
#########################################################################
58+
### handle arg validation ###
59+
60+
import inspect
61+
62+
def get_named_parameters(func: Callable) -> List[str]:
63+
# Get the signature of the function
64+
signature = inspect.signature(func)
65+
66+
# Extract the parameters from the signature
67+
parameters = signature.parameters
68+
69+
# Filter and return named parameters
70+
named_params = [
71+
name for name, param in parameters.items()
72+
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
73+
]
74+
return named_params
75+
76+
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:
77+
for key in q_kwargs.keys():
78+
if key not in named_params:
79+
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
80+
del q_kwargs[key]
81+
return q_kwargs
82+
83+
5784
#########################################################################
5885
### torchchat quantization API ###
5986

@@ -79,56 +106,32 @@ def quantize_model(
79106
quantize_options = json.loads(quantize_options)
80107

81108
for quantizer, q_kwargs in quantize_options.items():
82-
# Test if a8wxdq quantizer is available; Surface error if not.
83-
if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None:
84-
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
85-
86-
if (
87-
quantizer not in quantizer_class_dict
88-
and quantizer not in ao_quantizer_class_dict
89-
):
109+
if quantizer not in quantizer_class_dict:
90110
raise RuntimeError(f"unknown quantizer {quantizer} specified")
91-
if quantizer in ao_quantizer_class_dict:
111+
else:
92112
# Use tensor subclass API for int4 weight only.
93113
if device == "cuda" and quantizer == "linear:int4":
94114
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
95115
if not support_tensor_subclass:
96116
unwrap_tensor_subclass(model)
97117
continue
118+
98119
# We set global precision from quantize options if it is specified at cli.py:485
99120
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
100121
precision = get_precision()
101122

102-
try:
103-
if quantizer == "linear:a8wxdq":
104-
quant_handler = ao_quantizer_class_dict[quantizer](
105-
device=device,
106-
precision=precision,
107-
bitwidth=q_kwargs.get("bitwidth", 4),
108-
groupsize=q_kwargs.get("groupsize", 128),
109-
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
110-
)
111-
else:
112-
# Easier to ask forgiveness than permission
113-
quant_handler = ao_quantizer_class_dict[quantizer](
114-
groupsize=q_kwargs["groupsize"], device=device, precision=precision
115-
)
116-
except TypeError as e:
117-
if "unexpected keyword argument 'device'" in str(e):
118-
quant_handler = ao_quantizer_class_dict[quantizer](
119-
groupsize=q_kwargs["groupsize"], precision=precision
120-
)
121-
elif "unexpected keyword argument 'precision'" in str(e):
122-
quant_handler = ao_quantizer_class_dict[quantizer](
123-
groupsize=q_kwargs["groupsize"], device=device
124-
)
125-
else:
126-
raise e
123+
q = quantizer_class_dict[quantizer]
124+
named_params = get_named_parameters(q.__init__)
125+
q_kwargs = validate_args(named_params, q_kwargs, quantizer)
126+
127+
# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
128+
if "tokenizer" in named_params:
129+
q_kwargs["tokenizer"] = tokenizer
130+
quant_handler = q(device=device, precision=precision, **q_kwargs)
131+
132+
# quantize model
127133
model = quant_handler.quantize(model)
128-
else:
129-
model = quantizer_class_dict[quantizer](
130-
model, device=device, tokenizer=tokenizer, **q_kwargs
131-
).quantized_model()
134+
132135

133136

134137
#########################################################################
@@ -137,7 +140,7 @@ def quantize_model(
137140

138141

139142
class QuantHandler:
140-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
143+
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
141144
self.model_ = model
142145
self.device = device
143146
self.tokenizer = tokenizer
@@ -154,13 +157,18 @@ def quantized_model(self) -> nn.Module:
154157
self.model_.load_state_dict(model_updated_state_dict)
155158
return self.model_
156159

160+
# fallback for TC QuantHandlers that do not implement the method .quantize()
161+
def quantize(self, model: nn.Module) -> nn.Module:
162+
self.model_ = model
163+
return self.quantized_model()
164+
157165

158166
#########################################################################
159167
### wrapper for setting precision as a QuantHandler ###
160168

161169

162170
class PrecisionHandler(QuantHandler):
163-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
171+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
164172
self.model_ = model
165173
self.device = device
166174
self.tokenizer = tokenizer
@@ -169,6 +177,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
169177
dtype = name_to_dtype(dtype, device)
170178
self.dtype = dtype
171179

180+
# We simply ignore precision. because dtype is the precision arg as possibly string
181+
# maybe: assert(precision in [self.dtype, None])
182+
172183
def create_quantized_state_dict(self) -> Dict: # "StateDict"
173184
pass
174185

@@ -186,7 +197,7 @@ def quantized_model(self) -> nn.Module:
186197

187198

188199
class ExecutorHandler(QuantHandler):
189-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator):
200+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
190201
self.model_ = model
191202

192203
if isinstance(accelerator, str):
@@ -573,8 +584,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor:
573584
class WeightOnlyInt8QuantHandler(QuantHandler):
574585
def __init__(
575586
self,
576-
model: nn.Module,
577-
device,
587+
model: Optional[nn.Module] = None,
588+
device = None,
589+
precision=None,
578590
tokenizer=None,
579591
*,
580592
node_type: str = "*",
@@ -774,8 +786,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
774786
class EmbeddingOnlyQuantHandler(QuantHandler):
775787
def __init__(
776788
self,
777-
model: nn.Module,
778-
device,
789+
model: Optional[nn.Module] = None,
790+
device=None,
791+
precision=None,
779792
tokenizer=None,
780793
*,
781794
bitwidth: int = 8,
@@ -868,9 +881,6 @@ def quantized_model(self) -> nn.Module:
868881
"linear:int8": WeightOnlyInt8QuantHandler,
869882
"precision": PrecisionHandler,
870883
"executor": ExecutorHandler,
871-
}
872-
873-
ao_quantizer_class_dict = {
874884
"linear:int4": Int4WeightOnlyQuantizer,
875885
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
876886
}
@@ -890,7 +900,7 @@ def quantized_model(self) -> nn.Module:
890900
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
891901
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
892902
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
893-
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
903+
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
894904

895905
# Try loading custom op
896906
try:
@@ -903,4 +913,10 @@ def quantized_model(self) -> nn.Module:
903913
print("Slow fallback kernels will be used.")
904914

905915
except Exception as e:
916+
class ErrorHandler(QuantHandler):
917+
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
918+
global a8wxdq_load_error
919+
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
920+
906921
a8wxdq_load_error = e
922+
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler

0 commit comments

Comments
 (0)