66
77try :
88    os .environ ["BITSANDBYTES_NOWELCOME" ] =  "1" 
9-     import  bitsandbytes  as  bnb 
9+     from  bitsandbytes  import  MatmulLtState 
10+     from  bitsandbytes .nn  import  Linear4bit , Linear8bitLt , Params4bit , Int8Params 
1011except  ImportError :
1112    raise  ImportError ("Install bitsandbytes to use 4/8bit compression" )
1213
1314
14- class  Linear4bit (nn .Linear ):
15-     def  __init__ (
16-         self ,
17-         input_features ,
18-         output_features ,
19-         bias = True ,
20-         compute_dtype = None ,
21-         compress_statistics = True ,
22-         quant_type = "fp4" ,
23-         device = torch .device ("cpu" ),
24-     ):
25-         super ().__init__ (input_features , output_features , bias )
26- 
27-         self .weight  =  bnb .nn .Params4bit (
28-             self .weight .data ,
29-             requires_grad = False ,
30-             compress_statistics = compress_statistics ,
31-             quant_type = quant_type ,
32-         )
33-         self .compute_dtype  =  compute_dtype 
34- 
35-     def  forward (self , x : torch .Tensor ):
36-         # weights are cast automatically as Int8Params, but the bias has to be cast manually 
37-         if  self .bias  is  not None  and  self .bias .dtype  !=  x .dtype :
38-             self .bias .data  =  self .bias .data .to (x .dtype )
39- 
40-         if  getattr (self .weight , "quant_state" , None ) is  None :
41-             print (
42-                 "FP4 quantization state not initialized. Please call .cuda() or" 
43-                 " .to(device) on the LinearFP4 layer first." 
44-             )
45-         inp_dtype  =  x .dtype 
46-         if  self .compute_dtype  is  not None :
47-             x  =  x .to (self .compute_dtype )
48- 
49-         bias  =  None  if  self .bias  is  None  else  self .bias .to (self .compute_dtype )
50-         out  =  bnb .matmul_4bit (
51-             x , self .weight .t (), bias = bias , quant_state = self .weight .quant_state 
52-         )
53- 
54-         out  =  out .to (inp_dtype )
55- 
56-         return  out 
57- 
58- 
59- class  Linear8bitLt (nn .Linear ):
60-     def  __init__ (
61-         self ,
62-         input_features ,
63-         output_features ,
64-         bias = True ,
65-         has_fp16_weights = True ,
66-         threshold = 0.0 ,
67-         index = None ,
68-         device = torch .device ("cpu" ),
69-     ):
70-         super ().__init__ (input_features , output_features , bias )
71-         self .state  =  bnb .MatmulLtState ()
72-         self .index  =  index 
73- 
74-         self .state .threshold  =  threshold 
75-         self .state .has_fp16_weights  =  has_fp16_weights 
76-         self .state .memory_efficient_backward  =  False 
77-         if  threshold  >  0.0  and  not  has_fp16_weights :
78-             self .state .use_pool  =  True 
79- 
80-         self .weight  =  bnb .nn .Int8Params (
81-             self .weight .data ,
82-             has_fp16_weights = has_fp16_weights ,
83-             requires_grad = has_fp16_weights ,
84-         )
85- 
86-     def  _save_to_state_dict (self , destination , prefix , keep_vars ):
87-         if  (
88-             not  self .state .has_fp16_weights 
89-             and  self .state .CB  is  None 
90-             and  self .state .CxB  is  not None 
91-         ):
92-             # reorder weight layout back from ampere/turing to row 
93-             reorder_layout  =  True 
94-             weight_clone  =  self .weight .data .clone ()
95-         else :
96-             reorder_layout  =  False 
97- 
98-         try :
99-             if  reorder_layout :
100-                 self .weight .data  =  bnb .autograd ._functions .undo_layout (
101-                     self .state .CxB , self .state .tile_indices 
102-                 )
103- 
104-             super ()._save_to_state_dict (destination , prefix , keep_vars )
105- 
106-             # we only need to save SCB as extra data, because CB for quantized weights 
107-             # is already stored in weight.data 
108-             weight_name  =  "SCB" 
109- 
110-             # case 1: .cuda was called, SCB is in self.weight 
111-             param_from_weight  =  getattr (self .weight , weight_name )
112-             # case 2: self.init_8bit_state was called, SCB is in self.state 
113-             param_from_state  =  getattr (self .state , weight_name )
114- 
115-             key_name  =  prefix  +  f"{ weight_name }  
116-             if  param_from_weight  is  not None :
117-                 destination [key_name ] =  (
118-                     param_from_weight  if  keep_vars  else  param_from_weight .detach ()
119-                 )
120-             elif  not  self .state .has_fp16_weights  and  param_from_state  is  not None :
121-                 destination [key_name ] =  (
122-                     param_from_state  if  keep_vars  else  param_from_state .detach ()
123-                 )
124-         finally :
125-             if  reorder_layout :
126-                 self .weight .data  =  weight_clone 
127- 
128-     def  _load_from_state_dict (
129-         self ,
130-         state_dict ,
131-         prefix ,
132-         local_metadata ,
133-         strict ,
134-         missing_keys ,
135-         unexpected_keys ,
136-         error_msgs ,
137-     ):
138-         super ()._load_from_state_dict (
139-             state_dict ,
140-             prefix ,
141-             local_metadata ,
142-             strict ,
143-             missing_keys ,
144-             unexpected_keys ,
145-             error_msgs ,
146-         )
147-         for  key  in  unexpected_keys :
148-             input_name  =  key [len (prefix ) :]
149-             if  input_name  ==  "SCB" :
150-                 if  self .weight .SCB  is  None :
151-                     # buffers not yet initialized, can't call them directly without 
152-                     raise  RuntimeError (
153-                         "Loading a quantized checkpoint into non-quantized Linear8bitLt is " 
154-                         "not supported. Please call module.cuda() before module.load_state_dict()" 
155-                     )
156- 
157-                 input_param  =  state_dict [key ]
158-                 self .weight .SCB .copy_ (input_param )
159-                 unexpected_keys .remove (key )
160- 
161-     def  init_8bit_state (self ):
162-         self .state .CB  =  self .weight .CB 
163-         self .state .SCB  =  self .weight .SCB 
164-         self .weight .CB  =  None 
165-         self .weight .SCB  =  None 
166- 
167-     def  forward (self , x : torch .Tensor ):
168-         self .state .is_training  =  self .training 
169-         if  self .weight .CB  is  not None :
170-             self .init_8bit_state ()
171- 
172-         # weights are cast automatically as Int8Params, but the bias has to be cast manually 
173-         if  self .bias  is  not None  and  self .bias .dtype  !=  x .dtype :
174-             self .bias .data  =  self .bias .data .to (x .dtype )
175- 
176-         out  =  bnb .matmul (x , self .weight , bias = self .bias , state = self .state )
177- 
178-         if  not  self .state .has_fp16_weights :
179-             if  self .state .CB  is  not None  and  self .state .CxB  is  not None :
180-                 # we converted 8-bit row major to turing/ampere format in the first inference pass 
181-                 # we no longer need the row-major weight 
182-                 del  self .state .CB 
183-                 self .weight .data  =  self .state .CxB 
184-         return  out 
185- 
186- 
18715def  replace_bnb_linear (
18816    model ,
18917    module_to_convert = [],
@@ -206,6 +34,19 @@ def replace_bnb_linear(
20634                    module .bias  is  not None ,
20735                    has_fp16_weights = False ,
20836                    threshold = threshold ,
37+                     device = torch .device ("cpu" ),
38+                 )
39+                 model ._modules [name ].state  =  MatmulLtState ()
40+                 model ._modules [name ].index  =  None 
41+                 model ._modules [name ].state .threshold  =  threshold 
42+                 model ._modules [name ].state .has_fp16_weights  =  False 
43+                 model ._modules [name ].state .memory_efficient_backward  =  False 
44+                 if  threshold  >  0.0 :
45+                     model ._modules [name ].state .use_pool  =  True 
46+                 model ._modules [name ].weight  =  Int8Params (
47+                     model ._modules [name ].weight .data ,
48+                     has_fp16_weights = False ,
49+                     requires_grad = False ,
20950                )
21051            elif  q_type  in  ["bnb_FP4" , "bnb_NF4" ]:
21152                model ._modules [name ] =  nn .utils .skip_init (
@@ -215,5 +56,12 @@ def replace_bnb_linear(
21556                    module .bias  is  not None ,
21657                    compute_dtype = compute_dtype ,
21758                    quant_type = q_type [- 3 :].lower (),  # 'fp4' or 'nf4' 
59+                     device = torch .device ("cpu" ),
60+                 )
61+                 model ._modules [name ].weight  =  Params4bit (
62+                     model ._modules [name ].weight .data ,
63+                     requires_grad = False ,
64+                     quant_type = q_type [- 3 :].lower (),
21865                )
66+                 model ._modules [name ].compute_dtype  =  compute_dtype 
21967    return  model 
0 commit comments