6
6
7
7
try :
8
8
os .environ ["BITSANDBYTES_NOWELCOME" ] = "1"
9
- import bitsandbytes as bnb
9
+ from bitsandbytes import MatmulLtState
10
+ from bitsandbytes .nn import Linear4bit , Linear8bitLt , Params4bit , Int8Params
10
11
except ImportError :
11
12
raise ImportError ("Install bitsandbytes to use 4/8bit compression" )
12
13
13
14
14
- class Linear4bit (nn .Linear ):
15
- def __init__ (
16
- self ,
17
- input_features ,
18
- output_features ,
19
- bias = True ,
20
- compute_dtype = None ,
21
- compress_statistics = True ,
22
- quant_type = "fp4" ,
23
- device = torch .device ("cpu" ),
24
- ):
25
- super ().__init__ (input_features , output_features , bias )
26
-
27
- self .weight = bnb .nn .Params4bit (
28
- self .weight .data ,
29
- requires_grad = False ,
30
- compress_statistics = compress_statistics ,
31
- quant_type = quant_type ,
32
- )
33
- self .compute_dtype = compute_dtype
34
-
35
- def forward (self , x : torch .Tensor ):
36
- # weights are cast automatically as Int8Params, but the bias has to be cast manually
37
- if self .bias is not None and self .bias .dtype != x .dtype :
38
- self .bias .data = self .bias .data .to (x .dtype )
39
-
40
- if getattr (self .weight , "quant_state" , None ) is None :
41
- print (
42
- "FP4 quantization state not initialized. Please call .cuda() or"
43
- " .to(device) on the LinearFP4 layer first."
44
- )
45
- inp_dtype = x .dtype
46
- if self .compute_dtype is not None :
47
- x = x .to (self .compute_dtype )
48
-
49
- bias = None if self .bias is None else self .bias .to (self .compute_dtype )
50
- out = bnb .matmul_4bit (
51
- x , self .weight .t (), bias = bias , quant_state = self .weight .quant_state
52
- )
53
-
54
- out = out .to (inp_dtype )
55
-
56
- return out
57
-
58
-
59
- class Linear8bitLt (nn .Linear ):
60
- def __init__ (
61
- self ,
62
- input_features ,
63
- output_features ,
64
- bias = True ,
65
- has_fp16_weights = True ,
66
- threshold = 0.0 ,
67
- index = None ,
68
- device = torch .device ("cpu" ),
69
- ):
70
- super ().__init__ (input_features , output_features , bias )
71
- self .state = bnb .MatmulLtState ()
72
- self .index = index
73
-
74
- self .state .threshold = threshold
75
- self .state .has_fp16_weights = has_fp16_weights
76
- self .state .memory_efficient_backward = False
77
- if threshold > 0.0 and not has_fp16_weights :
78
- self .state .use_pool = True
79
-
80
- self .weight = bnb .nn .Int8Params (
81
- self .weight .data ,
82
- has_fp16_weights = has_fp16_weights ,
83
- requires_grad = has_fp16_weights ,
84
- )
85
-
86
- def _save_to_state_dict (self , destination , prefix , keep_vars ):
87
- if (
88
- not self .state .has_fp16_weights
89
- and self .state .CB is None
90
- and self .state .CxB is not None
91
- ):
92
- # reorder weight layout back from ampere/turing to row
93
- reorder_layout = True
94
- weight_clone = self .weight .data .clone ()
95
- else :
96
- reorder_layout = False
97
-
98
- try :
99
- if reorder_layout :
100
- self .weight .data = bnb .autograd ._functions .undo_layout (
101
- self .state .CxB , self .state .tile_indices
102
- )
103
-
104
- super ()._save_to_state_dict (destination , prefix , keep_vars )
105
-
106
- # we only need to save SCB as extra data, because CB for quantized weights
107
- # is already stored in weight.data
108
- weight_name = "SCB"
109
-
110
- # case 1: .cuda was called, SCB is in self.weight
111
- param_from_weight = getattr (self .weight , weight_name )
112
- # case 2: self.init_8bit_state was called, SCB is in self.state
113
- param_from_state = getattr (self .state , weight_name )
114
-
115
- key_name = prefix + f"{ weight_name } "
116
- if param_from_weight is not None :
117
- destination [key_name ] = (
118
- param_from_weight if keep_vars else param_from_weight .detach ()
119
- )
120
- elif not self .state .has_fp16_weights and param_from_state is not None :
121
- destination [key_name ] = (
122
- param_from_state if keep_vars else param_from_state .detach ()
123
- )
124
- finally :
125
- if reorder_layout :
126
- self .weight .data = weight_clone
127
-
128
- def _load_from_state_dict (
129
- self ,
130
- state_dict ,
131
- prefix ,
132
- local_metadata ,
133
- strict ,
134
- missing_keys ,
135
- unexpected_keys ,
136
- error_msgs ,
137
- ):
138
- super ()._load_from_state_dict (
139
- state_dict ,
140
- prefix ,
141
- local_metadata ,
142
- strict ,
143
- missing_keys ,
144
- unexpected_keys ,
145
- error_msgs ,
146
- )
147
- for key in unexpected_keys :
148
- input_name = key [len (prefix ) :]
149
- if input_name == "SCB" :
150
- if self .weight .SCB is None :
151
- # buffers not yet initialized, can't call them directly without
152
- raise RuntimeError (
153
- "Loading a quantized checkpoint into non-quantized Linear8bitLt is "
154
- "not supported. Please call module.cuda() before module.load_state_dict()"
155
- )
156
-
157
- input_param = state_dict [key ]
158
- self .weight .SCB .copy_ (input_param )
159
- unexpected_keys .remove (key )
160
-
161
- def init_8bit_state (self ):
162
- self .state .CB = self .weight .CB
163
- self .state .SCB = self .weight .SCB
164
- self .weight .CB = None
165
- self .weight .SCB = None
166
-
167
- def forward (self , x : torch .Tensor ):
168
- self .state .is_training = self .training
169
- if self .weight .CB is not None :
170
- self .init_8bit_state ()
171
-
172
- # weights are cast automatically as Int8Params, but the bias has to be cast manually
173
- if self .bias is not None and self .bias .dtype != x .dtype :
174
- self .bias .data = self .bias .data .to (x .dtype )
175
-
176
- out = bnb .matmul (x , self .weight , bias = self .bias , state = self .state )
177
-
178
- if not self .state .has_fp16_weights :
179
- if self .state .CB is not None and self .state .CxB is not None :
180
- # we converted 8-bit row major to turing/ampere format in the first inference pass
181
- # we no longer need the row-major weight
182
- del self .state .CB
183
- self .weight .data = self .state .CxB
184
- return out
185
-
186
-
187
15
def replace_bnb_linear (
188
16
model ,
189
17
module_to_convert = [],
@@ -206,6 +34,19 @@ def replace_bnb_linear(
206
34
module .bias is not None ,
207
35
has_fp16_weights = False ,
208
36
threshold = threshold ,
37
+ device = torch .device ("cpu" ),
38
+ )
39
+ model ._modules [name ].state = MatmulLtState ()
40
+ model ._modules [name ].index = None
41
+ model ._modules [name ].state .threshold = threshold
42
+ model ._modules [name ].state .has_fp16_weights = False
43
+ model ._modules [name ].state .memory_efficient_backward = False
44
+ if threshold > 0.0 :
45
+ model ._modules [name ].state .use_pool = True
46
+ model ._modules [name ].weight = Int8Params (
47
+ model ._modules [name ].weight .data ,
48
+ has_fp16_weights = False ,
49
+ requires_grad = False ,
209
50
)
210
51
elif q_type in ["bnb_FP4" , "bnb_NF4" ]:
211
52
model ._modules [name ] = nn .utils .skip_init (
@@ -215,5 +56,12 @@ def replace_bnb_linear(
215
56
module .bias is not None ,
216
57
compute_dtype = compute_dtype ,
217
58
quant_type = q_type [- 3 :].lower (), # 'fp4' or 'nf4'
59
+ device = torch .device ("cpu" ),
60
+ )
61
+ model ._modules [name ].weight = Params4bit (
62
+ model ._modules [name ].weight .data ,
63
+ requires_grad = False ,
64
+ quant_type = q_type [- 3 :].lower (),
218
65
)
66
+ model ._modules [name ].compute_dtype = compute_dtype
219
67
return model
0 commit comments