17
17
except :
18
18
pass
19
19
20
- from model import Transformer , find_multiple
20
+ from model import Transformer
21
21
22
22
##### Quantization Primitives ######
23
23
@@ -376,27 +376,29 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376
376
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
377
377
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
378
378
379
- def _calc_padded_size_linear_int4 (k , groupsize = 1 , inner_k_tiles = 1 ):
380
- return find_multiple (k , groupsize , inner_k_tiles * 16 )
381
-
382
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed ):
379
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
383
380
for name , child in module .named_children ():
384
381
if isinstance (child , nn .Linear ):
385
- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
382
+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
386
383
setattr (module , name , WeightOnlyInt4Linear (
387
384
child .in_features , child .out_features , bias = False ,
388
- groupsize = groupsize , inner_k_tiles = inner_k_tiles ,
385
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
386
+ ))
387
+ elif padding :
388
+ setattr (module , name , WeightOnlyInt4Linear (
389
+ child .in_features , child .out_features , bias = False ,
390
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
389
391
))
390
392
else :
391
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed )
393
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
392
394
393
395
394
396
class WeightOnlyInt4QuantHandler :
395
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
397
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
396
398
self .mod = mod
397
399
self .groupsize = groupsize
398
400
self .inner_k_tiles = inner_k_tiles
399
- self .padding_allowed = padding_allowed
401
+ self .padding = padding
400
402
assert groupsize in [32 , 64 , 128 , 256 ]
401
403
assert inner_k_tiles in [2 , 4 , 8 ]
402
404
@@ -418,9 +420,11 @@ def create_quantized_state_dict(self, use_cuda = True):
418
420
419
421
weight = mod .weight .data
420
422
if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
421
- if self .padding_allowed :
423
+ if self .padding :
424
+ from model import find_multiple
425
+ import torch .nn .functional as F
422
426
print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
423
- padded_in_features = _calc_padded_size_linear_int4 (in_features , 1024 )
427
+ padded_in_features = find_multiple (in_features , 1024 )
424
428
weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
425
429
else :
426
430
print (f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -435,30 +439,31 @@ def create_quantized_state_dict(self, use_cuda = True):
435
439
return cur_state_dict
436
440
437
441
def convert_for_runtime (self ):
438
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
442
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
439
443
return self .mod
440
444
441
445
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
442
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
446
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
447
+ from model import find_multiple
443
448
self .mod = mod
444
449
self .groupsize = groupsize
445
450
self .inner_k_tiles = inner_k_tiles
446
- self .padding_allowed = padding_allowed
451
+ self .padding = padding
447
452
self .get_qparams_func = lambda w : get_group_qparams (w , 4 , groupsize )
448
453
self .quantize_func = lambda w , qparams : \
449
454
group_quantize_tensor_from_qparams (w , qparams [0 ], qparams [1 ], 4 , groupsize )
450
455
self .dequantize_func = lambda q , qparams : \
451
456
group_dequantize_tensor_from_qparams (q , qparams [0 ], qparams [1 ], 4 , groupsize ).float ()
452
457
self .combine_qparams_list_func = lambda qparams_list : \
453
458
[torch .cat (x , dim = 1 ) for x in zip (* qparams_list )]
454
- # skip unless padding_allowed =True or its correctly sized
459
+ # skip unless padding =True or its correctly sized
455
460
self .skip_layer_func = lambda linear_weight : not (
456
- _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding_allowed
461
+ _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding
457
462
)
458
463
# we need to do the padding here, both for q and the qparams if necessary
459
464
def make_names_and_values_dict_func (q , qparams ):
460
465
k = q .shape [1 ]
461
- new_k = _calc_padded_size_linear_int4 (k , groupsize , inner_k_tiles )
466
+ new_k = find_multiple (k , 1024 )
462
467
# how much we need to pad the weight
463
468
delta_k = new_k - q .shape [1 ]
464
469
final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -472,7 +477,7 @@ def make_names_and_values_dict_func(q, qparams):
472
477
473
478
474
479
def convert_for_runtime (self ):
475
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
480
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
476
481
return self .mod
477
482
478
483
class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -483,16 +488,17 @@ class WeightOnlyInt4Linear(torch.nn.Module):
483
488
484
489
def __init__ (
485
490
self , in_features : int , out_features : int ,
486
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 ,
491
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
487
492
) -> None :
488
493
super ().__init__ ()
494
+ self .padding = padding
495
+ if padding :
496
+ from model import find_multiple
497
+ self .origin_in_features = in_features
498
+ in_features = find_multiple (in_features , 1024 )
489
499
490
- # always pad if needed since it becomes a noop at runtime if not needed
491
- self .origin_in_features = in_features
492
- in_features = _calc_padded_size_linear_int4 (in_features , groupsize , inner_k_tiles )
493
500
self .in_features = in_features
494
501
self .out_features = out_features
495
-
496
502
assert not bias , "require bias=False"
497
503
self .groupsize = groupsize
498
504
self .inner_k_tiles = inner_k_tiles
@@ -510,7 +516,9 @@ def __init__(
510
516
511
517
def forward (self , input : torch .Tensor ) -> torch .Tensor :
512
518
input = input .to (torch .bfloat16 )
513
- input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
519
+ if self .padding :
520
+ import torch .nn .functional as F
521
+ input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
514
522
return linear_forward_int4 (
515
523
input ,
516
524
self .weight , self .scales_and_zeros , self .out_features , self .groupsize
0 commit comments