1717except :
1818 pass
1919
20- from model import Transformer , find_multiple
20+ from model import Transformer
2121
2222##### Quantization Primitives ######
2323
@@ -376,27 +376,29 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376376def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
377377 return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
378378
379- def _calc_padded_size_linear_int4 (k , groupsize = 1 , inner_k_tiles = 1 ):
380- return find_multiple (k , groupsize , inner_k_tiles * 16 )
381-
382- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed ):
379+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
383380 for name , child in module .named_children ():
384381 if isinstance (child , nn .Linear ):
385- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
382+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
386383 setattr (module , name , WeightOnlyInt4Linear (
387384 child .in_features , child .out_features , bias = False ,
388- groupsize = groupsize , inner_k_tiles = inner_k_tiles ,
385+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
386+ ))
387+ elif padding :
388+ setattr (module , name , WeightOnlyInt4Linear (
389+ child .in_features , child .out_features , bias = False ,
390+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
389391 ))
390392 else :
391- replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed )
393+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
392394
393395
394396class WeightOnlyInt4QuantHandler :
395- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
397+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
396398 self .mod = mod
397399 self .groupsize = groupsize
398400 self .inner_k_tiles = inner_k_tiles
399- self .padding_allowed = padding_allowed
401+ self .padding = padding
400402 assert groupsize in [32 , 64 , 128 , 256 ]
401403 assert inner_k_tiles in [2 , 4 , 8 ]
402404
@@ -418,9 +420,11 @@ def create_quantized_state_dict(self, use_cuda = True):
418420
419421 weight = mod .weight .data
420422 if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
421- if self .padding_allowed :
423+ if self .padding :
424+ from model import find_multiple
425+ import torch .nn .functional as F
422426 print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
423- padded_in_features = _calc_padded_size_linear_int4 (in_features , 1024 )
427+ padded_in_features = find_multiple (in_features , 1024 )
424428 weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
425429 else :
426430 print (f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -435,30 +439,31 @@ def create_quantized_state_dict(self, use_cuda = True):
435439 return cur_state_dict
436440
437441 def convert_for_runtime (self ):
438- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
442+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
439443 return self .mod
440444
441445class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
442- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
446+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
447+ from model import find_multiple
443448 self .mod = mod
444449 self .groupsize = groupsize
445450 self .inner_k_tiles = inner_k_tiles
446- self .padding_allowed = padding_allowed
451+ self .padding = padding
447452 self .get_qparams_func = lambda w : get_group_qparams (w , 4 , groupsize )
448453 self .quantize_func = lambda w , qparams : \
449454 group_quantize_tensor_from_qparams (w , qparams [0 ], qparams [1 ], 4 , groupsize )
450455 self .dequantize_func = lambda q , qparams : \
451456 group_dequantize_tensor_from_qparams (q , qparams [0 ], qparams [1 ], 4 , groupsize ).float ()
452457 self .combine_qparams_list_func = lambda qparams_list : \
453458 [torch .cat (x , dim = 1 ) for x in zip (* qparams_list )]
454- # skip unless padding_allowed =True or its correctly sized
459+ # skip unless padding =True or its correctly sized
455460 self .skip_layer_func = lambda linear_weight : not (
456- _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding_allowed
461+ _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding
457462 )
458463 # we need to do the padding here, both for q and the qparams if necessary
459464 def make_names_and_values_dict_func (q , qparams ):
460465 k = q .shape [1 ]
461- new_k = _calc_padded_size_linear_int4 (k , groupsize , inner_k_tiles )
466+ new_k = find_multiple (k , 1024 )
462467 # how much we need to pad the weight
463468 delta_k = new_k - q .shape [1 ]
464469 final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -472,7 +477,7 @@ def make_names_and_values_dict_func(q, qparams):
472477
473478
474479 def convert_for_runtime (self ):
475- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
480+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
476481 return self .mod
477482
478483class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -483,16 +488,17 @@ class WeightOnlyInt4Linear(torch.nn.Module):
483488
484489 def __init__ (
485490 self , in_features : int , out_features : int ,
486- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 ,
491+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
487492 ) -> None :
488493 super ().__init__ ()
494+ self .padding = padding
495+ if padding :
496+ from model import find_multiple
497+ self .origin_in_features = in_features
498+ in_features = find_multiple (in_features , 1024 )
489499
490- # always pad if needed since it becomes a noop at runtime if not needed
491- self .origin_in_features = in_features
492- in_features = _calc_padded_size_linear_int4 (in_features , groupsize , inner_k_tiles )
493500 self .in_features = in_features
494501 self .out_features = out_features
495-
496502 assert not bias , "require bias=False"
497503 self .groupsize = groupsize
498504 self .inner_k_tiles = inner_k_tiles
@@ -510,7 +516,9 @@ def __init__(
510516
511517 def forward (self , input : torch .Tensor ) -> torch .Tensor :
512518 input = input .to (torch .bfloat16 )
513- input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
519+ if self .padding :
520+ import torch .nn .functional as F
521+ input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
514522 return linear_forward_int4 (
515523 input ,
516524 self .weight , self .scales_and_zeros , self .out_features , self .groupsize
0 commit comments