Skip to content

Commit 3d4c8de

Browse files
authored
update to skipinit weights with bnb (#2524)
1 parent 2828eb7 commit 3d4c8de

File tree

2 files changed

+23
-175
lines changed

2 files changed

+23
-175
lines changed

onmt/modules/bnb_linear.py

+22-174
Original file line numberDiff line numberDiff line change
@@ -6,184 +6,12 @@
66

77
try:
88
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
9-
import bitsandbytes as bnb
9+
from bitsandbytes import MatmulLtState
10+
from bitsandbytes.nn import Linear4bit, Linear8bitLt, Params4bit, Int8Params
1011
except ImportError:
1112
raise ImportError("Install bitsandbytes to use 4/8bit compression")
1213

1314

14-
class Linear4bit(nn.Linear):
15-
def __init__(
16-
self,
17-
input_features,
18-
output_features,
19-
bias=True,
20-
compute_dtype=None,
21-
compress_statistics=True,
22-
quant_type="fp4",
23-
device=torch.device("cpu"),
24-
):
25-
super().__init__(input_features, output_features, bias)
26-
27-
self.weight = bnb.nn.Params4bit(
28-
self.weight.data,
29-
requires_grad=False,
30-
compress_statistics=compress_statistics,
31-
quant_type=quant_type,
32-
)
33-
self.compute_dtype = compute_dtype
34-
35-
def forward(self, x: torch.Tensor):
36-
# weights are cast automatically as Int8Params, but the bias has to be cast manually
37-
if self.bias is not None and self.bias.dtype != x.dtype:
38-
self.bias.data = self.bias.data.to(x.dtype)
39-
40-
if getattr(self.weight, "quant_state", None) is None:
41-
print(
42-
"FP4 quantization state not initialized. Please call .cuda() or"
43-
" .to(device) on the LinearFP4 layer first."
44-
)
45-
inp_dtype = x.dtype
46-
if self.compute_dtype is not None:
47-
x = x.to(self.compute_dtype)
48-
49-
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
50-
out = bnb.matmul_4bit(
51-
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
52-
)
53-
54-
out = out.to(inp_dtype)
55-
56-
return out
57-
58-
59-
class Linear8bitLt(nn.Linear):
60-
def __init__(
61-
self,
62-
input_features,
63-
output_features,
64-
bias=True,
65-
has_fp16_weights=True,
66-
threshold=0.0,
67-
index=None,
68-
device=torch.device("cpu"),
69-
):
70-
super().__init__(input_features, output_features, bias)
71-
self.state = bnb.MatmulLtState()
72-
self.index = index
73-
74-
self.state.threshold = threshold
75-
self.state.has_fp16_weights = has_fp16_weights
76-
self.state.memory_efficient_backward = False
77-
if threshold > 0.0 and not has_fp16_weights:
78-
self.state.use_pool = True
79-
80-
self.weight = bnb.nn.Int8Params(
81-
self.weight.data,
82-
has_fp16_weights=has_fp16_weights,
83-
requires_grad=has_fp16_weights,
84-
)
85-
86-
def _save_to_state_dict(self, destination, prefix, keep_vars):
87-
if (
88-
not self.state.has_fp16_weights
89-
and self.state.CB is None
90-
and self.state.CxB is not None
91-
):
92-
# reorder weight layout back from ampere/turing to row
93-
reorder_layout = True
94-
weight_clone = self.weight.data.clone()
95-
else:
96-
reorder_layout = False
97-
98-
try:
99-
if reorder_layout:
100-
self.weight.data = bnb.autograd._functions.undo_layout(
101-
self.state.CxB, self.state.tile_indices
102-
)
103-
104-
super()._save_to_state_dict(destination, prefix, keep_vars)
105-
106-
# we only need to save SCB as extra data, because CB for quantized weights
107-
# is already stored in weight.data
108-
weight_name = "SCB"
109-
110-
# case 1: .cuda was called, SCB is in self.weight
111-
param_from_weight = getattr(self.weight, weight_name)
112-
# case 2: self.init_8bit_state was called, SCB is in self.state
113-
param_from_state = getattr(self.state, weight_name)
114-
115-
key_name = prefix + f"{weight_name}"
116-
if param_from_weight is not None:
117-
destination[key_name] = (
118-
param_from_weight if keep_vars else param_from_weight.detach()
119-
)
120-
elif not self.state.has_fp16_weights and param_from_state is not None:
121-
destination[key_name] = (
122-
param_from_state if keep_vars else param_from_state.detach()
123-
)
124-
finally:
125-
if reorder_layout:
126-
self.weight.data = weight_clone
127-
128-
def _load_from_state_dict(
129-
self,
130-
state_dict,
131-
prefix,
132-
local_metadata,
133-
strict,
134-
missing_keys,
135-
unexpected_keys,
136-
error_msgs,
137-
):
138-
super()._load_from_state_dict(
139-
state_dict,
140-
prefix,
141-
local_metadata,
142-
strict,
143-
missing_keys,
144-
unexpected_keys,
145-
error_msgs,
146-
)
147-
for key in unexpected_keys:
148-
input_name = key[len(prefix) :]
149-
if input_name == "SCB":
150-
if self.weight.SCB is None:
151-
# buffers not yet initialized, can't call them directly without
152-
raise RuntimeError(
153-
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
154-
"not supported. Please call module.cuda() before module.load_state_dict()"
155-
)
156-
157-
input_param = state_dict[key]
158-
self.weight.SCB.copy_(input_param)
159-
unexpected_keys.remove(key)
160-
161-
def init_8bit_state(self):
162-
self.state.CB = self.weight.CB
163-
self.state.SCB = self.weight.SCB
164-
self.weight.CB = None
165-
self.weight.SCB = None
166-
167-
def forward(self, x: torch.Tensor):
168-
self.state.is_training = self.training
169-
if self.weight.CB is not None:
170-
self.init_8bit_state()
171-
172-
# weights are cast automatically as Int8Params, but the bias has to be cast manually
173-
if self.bias is not None and self.bias.dtype != x.dtype:
174-
self.bias.data = self.bias.data.to(x.dtype)
175-
176-
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
177-
178-
if not self.state.has_fp16_weights:
179-
if self.state.CB is not None and self.state.CxB is not None:
180-
# we converted 8-bit row major to turing/ampere format in the first inference pass
181-
# we no longer need the row-major weight
182-
del self.state.CB
183-
self.weight.data = self.state.CxB
184-
return out
185-
186-
18715
def replace_bnb_linear(
18816
model,
18917
module_to_convert=[],
@@ -206,6 +34,19 @@ def replace_bnb_linear(
20634
module.bias is not None,
20735
has_fp16_weights=False,
20836
threshold=threshold,
37+
device=torch.device("cpu"),
38+
)
39+
model._modules[name].state = MatmulLtState()
40+
model._modules[name].index = None
41+
model._modules[name].state.threshold = threshold
42+
model._modules[name].state.has_fp16_weights = False
43+
model._modules[name].state.memory_efficient_backward = False
44+
if threshold > 0.0:
45+
model._modules[name].state.use_pool = True
46+
model._modules[name].weight = Int8Params(
47+
model._modules[name].weight.data,
48+
has_fp16_weights=False,
49+
requires_grad=False,
20950
)
21051
elif q_type in ["bnb_FP4", "bnb_NF4"]:
21152
model._modules[name] = nn.utils.skip_init(
@@ -215,5 +56,12 @@ def replace_bnb_linear(
21556
module.bias is not None,
21657
compute_dtype=compute_dtype,
21758
quant_type=q_type[-3:].lower(), # 'fp4' or 'nf4'
59+
device=torch.device("cpu"),
60+
)
61+
model._modules[name].weight = Params4bit(
62+
model._modules[name].weight.data,
63+
requires_grad=False,
64+
quant_type=q_type[-3:].lower(),
21865
)
66+
model._modules[name].compute_dtype = compute_dtype
21967
return model

requirements.opt.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ sentencepiece>=0.1.94,<0.1.98
33
subword-nmt>=0.3.7
44
rapidfuzz
55
scipy
6-
bitsandbytes>=0.39.1
6+
bitsandbytes>=0.41.2
77
safetensors
88
spacy

0 commit comments

Comments
 (0)