From ff6dd9ed7f42f2c7e312ea4e7c13539d2c312944 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 30 Oct 2024 10:16:37 +0100 Subject: [PATCH] ENH: Warn when loading PiSSA/OLoRA together with other adapters (#2186) Resolves #2184 Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. We now warn users about that and tell them how to mitigate this. --- docs/source/developer_guides/lora.md | 2 +- src/peft/mixed_model.py | 3 +++ src/peft/peft_model.py | 29 +++++++++++++++++++---- tests/test_initialization.py | 35 ++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index ebb60f133f..3824b4082d 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -52,7 +52,7 @@ Alternatively, execute fast SVD, which takes only a few seconds. The number of i ```python lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...) ``` -For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning). +For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning). ### OLoRA [OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance. diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 907227de36..098bd0d4fd 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -348,6 +348,9 @@ def get_model_status(self): def _split_kwargs(cls, kwargs: dict[str, Any]): return PeftModel._split_kwargs(kwargs) + def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None: + return PeftModel._check_new_adapter_config(self, peft_config, is_trainable=is_trainable) + def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any): """ Load a trained adapter into the model. diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 4690e76615..a38e0750f2 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1128,6 +1128,29 @@ def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_wei os.makedirs(base_name) safe_save_file(safe_dict, new_fname, metadata=metadata) + def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None: + """Perform checks on newly added PEFT configs to ensure integrity.""" + if peft_config.is_prompt_learning and is_trainable: + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + + # Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. + all_configs = [peft_config] + list(self.peft_config.values()) + if len(all_configs) > 1: + if any(getattr(config, "init_lora_weights", None) == "pissa" for config in all_configs): + msg = ( + "PiSSA changes the base weights of the model and should thus not be used with other adapters. " + "Consider converting the PiSSA adapter into a normal LoRA adapter: " + "https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora" + ) + warnings.warn(msg) + elif any(getattr(config, "init_lora_weights", None) == "olora" for config in all_configs): + msg = ( + "OLoRA changes the base weights of the model and should thus not be used with other adapters. " + "Consider converting the OLoRA adapter into a normal LoRA adapter: " + "https://github.com/huggingface/peft/tree/main/examples/olora_finetuning#olora-and-lora" + ) + warnings.warn(msg) + def load_adapter( self, model_id: Union[str, os.PathLike], @@ -1191,10 +1214,8 @@ def load_adapter( ephemeral_gpu_offload=ephemeral_gpu_offload, **hf_hub_download_kwargs, ) - if peft_config.is_prompt_learning and is_trainable: - raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") - else: - peft_config.inference_mode = not is_trainable + self._check_new_adapter_config(peft_config, is_trainable=is_trainable) + peft_config.inference_mode = not is_trainable self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage) adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 1b1cc2f1f3..4b65161961 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -948,6 +948,41 @@ def test_lora_config_pissa_olora_warns(self, config_kwargs, should_warn, recwarn LoraConfig(**config_kwargs) assert not recwarn.list + @pytest.mark.parametrize("init_method", ["pissa", "olora"]) + @pytest.mark.parametrize("pissa_olora_loaded_first", [False, True]) + def test_load_pissa_olora_with_other_adapter_warns(self, init_method, pissa_olora_loaded_first, recwarn, tmp_path): + # Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters. Check for a + # warning. See #2184. + + # create an adapter without PiSSA/OloRA + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id) + model = get_peft_model(model, LoraConfig(init_lora_weights=True)) + model.save_pretrained(tmp_path / "adapter0") + del model + + # create a model with PiSSA/OLoRA + model = AutoModelForCausalLM.from_pretrained(model_id) + model = get_peft_model(model, LoraConfig(init_lora_weights=init_method)) + model.save_pretrained(tmp_path / "adapter1") + del model + + # load the model + if pissa_olora_loaded_first: + path0, path1 = tmp_path / "adapter1", tmp_path / "adapter0" + else: + path0, path1 = tmp_path / "adapter0", tmp_path / "adapter1" + + model = AutoModelForCausalLM.from_pretrained(model_id) + model = PeftModel.from_pretrained(model, path0) + model = model.load_adapter(path1, adapter_name="other") + + if init_method == "pissa": + msg = "PiSSA changes the base weights of the model and should thus not be used with other adapters" + else: + msg = "OLoRA changes the base weights of the model and should thus not be used with other adapters" + assert any(str(w.message).startswith(msg) for w in recwarn.list) + def test_lora_rslora_scaling(self): # default is True torch.manual_seed(0)