Skip to content

Commit ff6dd9e

Browse files
ENH: Warn when loading PiSSA/OLoRA together with other adapters (#2186)
Resolves #2184 Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. We now warn users about that and tell them how to mitigate this.
1 parent 214345e commit ff6dd9e

File tree

4 files changed

+64
-5
lines changed

4 files changed

+64
-5
lines changed

docs/source/developer_guides/lora.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Alternatively, execute fast SVD, which takes only a few seconds. The number of i
5252
```python
5353
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
5454
```
55-
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning).
55+
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning).
5656

5757
### OLoRA
5858
[OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance.

src/peft/mixed_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def get_model_status(self):
348348
def _split_kwargs(cls, kwargs: dict[str, Any]):
349349
return PeftModel._split_kwargs(kwargs)
350350

351+
def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:
352+
return PeftModel._check_new_adapter_config(self, peft_config, is_trainable=is_trainable)
353+
351354
def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
352355
"""
353356
Load a trained adapter into the model.

src/peft/peft_model.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,29 @@ def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_wei
11281128
os.makedirs(base_name)
11291129
safe_save_file(safe_dict, new_fname, metadata=metadata)
11301130

1131+
def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:
1132+
"""Perform checks on newly added PEFT configs to ensure integrity."""
1133+
if peft_config.is_prompt_learning and is_trainable:
1134+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
1135+
1136+
# Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters.
1137+
all_configs = [peft_config] + list(self.peft_config.values())
1138+
if len(all_configs) > 1:
1139+
if any(getattr(config, "init_lora_weights", None) == "pissa" for config in all_configs):
1140+
msg = (
1141+
"PiSSA changes the base weights of the model and should thus not be used with other adapters. "
1142+
"Consider converting the PiSSA adapter into a normal LoRA adapter: "
1143+
"https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora"
1144+
)
1145+
warnings.warn(msg)
1146+
elif any(getattr(config, "init_lora_weights", None) == "olora" for config in all_configs):
1147+
msg = (
1148+
"OLoRA changes the base weights of the model and should thus not be used with other adapters. "
1149+
"Consider converting the OLoRA adapter into a normal LoRA adapter: "
1150+
"https://github.com/huggingface/peft/tree/main/examples/olora_finetuning#olora-and-lora"
1151+
)
1152+
warnings.warn(msg)
1153+
11311154
def load_adapter(
11321155
self,
11331156
model_id: Union[str, os.PathLike],
@@ -1191,10 +1214,8 @@ def load_adapter(
11911214
ephemeral_gpu_offload=ephemeral_gpu_offload,
11921215
**hf_hub_download_kwargs,
11931216
)
1194-
if peft_config.is_prompt_learning and is_trainable:
1195-
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
1196-
else:
1197-
peft_config.inference_mode = not is_trainable
1217+
self._check_new_adapter_config(peft_config, is_trainable=is_trainable)
1218+
peft_config.inference_mode = not is_trainable
11981219
self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)
11991220

12001221
adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)

tests/test_initialization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,41 @@ def test_lora_config_pissa_olora_warns(self, config_kwargs, should_warn, recwarn
948948
LoraConfig(**config_kwargs)
949949
assert not recwarn.list
950950

951+
@pytest.mark.parametrize("init_method", ["pissa", "olora"])
952+
@pytest.mark.parametrize("pissa_olora_loaded_first", [False, True])
953+
def test_load_pissa_olora_with_other_adapter_warns(self, init_method, pissa_olora_loaded_first, recwarn, tmp_path):
954+
# Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. Check for a
955+
# warning. See #2184.
956+
957+
# create an adapter without PiSSA/OloRA
958+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
959+
model = AutoModelForCausalLM.from_pretrained(model_id)
960+
model = get_peft_model(model, LoraConfig(init_lora_weights=True))
961+
model.save_pretrained(tmp_path / "adapter0")
962+
del model
963+
964+
# create a model with PiSSA/OLoRA
965+
model = AutoModelForCausalLM.from_pretrained(model_id)
966+
model = get_peft_model(model, LoraConfig(init_lora_weights=init_method))
967+
model.save_pretrained(tmp_path / "adapter1")
968+
del model
969+
970+
# load the model
971+
if pissa_olora_loaded_first:
972+
path0, path1 = tmp_path / "adapter1", tmp_path / "adapter0"
973+
else:
974+
path0, path1 = tmp_path / "adapter0", tmp_path / "adapter1"
975+
976+
model = AutoModelForCausalLM.from_pretrained(model_id)
977+
model = PeftModel.from_pretrained(model, path0)
978+
model = model.load_adapter(path1, adapter_name="other")
979+
980+
if init_method == "pissa":
981+
msg = "PiSSA changes the base weights of the model and should thus not be used with other adapters"
982+
else:
983+
msg = "OLoRA changes the base weights of the model and should thus not be used with other adapters"
984+
assert any(str(w.message).startswith(msg) for w in recwarn.list)
985+
951986
def test_lora_rslora_scaling(self):
952987
# default is True
953988
torch.manual_seed(0)

0 commit comments

Comments
 (0)