26
26
27
27
# from functools import reduce
28
28
# from math import gcd
29
- from typing import Dict , Optional
29
+ from typing import Dict , Optional , Callable , Any , List
30
30
31
31
import torch
32
32
import torch .nn as nn
54
54
# Flag for whether the a8wxdq quantizer is available.
55
55
a8wxdq_load_error : Optional [Exception ] = None
56
56
57
+ #########################################################################
58
+ ### handle arg validation ###
59
+
60
+ import inspect
61
+
62
+ def get_named_parameters (func : Callable ) -> List [str ]:
63
+ # Get the signature of the function
64
+ signature = inspect .signature (func )
65
+
66
+ # Extract the parameters from the signature
67
+ parameters = signature .parameters
68
+
69
+ # Filter and return named parameters
70
+ named_params = [
71
+ name for name , param in parameters .items ()
72
+ if param .kind in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
73
+ ]
74
+ return named_params
75
+
76
+ def validate_args (named_params : List [str ], q_kwargs : Dict [str , Any ], quantizer : Optional [str ] = None ) -> Dict [str , Any ]:
77
+ for key in q_kwargs .keys ():
78
+ if key not in named_params :
79
+ print (f"Specification for quantizer { quantizer } has extraneous key { key } . Ignoring." )
80
+ del q_kwargs [key ]
81
+ return q_kwargs
82
+
83
+
57
84
#########################################################################
58
85
### torchchat quantization API ###
59
86
@@ -79,56 +106,32 @@ def quantize_model(
79
106
quantize_options = json .loads (quantize_options )
80
107
81
108
for quantizer , q_kwargs in quantize_options .items ():
82
- # Test if a8wxdq quantizer is available; Surface error if not.
83
- if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None :
84
- raise Exception (f"Note: Failed to load torchao experimental a8wxdq quantizer with error: { a8wxdq_load_error } " )
85
-
86
- if (
87
- quantizer not in quantizer_class_dict
88
- and quantizer not in ao_quantizer_class_dict
89
- ):
109
+ if quantizer not in quantizer_class_dict :
90
110
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
91
- if quantizer in ao_quantizer_class_dict :
111
+ else :
92
112
# Use tensor subclass API for int4 weight only.
93
113
if device == "cuda" and quantizer == "linear:int4" :
94
114
quantize_ (model , int4_weight_only (q_kwargs ["groupsize" ]))
95
115
if not support_tensor_subclass :
96
116
unwrap_tensor_subclass (model )
97
117
continue
118
+
98
119
# We set global precision from quantize options if it is specified at cli.py:485
99
120
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
100
121
precision = get_precision ()
101
122
102
- try :
103
- if quantizer == "linear:a8wxdq" :
104
- quant_handler = ao_quantizer_class_dict [quantizer ](
105
- device = device ,
106
- precision = precision ,
107
- bitwidth = q_kwargs .get ("bitwidth" , 4 ),
108
- groupsize = q_kwargs .get ("groupsize" , 128 ),
109
- has_weight_zeros = q_kwargs .get ("has_weight_zeros" , False ),
110
- )
111
- else :
112
- # Easier to ask forgiveness than permission
113
- quant_handler = ao_quantizer_class_dict [quantizer ](
114
- groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
115
- )
116
- except TypeError as e :
117
- if "unexpected keyword argument 'device'" in str (e ):
118
- quant_handler = ao_quantizer_class_dict [quantizer ](
119
- groupsize = q_kwargs ["groupsize" ], precision = precision
120
- )
121
- elif "unexpected keyword argument 'precision'" in str (e ):
122
- quant_handler = ao_quantizer_class_dict [quantizer ](
123
- groupsize = q_kwargs ["groupsize" ], device = device
124
- )
125
- else :
126
- raise e
123
+ q = quantizer_class_dict [quantizer ]
124
+ named_params = get_named_parameters (q .__init__ )
125
+ q_kwargs = validate_args (named_params , q_kwargs , quantizer )
126
+
127
+ # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
128
+ if "tokenizer" in named_params :
129
+ q_kwargs ["tokenizer" ] = tokenizer
130
+ quant_handler = q (device = device , precision = precision , ** q_kwargs )
131
+
132
+ # quantize model
127
133
model = quant_handler .quantize (model )
128
- else :
129
- model = quantizer_class_dict [quantizer ](
130
- model , device = device , tokenizer = tokenizer , ** q_kwargs
131
- ).quantized_model ()
134
+
132
135
133
136
134
137
#########################################################################
@@ -137,7 +140,7 @@ def quantize_model(
137
140
138
141
139
142
class QuantHandler :
140
- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None ):
143
+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None ):
141
144
self .model_ = model
142
145
self .device = device
143
146
self .tokenizer = tokenizer
@@ -154,13 +157,18 @@ def quantized_model(self) -> nn.Module:
154
157
self .model_ .load_state_dict (model_updated_state_dict )
155
158
return self .model_
156
159
160
+ # fallback for TC QuantHandlers that do not implement the method .quantize()
161
+ def quantize (self , model : nn .Module ) -> nn .Module :
162
+ self .model_ = model
163
+ return self .quantized_model ()
164
+
157
165
158
166
#########################################################################
159
167
### wrapper for setting precision as a QuantHandler ###
160
168
161
169
162
170
class PrecisionHandler (QuantHandler ):
163
- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None , * , dtype ):
171
+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None , * , dtype ):
164
172
self .model_ = model
165
173
self .device = device
166
174
self .tokenizer = tokenizer
@@ -169,6 +177,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
169
177
dtype = name_to_dtype (dtype , device )
170
178
self .dtype = dtype
171
179
180
+ # We simply ignore precision. because dtype is the precision arg as possibly string
181
+ # maybe: assert(precision in [self.dtype, None])
182
+
172
183
def create_quantized_state_dict (self ) -> Dict : # "StateDict"
173
184
pass
174
185
@@ -186,7 +197,7 @@ def quantized_model(self) -> nn.Module:
186
197
187
198
188
199
class ExecutorHandler (QuantHandler ):
189
- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None , * , accelerator ):
200
+ def __init__ (self , model : Optional [ nn .Module ] = None , device = "cpu" , precision = None , tokenizer = None , * , accelerator ):
190
201
self .model_ = model
191
202
192
203
if isinstance (accelerator , str ):
@@ -573,8 +584,9 @@ def et_forward(self, input: torch.Tensor) -> torch.Tensor:
573
584
class WeightOnlyInt8QuantHandler (QuantHandler ):
574
585
def __init__ (
575
586
self ,
576
- model : nn .Module ,
577
- device ,
587
+ model : Optional [nn .Module ] = None ,
588
+ device = None ,
589
+ precision = None ,
578
590
tokenizer = None ,
579
591
* ,
580
592
node_type : str = "*" ,
@@ -774,8 +786,9 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
774
786
class EmbeddingOnlyQuantHandler (QuantHandler ):
775
787
def __init__ (
776
788
self ,
777
- model : nn .Module ,
778
- device ,
789
+ model : Optional [nn .Module ] = None ,
790
+ device = None ,
791
+ precision = None ,
779
792
tokenizer = None ,
780
793
* ,
781
794
bitwidth : int = 8 ,
@@ -868,9 +881,6 @@ def quantized_model(self) -> nn.Module:
868
881
"linear:int8" : WeightOnlyInt8QuantHandler ,
869
882
"precision" : PrecisionHandler ,
870
883
"executor" : ExecutorHandler ,
871
- }
872
-
873
- ao_quantizer_class_dict = {
874
884
"linear:int4" : Int4WeightOnlyQuantizer ,
875
885
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
876
886
}
@@ -890,7 +900,7 @@ def quantized_model(self) -> nn.Module:
890
900
sys .modules ["torchao_experimental_quant_api" ] = torchao_experimental_quant_api
891
901
torchao_experimental_quant_api_spec .loader .exec_module (torchao_experimental_quant_api )
892
902
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
893
- ao_quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightQuantizer
903
+ quantizer_class_dict ["linear:a8wxdq" ] = Int8DynActIntxWeightQuantizer
894
904
895
905
# Try loading custom op
896
906
try :
@@ -903,4 +913,10 @@ def quantized_model(self) -> nn.Module:
903
913
print ("Slow fallback kernels will be used." )
904
914
905
915
except Exception as e :
916
+ class ErrorHandler (QuantHandler ):
917
+ def __init__ (self , model : Optional [nn .Module ]= None , device = "cpu" , precision = None ):
918
+ global a8wxdq_load_error
919
+ raise Exception (f"Note: Failed to load torchao experimental a8wxdq quantizer with error: { a8wxdq_load_error } " )
920
+
906
921
a8wxdq_load_error = e
922
+ quantizer_class_dict ["linear:a8wxdq" ] = ErrorHandler
0 commit comments