@@ -66,10 +66,12 @@ def __init__(
6666
6767 def add_input (self , args ):
6868 if self .inputs is None :
69- self .inputs = [MultiInput ([arg ]) for arg in args ]
69+ # self.inputs = [MultiInput([arg]) for arg in args]
70+ self .inputs = [GPTQMultiTensor ([arg ]) for arg in args ]
7071 else :
7172 self .inputs = [
72- multi .add_input (arg ) for (multi , arg ) in zip (self .inputs , args )
73+ # multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
74+ multi .add_tensors (arg ) for (multi , arg ) in zip (self .inputs , args )
7375 ]
7476
7577 def get_recorded_inputs (self ):
@@ -129,6 +131,199 @@ def cuda(self):
129131 self .values = [val .cuda () if isinstance (val , torch .Tensor ) else val for val in self .values ]
130132
131133
134+ class GPTQMultiTensor (torch .Tensor ):
135+ """
136+ """
137+ # todo need default shape/dtype
138+ @staticmethod
139+ def __new__ (cls , input , ** kwargs ):
140+ if isinstance (input , (list , tuple )):
141+ input = input [0 ]
142+ kwargs ["dtype" ]= kwargs .get ("dtype" , input .dtype )
143+ shape = kwargs .pop ("shape" , input .shape )
144+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
145+
146+ def __init__ (self , input , ** kwargs ):
147+ self .values = []
148+ self .add_tensors (input )
149+ self .debug = True
150+
151+ def __repr__ (self ):
152+ return (
153+ f"{ self .__class__ .__name__ } (data={ self .values } )"
154+ )
155+
156+ def add_tensors (self , input ):
157+ if isinstance (input , (tuple , list )):
158+ for inp in input :
159+ self .add_tensors (inp )
160+ else :
161+ assert isinstance (input , torch .Tensor ), f"MultiTensor can only use add_input for Tensors or lists of tensors but got { type (input )} "
162+ self .values .append (input )
163+ return self
164+
165+ def count (self ):
166+ return len (self .values )
167+
168+ def cuda (self ):
169+ self .values = [val .cuda () for val in self .values ]
170+ return self
171+
172+ def cpu (self ):
173+ self .values = [val .cpu () for val in self .values ]
174+ return self
175+
176+ def configure_quantization_mode (
177+ self ,
178+ get_qparams_func ,
179+ quantize_func ,
180+ dequantize_func ,
181+ combine_qparams_list_func ,
182+ make_names_and_values_dict_func ,
183+ skip_layer_func ,
184+ ):
185+ self .get_qparams_func = get_qparams_func
186+ self .quantize_func = quantize_func
187+ self .dequantize_func = dequantize_func
188+ self .combine_qparams_list_func = combine_qparams_list_func
189+ self .skip_layer_func = skip_layer_func
190+ self .make_names_and_values_dict_func = make_names_and_values_dict_func
191+ return self
192+
193+ @classmethod
194+ def __torch_function__ (cls , func , types , args = (), kwargs = None , skip_gptq = False ):
195+ # with torch._C.DisableTorchFunctionSubclass():
196+ # is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
197+ # if is_set_item:
198+ # breakpoint()
199+ # try:
200+ # new_arg1=[None if x == slice(None) else x for x in args[1]]
201+ # return torch.ops.aten.index_put(args[0], new_arg1, args[2])
202+ # except Exception as e:
203+ # print(e)
204+ # print("?A?")
205+ # breakpoint()
206+ # print("?")
207+ # if func == torch.ops.aten.index_put_:
208+ # breakpoint()
209+
210+ def tensors_to_cuda (args ):
211+ new_args = []
212+ for x in args :
213+ new_args .append (x .cuda () if isinstance (x , torch .Tensor ) else x )
214+ return new_args
215+
216+ def flat_to_grouped (flat ):
217+ # size of biggest MultiTensor
218+ multi_tensor_size = max (
219+ [x .count () if isinstance (x , GPTQMultiTensor ) else 1 for x in flat ]
220+ )
221+ # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
222+ grouped = list (
223+ zip (
224+ * [x .values if isinstance (x , GPTQMultiTensor ) else [x ] * multi_tensor_size for x in flat ]
225+ )
226+ )
227+ return grouped
228+
229+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
230+ # where A is nontensor, b's,c's are tensors
231+ def grouped_to_flat (grouped ):
232+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
233+ flat_tups = list (zip (* grouped ))
234+ # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
235+ flattened = [
236+ cls (tup ).cpu () if isinstance (tup [0 ], torch .Tensor ) else tup [0 ] for tup in flat_tups
237+ ]
238+ # need to check that getting rid of all but one from each nonTensor tuple is OK
239+ non_tensors_equal = min ([True ]+ [
240+ min ([True ]+ [ # handle situation where tuples have size 0
241+ tup [0 ]== x for x in tup # check all elements match
242+ ]) for tup in flat_tups if not isinstance (tup [0 ], torch .Tensor ) # look at tuples of nonTensors
243+ ])
244+ return flattened , non_tensors_equal
245+
246+ kwargs = {} if kwargs is None else kwargs
247+ # combine args and kwargs and remove lists and tuples
248+ flat_args , spec = tree_flatten ((args , kwargs ))
249+ # move single tensors to cuda
250+
251+ # flat_args = tensors_to_cuda(flat_args)
252+
253+ # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
254+ grouped_args = flat_to_grouped (flat_args )
255+
256+ do_gptq_linear = (
257+ func is nn .functional .linear
258+ # and id(args[1]) in self.id_to_name
259+ and not skip_gptq
260+ # and not (self.skip_layer_func)
261+ )
262+
263+ # run function for each of the multitensors and return a multitensor
264+ if not do_gptq_linear :
265+ outputs = []
266+ with torch ._C .DisableTorchFunctionSubclass ():
267+ for inp in grouped_args :
268+ # inp = tensors_to_cuda(inp)
269+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
270+ try :
271+ out = func (* cur_args , ** cur_kwargs )
272+ outputs .append (out .cpu () if isinstance (out , torch .Tensor ) else out )
273+ except Exception as e :
274+ print (e )
275+ print ("?B?" )
276+ breakpoint ()
277+ print ("?" )
278+ try :
279+ # each output
280+ grouped_outputs = [tree_flatten (x )[0 ] for x in outputs ]
281+ out_spec = tree_flatten (outputs [0 ])[1 ]
282+ # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
283+ flat_outputs , non_tensors_equal = grouped_to_flat (grouped_outputs )
284+ assert non_tensors_equal , (
285+ f"ERR: found a function in model: { func } which "
286+ + "caused an error in GPTQMultiInput, the function dispatch only works for functions"
287+ + " with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
288+ )
289+ return tree_unflatten (flat_outputs , out_spec )
290+ except Exception as e :
291+ print (e )
292+ print ("?C?" )
293+ breakpoint ()
294+ print ("?" )
295+
296+ # do GPTQ if quantize_linear is true
297+ total_batches = 0
298+ H = 0
299+ for inp in grouped_args :
300+ # inp = tensors_to_cuda(inp)
301+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
302+ x = cur_args [0 ].float ()
303+ shape = x .shape
304+ n = 1 if len (shape ) == 2 else shape [0 ]
305+ H *= total_batches / (total_batches + n )
306+ total_batches += n
307+ x = (
308+ (2 / total_batches ) ** (1 / 2 ) *
309+ x .reshape (- 1 , shape [- 1 ]).t ().float ()
310+
311+ )
312+ H += x .matmul (x .t ())
313+ W = args [1 ].to (H .device )
314+ DQ = W + .01
315+ # Q, DQ, qparams = args[0].faster_quant(H, W.detach())
316+
317+ new_out = cls .__torch_function__ (func , types , (args [0 ], DQ , * args [2 :]), kwargs , skip_gptq = True )
318+ # if args[0].debug:
319+ return new_out
320+
321+ @classmethod
322+ def __torch_dispatch__ (cls , func , types , args , kwargs ):
323+ breakpoint ()
324+ pass
325+
326+
132327class GenericGPTQRunner (fx .Interpreter ):
133328 """
134329 This is a generic GPTQ runner that takes an existing model and applies GPTQ.
@@ -150,7 +345,7 @@ def __init__(
150345 }
151346
152347 # trace model for one input
153- one_input = [multi .values [0 ].cpu () for multi in inputs ]
348+ one_input = tuple ( [multi .values [0 ].cpu () for multi in inputs ])
154349 exported_model = torch ._dynamo .export (
155350 model .cpu (), aten_graph = True , pre_dispatch = True , tracing_mode = "fake"
156351 )(* one_input )
@@ -161,7 +356,7 @@ def __init__(
161356 self .groupsize = groupsize
162357 self .inputs = inputs
163358 self .gptq_done = False
164- self .debug = False
359+ self .debug = True
165360
166361 def configure_quantization_mode (
167362 self ,
@@ -312,6 +507,16 @@ def SQNR(x, y):
312507 print (
313508 "SQNR for QDQ (this should be inf)" , SQNR (DQ , DQ_after )
314509 ) # matches
510+ qparams_after = self .get_qparams_func (DQ )
511+ Q_after = self .quantize_func (DQ , qparams_after )
512+ print (
513+ "abs difference of Q-quant(DQ)" , (Q - Q_after ).abs ().sum ()
514+ )
515+ DQ_after_after = self .dequantize_func (Q_after , qparams_after ).to (DQ .dtype )
516+ print (
517+ "SQNR for DQ(Q(DQ)) vs DQ" , SQNR (DQ , DQ_after_after )
518+ )
519+ breakpoint ()
315520
316521 print (
317522 "SQNR for weight (can be low)" , SQNR (W , DQ .cuda ())
0 commit comments