@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365
365
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
366
366
return weight_int4pack , scales_and_zeros
367
367
368
+ def _calc_padded_size (k , groupsize = 1 , innner_k_tiles = 1 ):
369
+ from model import find_multiple
370
+ return find_multiple (k , 1024 )
368
371
369
372
def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
370
373
origin_x_size = x .size ()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378
381
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
379
382
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
380
383
381
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
384
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda ):
382
385
for name , child in module .named_children ():
383
386
if isinstance (child , nn .Linear ):
384
- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
387
+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
385
388
setattr (module , name , WeightOnlyInt4Linear (
386
389
child .in_features , child .out_features , bias = False ,
387
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
388
- ))
389
- elif padding :
390
- setattr (module , name , WeightOnlyInt4Linear (
391
- child .in_features , child .out_features , bias = False ,
392
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
390
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , use_cuda = use_cuda
393
391
))
394
392
else :
395
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
393
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda )
396
394
397
395
398
396
class WeightOnlyInt4QuantHandler :
399
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
397
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
400
398
self .mod = mod
401
399
self .groupsize = groupsize
402
400
self .inner_k_tiles = inner_k_tiles
403
- self .padding = padding
401
+ self .padding_allowed = padding_allowed
404
402
assert groupsize in [32 , 64 , 128 , 256 ]
405
403
assert inner_k_tiles in [2 , 4 , 8 ]
406
404
@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417
415
418
416
weight = mod .weight .data
419
417
if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
420
- if self .padding :
418
+ if self .padding_allowed :
421
419
from model import find_multiple
422
420
import torch .nn .functional as F
423
421
print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436
434
return cur_state_dict
437
435
438
436
def convert_for_runtime (self , use_cuda ):
439
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
437
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed , use_cuda )
440
438
return self .mod
441
439
442
440
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485
483
486
484
def __init__ (
487
485
self , in_features : int , out_features : int ,
488
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
486
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
489
487
) -> None :
490
488
super ().__init__ ()
491
- self .padding = padding
492
- if padding :
489
+ self .padding = _check_linear_int4_k ( in_features , groupsize , inner_k_tiles )
490
+ if self . padding :
493
491
from model import find_multiple
494
492
self .origin_in_features = in_features
495
493
in_features = find_multiple (in_features , 1024 )
@@ -597,7 +595,7 @@ def quantize(
597
595
598
596
dir_name = checkpoint_path .parent
599
597
base_name = checkpoint_path .name
600
- new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .pth" )
598
+ new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .{ device } . pth" )
601
599
else :
602
600
raise ValueError (f"Invalid quantization mode { mode } needs to be one of [int8, int4, int4-gpptq]" )
603
601
0 commit comments