diff --git a/multilora.py b/multilora.py index f53a2a7..2ef5cf5 100644 --- a/multilora.py +++ b/multilora.py @@ -682,6 +682,19 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.reset_lora_parameters(adapter_name) self.to(self.weight.device) + def _get_names_of_computable_adapters(self) -> list[str]: + """ + For an LoRA to be useful we need it to have r>0 and scaling > 0 + Here we filter all the available adapters and return only + those that are useful for computation (eg forward pass or merging) + """ + return [ + adapter_name for adapter_name in self.lora_A.keys() + if self.lora_alpha[adapter_name] != 0 + and self.r[adapter_name] > 0 + ] + + def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha @@ -806,24 +819,24 @@ def __init__( # sequential execution allows you to use multiple heterogenous lora adapters def merge(self): - if self.active_adapter not in self.lora_A.keys(): + if self.active_adapter not in self._get_names_of_computable_adapters(): + warnings.warn(f"Skipping merge of adapter {active_adapter}") return if self.merged: warnings.warn("Already merged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data += self.get_delta_weight(self.active_adapter) - self.merged = True + self.weight.data += self.get_delta_weight(self.active_adapter) + self.merged = True def unmerge(self): - if self.active_adapter not in self.lora_A.keys(): + if self.active_adapter not in self._get_names_of_computable_adapters(): + warnings.warn(f"Skipping un-merge of adapter {active_adapter}") return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data -= self.get_delta_weight(self.active_adapter) - self.merged = False + self.weight.data -= self.get_delta_weight(self.active_adapter) + self.merged = False def get_delta_weight(self, adapter): return ( @@ -834,20 +847,20 @@ def get_delta_weight(self, adapter): * self.scaling[adapter] ) - + def preprocess_weights(self, k): if not self.preprocessed: for adapter_name in self.lora_A.keys(): self.lora_A[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_A[adapter_name].weight.data) self.lora_A[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_A[adapter_name].weight.data) self.lora_A[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_A[adapter_name].weight.data) - + self.lora_B[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_B[adapter_name].weight.data) self.lora_B[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_B[adapter_name].weight.data) self.lora_B[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_B[adapter_name].weight.data) self.preprocessed = True - + def trim_task_vector(self, adapter_name, k, matrix): magnitude = matrix.abs() threshold = torch.kthvalue(magnitude.view(-1), int((1 - k / 100) * magnitude.numel())).values @@ -874,17 +887,14 @@ def forward(self, x: torch.Tensor): result = F.linear(x, transpose( self.weight, self.fan_in_fan_out), bias=self.bias) - adapters = list(self.lora_A.keys()) - + adapters = self._get_names_of_computable_adapters() x = x.to(self.lora_A[adapters[0]].weight.dtype) for adapter_name in adapters: - if self.scaling[adapter_name] != 0: - - temp_result = self.lora_B[adapter_name]( - self.lora_A[adapter_name](x) - ).to(previous_dtype) * self.scaling[adapter_name] - result = torch.add(result, temp_result) + temp_result = self.lora_B[adapter_name]( + self.lora_A[adapter_name](x) + ).to(previous_dtype) * self.scaling[adapter_name] + result = torch.add(result, temp_result) result = result.to(previous_dtype) @@ -1029,12 +1039,10 @@ def get_delta_weight(self, adapter): def forward(self, x: torch.Tensor): previous_dtype = x.dtype - adapters = list(self.lora_A.keys()) - - for adapter_name in adapters: + for adapter_name in self._get_names_of_computable_adapters(): # apply scaling for A self.lora_A_active.weight += self.lora_A[adapter_name].weight * self.scaling[adapter_name] - self.lora_B_active.weight += self.lora_B[adapter_name].weight + self.lora_B_active.weight += self.lora_B[adapter_name].weight result = F.conv2d( x, @@ -1099,9 +1107,9 @@ def forward(self, x: torch.Tensor): if self.disable_adapters: return result - + else: - adapters = list(self.lora_A.keys()) + adapters = self._get_names_of_computable_adapters() if not torch.is_autocast_enabled(): expected_dtype = result.dtype @@ -1113,14 +1121,14 @@ def forward(self, x: torch.Tensor): self.lora_A[adapter_name]( self.lora_dropout[adapter_name](x)) ).to(expected_dtype) \ - * self.scaling[adapter_name] + * self.scaling[adapter_name] else: for adapter_name in adapters: result += self.lora_B[adapter_name]( self.lora_A[adapter_name]( self.lora_dropout[adapter_name](x)) - ) * self.scaling[adapter_name] + ) * self.scaling[adapter_name] return result @@ -1185,18 +1193,18 @@ def preprocess_weights(self, k): self.lora_A[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_A[adapter_name].weight.data) self.lora_A[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_A[adapter_name].weight.data) self.lora_A[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_A[adapter_name].weight.data) - + self.lora_B[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_B[adapter_name].weight.data) self.lora_B[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_B[adapter_name].weight.data) self.lora_B[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_B[adapter_name].weight.data) - + self.preprocessed = True def forward(self, x: torch.Tensor): result = super().forward(x) # if not self.preprocessed: # self.preprocess_weights(25) if not self.disable_adapters: - adapters = list(self.lora_A.keys()) + adapters = self._get_names_of_computable_adapters() if not torch.is_autocast_enabled(): expected_dtype = result.dtype @@ -1205,7 +1213,7 @@ def forward(self, x: torch.Tensor): for adapter_name in adapters: if self.scaling[adapter_name] != 0: - + temp_result = self.lora_B[adapter_name]( self.lora_A[adapter_name](x) ).to(expected_dtype) * self.scaling[adapter_name] @@ -1213,11 +1221,10 @@ def forward(self, x: torch.Tensor): else: for adapter_name in adapters: if self.scaling[adapter_name] != 0: - + temp_result = self.lora_B[adapter_name]( self.lora_A[adapter_name](x) ).to(expected_dtype) * self.scaling[adapter_name] result = torch.add(result, temp_result) return result - \ No newline at end of file