Skip to content

Commit 93dab0e

Browse files
committed
int4 gptq shape fix
Summary: redoing 5bf70c1 in a way that doesn't get reverted. note, needed to fix a device issue as well. Test Plan: export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d3928b07e8be7e1e9e98b43584e125b6d60770d6 Pull Request resolved: #142
1 parent c955dac commit 93dab0e

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

GPTQ.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ def __init__(
150150
}
151151

152152
# trace model for one input
153-
one_input = [multi.values[0] for multi in inputs]
153+
one_input = [multi.values[0].cpu() for multi in inputs]
154154
exported_model = torch._dynamo.export(
155-
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
155+
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
156156
)(*one_input)
157157
super().__init__(exported_model.graph_module)
158158
self.new_state_dict = model.state_dict()

quantize.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365365
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
366366
return weight_int4pack, scales_and_zeros
367367

368+
def _calc_padded_size(k, groupsize=1, innner_k_tiles=1):
369+
from model import find_multiple
370+
return find_multiple(k, 1024)
368371

369372
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
370373
origin_x_size = x.size()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378381
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
379382
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
380383

381-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
384+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda):
382385
for name, child in module.named_children():
383386
if isinstance(child, nn.Linear):
384-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
387+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
385388
setattr(module, name, WeightOnlyInt4Linear(
386389
child.in_features, child.out_features, bias=False,
387-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda
388-
))
389-
elif padding:
390-
setattr(module, name, WeightOnlyInt4Linear(
391-
child.in_features, child.out_features, bias=False,
392-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda
390+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
393391
))
394392
else:
395-
replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
393+
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)
396394

397395

398396
class WeightOnlyInt4QuantHandler:
399-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
397+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
400398
self.mod = mod
401399
self.groupsize = groupsize
402400
self.inner_k_tiles = inner_k_tiles
403-
self.padding = padding
401+
self.padding_allowed = padding_allowed
404402
assert groupsize in [32, 64, 128, 256]
405403
assert inner_k_tiles in [2, 4, 8]
406404

@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417415

418416
weight = mod.weight.data
419417
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
420-
if self.padding:
418+
if self.padding_allowed:
421419
from model import find_multiple
422420
import torch.nn.functional as F
423421
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436434
return cur_state_dict
437435

438436
def convert_for_runtime(self, use_cuda):
439-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
437+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
440438
return self.mod
441439

442440
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
@@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485483

486484
def __init__(
487485
self, in_features: int, out_features: int,
488-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True,
486+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
489487
) -> None:
490488
super().__init__()
491-
self.padding = padding
492-
if padding:
489+
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
490+
if self.padding:
493491
from model import find_multiple
494492
self.origin_in_features = in_features
495493
in_features = find_multiple(in_features, 1024)
@@ -597,7 +595,7 @@ def quantize(
597595

598596
dir_name = checkpoint_path.parent
599597
base_name = checkpoint_path.name
600-
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
598+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
601599
else:
602600
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
603601

0 commit comments

Comments
 (0)