Skip to content

Commit a147137

Browse files
authored
fix bnb loading (#2529)
1 parent 78c8908 commit a147137

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

onmt/model_builder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,24 @@ def load_test_model(opt, device_id=0, model_path=None):
9595

9696
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
9797

98-
if hasattr(model_opt, "quant_type") and model_opt.quant_type not in [
98+
if hasattr(model_opt, "quant_type") and model_opt.quant_type in [
9999
"llm_awq",
100100
"aawq_gemm",
101101
"aawq_gemv",
102-
]:
102+
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
103+
if hasattr(opt, "quant_type") and opt.quant_type != model_opt.quant_type:
104+
raise ValueError(
105+
"Model is a awq quantized model, cannot overwrite with another quant method"
106+
)
107+
108+
elif hasattr(opt, "quant_type") and opt.quant_type not in [
109+
"llm_awq",
110+
"aawq_gemm",
111+
"aawq_gemv",
112+
]: # we still want to be able to load fp16/32 models with bnb 4bit to minimize ram footprint
103113
model_opt.quant_layers = opt.quant_layers
104114
model_opt.quant_type = opt.quant_type
115+
model_opt.lora_layers = []
105116

106117
if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel":
107118
model_opt.world_size = opt.world_size

0 commit comments

Comments
 (0)