Skip to content

Commit

Permalink
ENH: Warn when loading PiSSA/OLoRA together with other adapters (#2186)
Browse files Browse the repository at this point in the history
Resolves #2184

Since PiSSA/OLoRA modifies the base weights, it should not be combined
with other adapters. We now warn users about that and tell them how to
mitigate this.
  • Loading branch information
BenjaminBossan authored Oct 30, 2024
1 parent 214345e commit ff6dd9e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Alternatively, execute fast SVD, which takes only a few seconds. The number of i
```python
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
```
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning).
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning).

### OLoRA
[OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance.
Expand Down
3 changes: 3 additions & 0 deletions src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def get_model_status(self):
def _split_kwargs(cls, kwargs: dict[str, Any]):
return PeftModel._split_kwargs(kwargs)

def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:
return PeftModel._check_new_adapter_config(self, peft_config, is_trainable=is_trainable)

def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
"""
Load a trained adapter into the model.
Expand Down
29 changes: 25 additions & 4 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,29 @@ def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_wei
os.makedirs(base_name)
safe_save_file(safe_dict, new_fname, metadata=metadata)

def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:
"""Perform checks on newly added PEFT configs to ensure integrity."""
if peft_config.is_prompt_learning and is_trainable:
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")

# Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters.
all_configs = [peft_config] + list(self.peft_config.values())
if len(all_configs) > 1:
if any(getattr(config, "init_lora_weights", None) == "pissa" for config in all_configs):
msg = (
"PiSSA changes the base weights of the model and should thus not be used with other adapters. "
"Consider converting the PiSSA adapter into a normal LoRA adapter: "
"https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora"
)
warnings.warn(msg)
elif any(getattr(config, "init_lora_weights", None) == "olora" for config in all_configs):
msg = (
"OLoRA changes the base weights of the model and should thus not be used with other adapters. "
"Consider converting the OLoRA adapter into a normal LoRA adapter: "
"https://github.com/huggingface/peft/tree/main/examples/olora_finetuning#olora-and-lora"
)
warnings.warn(msg)

def load_adapter(
self,
model_id: Union[str, os.PathLike],
Expand Down Expand Up @@ -1191,10 +1214,8 @@ def load_adapter(
ephemeral_gpu_offload=ephemeral_gpu_offload,
**hf_hub_download_kwargs,
)
if peft_config.is_prompt_learning and is_trainable:
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
else:
peft_config.inference_mode = not is_trainable
self._check_new_adapter_config(peft_config, is_trainable=is_trainable)
peft_config.inference_mode = not is_trainable
self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)

adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,41 @@ def test_lora_config_pissa_olora_warns(self, config_kwargs, should_warn, recwarn
LoraConfig(**config_kwargs)
assert not recwarn.list

@pytest.mark.parametrize("init_method", ["pissa", "olora"])
@pytest.mark.parametrize("pissa_olora_loaded_first", [False, True])
def test_load_pissa_olora_with_other_adapter_warns(self, init_method, pissa_olora_loaded_first, recwarn, tmp_path):
# Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. Check for a
# warning. See #2184.

# create an adapter without PiSSA/OloRA
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
model = get_peft_model(model, LoraConfig(init_lora_weights=True))
model.save_pretrained(tmp_path / "adapter0")
del model

# create a model with PiSSA/OLoRA
model = AutoModelForCausalLM.from_pretrained(model_id)
model = get_peft_model(model, LoraConfig(init_lora_weights=init_method))
model.save_pretrained(tmp_path / "adapter1")
del model

# load the model
if pissa_olora_loaded_first:
path0, path1 = tmp_path / "adapter1", tmp_path / "adapter0"
else:
path0, path1 = tmp_path / "adapter0", tmp_path / "adapter1"

model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path0)
model = model.load_adapter(path1, adapter_name="other")

if init_method == "pissa":
msg = "PiSSA changes the base weights of the model and should thus not be used with other adapters"
else:
msg = "OLoRA changes the base weights of the model and should thus not be used with other adapters"
assert any(str(w.message).startswith(msg) for w in recwarn.list)

def test_lora_rslora_scaling(self):
# default is True
torch.manual_seed(0)
Expand Down

0 comments on commit ff6dd9e

Please sign in to comment.