Skip to content

Commit dfaa329

Browse files
committed
fixing GPTQ
Summary: trying to fix the issue with kv_cache update by changing tracing into a tensor subclass. However it seems we have less success than the fx tracer. The fx tracer breaks due k_out[:,:, input_pos] = k_val getting traced as new_var = torch.ops.aten.index_put_(k_out, [None, None, input_pos], k_val) with new var never being accessed afterward. new_var becomes hte correct multiInput value, but then is lost. The subclass ont he other hand, tries to use the func "<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>" which seems to not want to mutate k_out and so the attempt to make it a multiTensor fails. Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9ed1621201317e5f655132ba11538a67c8aa5a69 Pull Request resolved: #148
1 parent f697317 commit dfaa329

File tree

4 files changed

+244
-9
lines changed

4 files changed

+244
-9
lines changed

GPTQ.py

+209-4
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ def __init__(
6666

6767
def add_input(self, args):
6868
if self.inputs is None:
69-
self.inputs = [MultiInput([arg]) for arg in args]
69+
# self.inputs = [MultiInput([arg]) for arg in args]
70+
self.inputs = [GPTQMultiTensor([arg]) for arg in args]
7071
else:
7172
self.inputs = [
72-
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
73+
# multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
74+
multi.add_tensors(arg) for (multi, arg) in zip(self.inputs, args)
7375
]
7476

7577
def get_recorded_inputs(self):
@@ -129,6 +131,199 @@ def cuda(self):
129131
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
130132

131133

134+
class GPTQMultiTensor(torch.Tensor):
135+
"""
136+
"""
137+
# todo need default shape/dtype
138+
@staticmethod
139+
def __new__(cls, input, **kwargs):
140+
if isinstance(input, (list, tuple)):
141+
input = input[0]
142+
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
143+
shape = kwargs.pop("shape", input.shape)
144+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
145+
146+
def __init__(self, input, **kwargs):
147+
self.values = []
148+
self.add_tensors(input)
149+
self.debug = True
150+
151+
def __repr__(self):
152+
return (
153+
f"{self.__class__.__name__}(data={self.values})"
154+
)
155+
156+
def add_tensors(self, input):
157+
if isinstance(input, (tuple, list)):
158+
for inp in input:
159+
self.add_tensors(inp)
160+
else:
161+
assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_input for Tensors or lists of tensors but got {type(input)}"
162+
self.values.append(input)
163+
return self
164+
165+
def count(self):
166+
return len(self.values)
167+
168+
def cuda(self):
169+
self.values = [val.cuda() for val in self.values]
170+
return self
171+
172+
def cpu(self):
173+
self.values = [val.cpu() for val in self.values]
174+
return self
175+
176+
def configure_quantization_mode(
177+
self,
178+
get_qparams_func,
179+
quantize_func,
180+
dequantize_func,
181+
combine_qparams_list_func,
182+
make_names_and_values_dict_func,
183+
skip_layer_func,
184+
):
185+
self.get_qparams_func = get_qparams_func
186+
self.quantize_func = quantize_func
187+
self.dequantize_func = dequantize_func
188+
self.combine_qparams_list_func = combine_qparams_list_func
189+
self.skip_layer_func = skip_layer_func
190+
self.make_names_and_values_dict_func = make_names_and_values_dict_func
191+
return self
192+
193+
@classmethod
194+
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False):
195+
# with torch._C.DisableTorchFunctionSubclass():
196+
# is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
197+
# if is_set_item:
198+
# breakpoint()
199+
# try:
200+
# new_arg1=[None if x == slice(None) else x for x in args[1]]
201+
# return torch.ops.aten.index_put(args[0], new_arg1, args[2])
202+
# except Exception as e:
203+
# print(e)
204+
# print("?A?")
205+
# breakpoint()
206+
# print("?")
207+
# if func == torch.ops.aten.index_put_:
208+
# breakpoint()
209+
210+
def tensors_to_cuda(args):
211+
new_args = []
212+
for x in args:
213+
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
214+
return new_args
215+
216+
def flat_to_grouped(flat):
217+
# size of biggest MultiTensor
218+
multi_tensor_size = max(
219+
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat]
220+
)
221+
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
222+
grouped = list(
223+
zip(
224+
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat]
225+
)
226+
)
227+
return grouped
228+
229+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
230+
# where A is nontensor, b's,c's are tensors
231+
def grouped_to_flat(grouped):
232+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
233+
flat_tups = list(zip(*grouped))
234+
# convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
235+
flattened = [
236+
cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups
237+
]
238+
# need to check that getting rid of all but one from each nonTensor tuple is OK
239+
non_tensors_equal=min([True]+[
240+
min([True]+[ # handle situation where tuples have size 0
241+
tup[0]==x for x in tup # check all elements match
242+
]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors
243+
])
244+
return flattened, non_tensors_equal
245+
246+
kwargs = {} if kwargs is None else kwargs
247+
# combine args and kwargs and remove lists and tuples
248+
flat_args, spec = tree_flatten((args, kwargs))
249+
# move single tensors to cuda
250+
251+
# flat_args = tensors_to_cuda(flat_args)
252+
253+
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
254+
grouped_args = flat_to_grouped(flat_args)
255+
256+
do_gptq_linear = (
257+
func is nn.functional.linear
258+
# and id(args[1]) in self.id_to_name
259+
and not skip_gptq
260+
# and not (self.skip_layer_func)
261+
)
262+
263+
# run function for each of the multitensors and return a multitensor
264+
if not do_gptq_linear:
265+
outputs = []
266+
with torch._C.DisableTorchFunctionSubclass():
267+
for inp in grouped_args:
268+
# inp = tensors_to_cuda(inp)
269+
cur_args, cur_kwargs = tree_unflatten(inp, spec)
270+
try:
271+
out = func(*cur_args, **cur_kwargs)
272+
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
273+
except Exception as e:
274+
print(e)
275+
print("?B?")
276+
breakpoint()
277+
print("?")
278+
try:
279+
# each output
280+
grouped_outputs = [tree_flatten(x)[0] for x in outputs]
281+
out_spec = tree_flatten(outputs[0])[1]
282+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
283+
flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs)
284+
assert non_tensors_equal, (
285+
f"ERR: found a function in model: {func} which "
286+
+"caused an error in GPTQMultiInput, the function dispatch only works for functions"
287+
+" with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
288+
)
289+
return tree_unflatten(flat_outputs, out_spec)
290+
except Exception as e:
291+
print(e)
292+
print("?C?")
293+
breakpoint()
294+
print("?")
295+
296+
# do GPTQ if quantize_linear is true
297+
total_batches = 0
298+
H=0
299+
for inp in grouped_args:
300+
# inp = tensors_to_cuda(inp)
301+
cur_args, cur_kwargs = tree_unflatten(inp, spec)
302+
x = cur_args[0].float()
303+
shape = x.shape
304+
n = 1 if len(shape) == 2 else shape[0]
305+
H*= total_batches / (total_batches + n)
306+
total_batches += n
307+
x = (
308+
(2 / total_batches) ** (1 / 2) *
309+
x.reshape(-1, shape[-1]).t().float()
310+
311+
)
312+
H += x.matmul(x.t())
313+
W = args[1].to(H.device)
314+
DQ = W+.01
315+
# Q, DQ, qparams = args[0].faster_quant(H, W.detach())
316+
317+
new_out = cls.__torch_function__(func, types, (args[0], DQ, *args[2:]), kwargs, skip_gptq = True)
318+
# if args[0].debug:
319+
return new_out
320+
321+
@classmethod
322+
def __torch_dispatch__(cls, func, types, args, kwargs):
323+
breakpoint()
324+
pass
325+
326+
132327
class GenericGPTQRunner(fx.Interpreter):
133328
"""
134329
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
@@ -150,7 +345,7 @@ def __init__(
150345
}
151346

152347
# trace model for one input
153-
one_input = [multi.values[0].cpu() for multi in inputs]
348+
one_input = tuple([multi.values[0].cpu() for multi in inputs])
154349
exported_model = torch._dynamo.export(
155350
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
156351
)(*one_input)
@@ -161,7 +356,7 @@ def __init__(
161356
self.groupsize = groupsize
162357
self.inputs = inputs
163358
self.gptq_done = False
164-
self.debug = False
359+
self.debug = True
165360

166361
def configure_quantization_mode(
167362
self,
@@ -312,6 +507,16 @@ def SQNR(x, y):
312507
print(
313508
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
314509
) # matches
510+
qparams_after = self.get_qparams_func(DQ)
511+
Q_after = self.quantize_func(DQ, qparams_after)
512+
print(
513+
"abs difference of Q-quant(DQ)", (Q-Q_after).abs().sum()
514+
)
515+
DQ_after_after = self.dequantize_func(Q_after, qparams_after).to(DQ.dtype)
516+
print(
517+
"SQNR for DQ(Q(DQ)) vs DQ", SQNR(DQ, DQ_after_after)
518+
)
519+
breakpoint()
315520

316521
print(
317522
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())

model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,16 @@ def update(self, input_pos, k_val, v_val):
7878
# input_pos: [S], k_val: [B, H, S, D]
7979
assert input_pos.shape[0] == k_val.shape[2]
8080

81+
8182
k_out = self.k_cache
8283
v_out = self.v_cache
8384
k_out[:, :, input_pos] = k_val
85+
breakpoint()
8486
v_out[:, :, input_pos] = v_val
87+
breakpoint()
88+
# k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
89+
# v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
90+
8591

8692
return k_out, v_out
8793

@@ -174,7 +180,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
174180

175181
kv_size = self.n_local_heads * self.head_dim
176182
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
177-
178183
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
179184
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
180185
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

quantize.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch.nn.functional as F
1212
from sentencepiece import SentencePieceProcessor
1313

14+
from GPTQ import GenericGPTQRunner, InputRecorder
15+
1416
try:
1517
from GPTQ import GenericGPTQRunner, InputRecorder
1618
from eval import get_task_dict, evaluate, lm_eval
@@ -286,6 +288,10 @@ def create_quantized_state_dict(
286288
pad_calibration_inputs,
287289
) -> "StateDict":
288290
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
291+
self.mod=self.mod.to("cpu")
292+
inputs=[x.cpu() if hasattr(x, "cpu") else x for x in inputs]
293+
self.mod(*inputs)
294+
289295
print("Tracing model for GPTQ")
290296
GPTQ_runner = GenericGPTQRunner(
291297
self.mod,
@@ -438,12 +444,12 @@ def convert_for_runtime(self, use_cuda):
438444
return self.mod
439445

440446
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
441-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
447+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
442448
from model import find_multiple
443449
self.mod = mod
444450
self.groupsize = groupsize
445451
self.inner_k_tiles = inner_k_tiles
446-
self.padding = padding
452+
self.padding_allowed = padding_allowed
447453
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
448454
self.quantize_func = lambda w, qparams: \
449455
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
@@ -453,7 +459,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
453459
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
454460
# skip unless padding=True or its correctly sized
455461
self.skip_layer_func = lambda linear_weight: not (
456-
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
462+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed
457463
)
458464
# we need to do the padding here, both for q and the qparams if necessary
459465
def make_names_and_values_dict_func(q, qparams):
@@ -472,7 +478,7 @@ def make_names_and_values_dict_func(q, qparams):
472478

473479

474480
def convert_for_runtime(self, use_cuda):
475-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
481+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
476482
return self.mod
477483

478484
class WeightOnlyInt4Linear(torch.nn.Module):

run.sh

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
2+
3+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
4+
# echo "base"
5+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 1
6+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5
7+
8+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
9+
# echo "quant good"
10+
11+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
12+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
13+
14+
# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
15+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
16+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
17+
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth
18+
19+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 commit comments

Comments
 (0)