Skip to content

Commit dd62281

Browse files
committed
Revert "fixing over padding and GPTQ padding bug"
This reverts commit 5bf70c1. Breaks llama-70B + int4 + TP
1 parent c6a85b1 commit dd62281

File tree

2 files changed

+35
-30
lines changed

2 files changed

+35
-30
lines changed

model.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,15 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
from dataclasses import dataclass
7-
from typing import Optional, Tuple
7+
from typing import Optional
88

99
import torch
1010
import torch.nn as nn
1111
from torch import Tensor
1212
from torch.nn import functional as F
13-
from math import gcd
14-
from functools import reduce
1513

1614

17-
def find_multiple(n: int, *args: Tuple[int]) -> int:
18-
k = reduce(lambda x,y: x*y//gcd(x,y), args+(1,))
15+
def find_multiple(n: int, k: int) -> int:
1916
if n % k == 0:
2017
return n
2118
return n + k - (n % k)

quantize.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
except:
1818
pass
1919

20-
from model import Transformer, find_multiple
20+
from model import Transformer
2121

2222
##### Quantization Primitives ######
2323

@@ -376,27 +376,29 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376376
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
377377
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
378378

379-
def _calc_padded_size_linear_int4(k, groupsize = 1, inner_k_tiles = 1):
380-
return find_multiple(k, groupsize, inner_k_tiles*16)
381-
382-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed):
379+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
383380
for name, child in module.named_children():
384381
if isinstance(child, nn.Linear):
385-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
382+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
386383
setattr(module, name, WeightOnlyInt4Linear(
387384
child.in_features, child.out_features, bias=False,
388-
groupsize=groupsize, inner_k_tiles=inner_k_tiles,
385+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
386+
))
387+
elif padding:
388+
setattr(module, name, WeightOnlyInt4Linear(
389+
child.in_features, child.out_features, bias=False,
390+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
389391
))
390392
else:
391-
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed)
393+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
392394

393395

394396
class WeightOnlyInt4QuantHandler:
395-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
397+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
396398
self.mod = mod
397399
self.groupsize = groupsize
398400
self.inner_k_tiles = inner_k_tiles
399-
self.padding_allowed = padding_allowed
401+
self.padding = padding
400402
assert groupsize in [32, 64, 128, 256]
401403
assert inner_k_tiles in [2, 4, 8]
402404

@@ -418,9 +420,11 @@ def create_quantized_state_dict(self, use_cuda = True):
418420

419421
weight = mod.weight.data
420422
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
421-
if self.padding_allowed:
423+
if self.padding:
424+
from model import find_multiple
425+
import torch.nn.functional as F
422426
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
423-
padded_in_features = _calc_padded_size_linear_int4(in_features, 1024)
427+
padded_in_features = find_multiple(in_features, 1024)
424428
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
425429
else:
426430
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -435,30 +439,31 @@ def create_quantized_state_dict(self, use_cuda = True):
435439
return cur_state_dict
436440

437441
def convert_for_runtime(self):
438-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed)
442+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
439443
return self.mod
440444

441445
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
442-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
446+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
447+
from model import find_multiple
443448
self.mod = mod
444449
self.groupsize = groupsize
445450
self.inner_k_tiles = inner_k_tiles
446-
self.padding_allowed = padding_allowed
451+
self.padding = padding
447452
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
448453
self.quantize_func = lambda w, qparams: \
449454
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
450455
self.dequantize_func = lambda q, qparams: \
451456
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
452457
self.combine_qparams_list_func = lambda qparams_list: \
453458
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
454-
# skip unless padding_allowed=True or its correctly sized
459+
# skip unless padding=True or its correctly sized
455460
self.skip_layer_func = lambda linear_weight: not (
456-
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed
461+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
457462
)
458463
# we need to do the padding here, both for q and the qparams if necessary
459464
def make_names_and_values_dict_func(q, qparams):
460465
k = q.shape[1]
461-
new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
466+
new_k = find_multiple(k, 1024)
462467
# how much we need to pad the weight
463468
delta_k = new_k - q.shape[1]
464469
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
@@ -472,7 +477,7 @@ def make_names_and_values_dict_func(q, qparams):
472477

473478

474479
def convert_for_runtime(self):
475-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed)
480+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
476481
return self.mod
477482

478483
class WeightOnlyInt4Linear(torch.nn.Module):
@@ -483,16 +488,17 @@ class WeightOnlyInt4Linear(torch.nn.Module):
483488

484489
def __init__(
485490
self, in_features: int, out_features: int,
486-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
491+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
487492
) -> None:
488493
super().__init__()
494+
self.padding = padding
495+
if padding:
496+
from model import find_multiple
497+
self.origin_in_features = in_features
498+
in_features = find_multiple(in_features, 1024)
489499

490-
# always pad if needed since it becomes a noop at runtime if not needed
491-
self.origin_in_features = in_features
492-
in_features = _calc_padded_size_linear_int4(in_features, groupsize, inner_k_tiles)
493500
self.in_features = in_features
494501
self.out_features = out_features
495-
496502
assert not bias, "require bias=False"
497503
self.groupsize = groupsize
498504
self.inner_k_tiles = inner_k_tiles
@@ -510,7 +516,9 @@ def __init__(
510516

511517
def forward(self, input: torch.Tensor) -> torch.Tensor:
512518
input = input.to(torch.bfloat16)
513-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
519+
if self.padding:
520+
import torch.nn.functional as F
521+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
514522
return linear_forward_int4(
515523
input,
516524
self.weight, self.scales_and_zeros, self.out_features, self.groupsize

0 commit comments

Comments
 (0)