@@ -66,10 +66,12 @@ def __init__(
66
66
67
67
def add_input (self , args ):
68
68
if self .inputs is None :
69
- self .inputs = [MultiInput ([arg ]) for arg in args ]
69
+ # self.inputs = [MultiInput([arg]) for arg in args]
70
+ self .inputs = [GPTQMultiTensor ([arg ]) for arg in args ]
70
71
else :
71
72
self .inputs = [
72
- multi .add_input (arg ) for (multi , arg ) in zip (self .inputs , args )
73
+ # multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
74
+ multi .add_tensors (arg ) for (multi , arg ) in zip (self .inputs , args )
73
75
]
74
76
75
77
def get_recorded_inputs (self ):
@@ -129,6 +131,199 @@ def cuda(self):
129
131
self .values = [val .cuda () if isinstance (val , torch .Tensor ) else val for val in self .values ]
130
132
131
133
134
+ class GPTQMultiTensor (torch .Tensor ):
135
+ """
136
+ """
137
+ # todo need default shape/dtype
138
+ @staticmethod
139
+ def __new__ (cls , input , ** kwargs ):
140
+ if isinstance (input , (list , tuple )):
141
+ input = input [0 ]
142
+ kwargs ["dtype" ]= kwargs .get ("dtype" , input .dtype )
143
+ shape = kwargs .pop ("shape" , input .shape )
144
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
145
+
146
+ def __init__ (self , input , ** kwargs ):
147
+ self .values = []
148
+ self .add_tensors (input )
149
+ self .debug = True
150
+
151
+ def __repr__ (self ):
152
+ return (
153
+ f"{ self .__class__ .__name__ } (data={ self .values } )"
154
+ )
155
+
156
+ def add_tensors (self , input ):
157
+ if isinstance (input , (tuple , list )):
158
+ for inp in input :
159
+ self .add_tensors (inp )
160
+ else :
161
+ assert isinstance (input , torch .Tensor ), f"MultiTensor can only use add_input for Tensors or lists of tensors but got { type (input )} "
162
+ self .values .append (input )
163
+ return self
164
+
165
+ def count (self ):
166
+ return len (self .values )
167
+
168
+ def cuda (self ):
169
+ self .values = [val .cuda () for val in self .values ]
170
+ return self
171
+
172
+ def cpu (self ):
173
+ self .values = [val .cpu () for val in self .values ]
174
+ return self
175
+
176
+ def configure_quantization_mode (
177
+ self ,
178
+ get_qparams_func ,
179
+ quantize_func ,
180
+ dequantize_func ,
181
+ combine_qparams_list_func ,
182
+ make_names_and_values_dict_func ,
183
+ skip_layer_func ,
184
+ ):
185
+ self .get_qparams_func = get_qparams_func
186
+ self .quantize_func = quantize_func
187
+ self .dequantize_func = dequantize_func
188
+ self .combine_qparams_list_func = combine_qparams_list_func
189
+ self .skip_layer_func = skip_layer_func
190
+ self .make_names_and_values_dict_func = make_names_and_values_dict_func
191
+ return self
192
+
193
+ @classmethod
194
+ def __torch_function__ (cls , func , types , args = (), kwargs = None , skip_gptq = False ):
195
+ # with torch._C.DisableTorchFunctionSubclass():
196
+ # is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
197
+ # if is_set_item:
198
+ # breakpoint()
199
+ # try:
200
+ # new_arg1=[None if x == slice(None) else x for x in args[1]]
201
+ # return torch.ops.aten.index_put(args[0], new_arg1, args[2])
202
+ # except Exception as e:
203
+ # print(e)
204
+ # print("?A?")
205
+ # breakpoint()
206
+ # print("?")
207
+ # if func == torch.ops.aten.index_put_:
208
+ # breakpoint()
209
+
210
+ def tensors_to_cuda (args ):
211
+ new_args = []
212
+ for x in args :
213
+ new_args .append (x .cuda () if isinstance (x , torch .Tensor ) else x )
214
+ return new_args
215
+
216
+ def flat_to_grouped (flat ):
217
+ # size of biggest MultiTensor
218
+ multi_tensor_size = max (
219
+ [x .count () if isinstance (x , GPTQMultiTensor ) else 1 for x in flat ]
220
+ )
221
+ # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
222
+ grouped = list (
223
+ zip (
224
+ * [x .values if isinstance (x , GPTQMultiTensor ) else [x ] * multi_tensor_size for x in flat ]
225
+ )
226
+ )
227
+ return grouped
228
+
229
+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
230
+ # where A is nontensor, b's,c's are tensors
231
+ def grouped_to_flat (grouped ):
232
+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
233
+ flat_tups = list (zip (* grouped ))
234
+ # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
235
+ flattened = [
236
+ cls (tup ).cpu () if isinstance (tup [0 ], torch .Tensor ) else tup [0 ] for tup in flat_tups
237
+ ]
238
+ # need to check that getting rid of all but one from each nonTensor tuple is OK
239
+ non_tensors_equal = min ([True ]+ [
240
+ min ([True ]+ [ # handle situation where tuples have size 0
241
+ tup [0 ]== x for x in tup # check all elements match
242
+ ]) for tup in flat_tups if not isinstance (tup [0 ], torch .Tensor ) # look at tuples of nonTensors
243
+ ])
244
+ return flattened , non_tensors_equal
245
+
246
+ kwargs = {} if kwargs is None else kwargs
247
+ # combine args and kwargs and remove lists and tuples
248
+ flat_args , spec = tree_flatten ((args , kwargs ))
249
+ # move single tensors to cuda
250
+
251
+ # flat_args = tensors_to_cuda(flat_args)
252
+
253
+ # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
254
+ grouped_args = flat_to_grouped (flat_args )
255
+
256
+ do_gptq_linear = (
257
+ func is nn .functional .linear
258
+ # and id(args[1]) in self.id_to_name
259
+ and not skip_gptq
260
+ # and not (self.skip_layer_func)
261
+ )
262
+
263
+ # run function for each of the multitensors and return a multitensor
264
+ if not do_gptq_linear :
265
+ outputs = []
266
+ with torch ._C .DisableTorchFunctionSubclass ():
267
+ for inp in grouped_args :
268
+ # inp = tensors_to_cuda(inp)
269
+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
270
+ try :
271
+ out = func (* cur_args , ** cur_kwargs )
272
+ outputs .append (out .cpu () if isinstance (out , torch .Tensor ) else out )
273
+ except Exception as e :
274
+ print (e )
275
+ print ("?B?" )
276
+ breakpoint ()
277
+ print ("?" )
278
+ try :
279
+ # each output
280
+ grouped_outputs = [tree_flatten (x )[0 ] for x in outputs ]
281
+ out_spec = tree_flatten (outputs [0 ])[1 ]
282
+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
283
+ flat_outputs , non_tensors_equal = grouped_to_flat (grouped_outputs )
284
+ assert non_tensors_equal , (
285
+ f"ERR: found a function in model: { func } which "
286
+ + "caused an error in GPTQMultiInput, the function dispatch only works for functions"
287
+ + " with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
288
+ )
289
+ return tree_unflatten (flat_outputs , out_spec )
290
+ except Exception as e :
291
+ print (e )
292
+ print ("?C?" )
293
+ breakpoint ()
294
+ print ("?" )
295
+
296
+ # do GPTQ if quantize_linear is true
297
+ total_batches = 0
298
+ H = 0
299
+ for inp in grouped_args :
300
+ # inp = tensors_to_cuda(inp)
301
+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
302
+ x = cur_args [0 ].float ()
303
+ shape = x .shape
304
+ n = 1 if len (shape ) == 2 else shape [0 ]
305
+ H *= total_batches / (total_batches + n )
306
+ total_batches += n
307
+ x = (
308
+ (2 / total_batches ) ** (1 / 2 ) *
309
+ x .reshape (- 1 , shape [- 1 ]).t ().float ()
310
+
311
+ )
312
+ H += x .matmul (x .t ())
313
+ W = args [1 ].to (H .device )
314
+ DQ = W + .01
315
+ # Q, DQ, qparams = args[0].faster_quant(H, W.detach())
316
+
317
+ new_out = cls .__torch_function__ (func , types , (args [0 ], DQ , * args [2 :]), kwargs , skip_gptq = True )
318
+ # if args[0].debug:
319
+ return new_out
320
+
321
+ @classmethod
322
+ def __torch_dispatch__ (cls , func , types , args , kwargs ):
323
+ breakpoint ()
324
+ pass
325
+
326
+
132
327
class GenericGPTQRunner (fx .Interpreter ):
133
328
"""
134
329
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
@@ -150,7 +345,7 @@ def __init__(
150
345
}
151
346
152
347
# trace model for one input
153
- one_input = [multi .values [0 ].cpu () for multi in inputs ]
348
+ one_input = tuple ( [multi .values [0 ].cpu () for multi in inputs ])
154
349
exported_model = torch ._dynamo .export (
155
350
model .cpu (), aten_graph = True , pre_dispatch = True , tracing_mode = "fake"
156
351
)(* one_input )
@@ -161,7 +356,7 @@ def __init__(
161
356
self .groupsize = groupsize
162
357
self .inputs = inputs
163
358
self .gptq_done = False
164
- self .debug = False
359
+ self .debug = True
165
360
166
361
def configure_quantization_mode (
167
362
self ,
@@ -312,6 +507,16 @@ def SQNR(x, y):
312
507
print (
313
508
"SQNR for QDQ (this should be inf)" , SQNR (DQ , DQ_after )
314
509
) # matches
510
+ qparams_after = self .get_qparams_func (DQ )
511
+ Q_after = self .quantize_func (DQ , qparams_after )
512
+ print (
513
+ "abs difference of Q-quant(DQ)" , (Q - Q_after ).abs ().sum ()
514
+ )
515
+ DQ_after_after = self .dequantize_func (Q_after , qparams_after ).to (DQ .dtype )
516
+ print (
517
+ "SQNR for DQ(Q(DQ)) vs DQ" , SQNR (DQ , DQ_after_after )
518
+ )
519
+ breakpoint ()
315
520
316
521
print (
317
522
"SQNR for weight (can be low)" , SQNR (W , DQ .cuda ())
0 commit comments