Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding a method to get only the adaptors that are worth computing #14

Open
wants to merge 2 commits into
base: v1
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions multilora.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,19 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

def _get_names_of_computable_adapters(self) -> list[str]:
"""
For an LoRA to be useful we need it to have r>0 and scaling > 0
Here we filter all the available adapters and return only
those that are useful for computation (eg forward pass or merging)
"""
return [
adapter_name for adapter_name in self.lora_A.keys()
if self.lora_alpha[adapter_name] != 0
and self.r[adapter_name] > 0
]


def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
Expand Down Expand Up @@ -806,24 +819,24 @@ def __init__(
# sequential execution allows you to use multiple heterogenous lora adapters

def merge(self):
if self.active_adapter not in self.lora_A.keys():
if self.active_adapter not in self._get_names_of_computable_adapters():
warnings.warn(f"Skipping merge of adapter {active_adapter}")
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data += self.get_delta_weight(self.active_adapter)
self.merged = True
self.weight.data += self.get_delta_weight(self.active_adapter)
self.merged = True

def unmerge(self):
if self.active_adapter not in self.lora_A.keys():
if self.active_adapter not in self._get_names_of_computable_adapters():
warnings.warn(f"Skipping un-merge of adapter {active_adapter}")
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data -= self.get_delta_weight(self.active_adapter)
self.merged = False
self.weight.data -= self.get_delta_weight(self.active_adapter)
self.merged = False

def get_delta_weight(self, adapter):
return (
Expand All @@ -834,20 +847,20 @@ def get_delta_weight(self, adapter):
* self.scaling[adapter]
)


def preprocess_weights(self, k):
if not self.preprocessed:
for adapter_name in self.lora_A.keys():
self.lora_A[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_A[adapter_name].weight.data)
self.lora_A[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_A[adapter_name].weight.data)
self.lora_A[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_A[adapter_name].weight.data)

self.lora_B[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_B[adapter_name].weight.data)
self.lora_B[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_B[adapter_name].weight.data)
self.lora_B[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_B[adapter_name].weight.data)
self.preprocessed = True


def trim_task_vector(self, adapter_name, k, matrix):
magnitude = matrix.abs()
threshold = torch.kthvalue(magnitude.view(-1), int((1 - k / 100) * magnitude.numel())).values
Expand All @@ -874,17 +887,14 @@ def forward(self, x: torch.Tensor):
result = F.linear(x, transpose(
self.weight, self.fan_in_fan_out), bias=self.bias)

adapters = list(self.lora_A.keys())

adapters = self._get_names_of_computable_adapters()
x = x.to(self.lora_A[adapters[0]].weight.dtype)

for adapter_name in adapters:
if self.scaling[adapter_name] != 0:

temp_result = self.lora_B[adapter_name](
self.lora_A[adapter_name](x)
).to(previous_dtype) * self.scaling[adapter_name]
result = torch.add(result, temp_result)
temp_result = self.lora_B[adapter_name](
self.lora_A[adapter_name](x)
).to(previous_dtype) * self.scaling[adapter_name]
result = torch.add(result, temp_result)

result = result.to(previous_dtype)

Expand Down Expand Up @@ -1029,12 +1039,10 @@ def get_delta_weight(self, adapter):
def forward(self, x: torch.Tensor):
previous_dtype = x.dtype

adapters = list(self.lora_A.keys())

for adapter_name in adapters:
for adapter_name in self._get_names_of_computable_adapters():
# apply scaling for A
self.lora_A_active.weight += self.lora_A[adapter_name].weight * self.scaling[adapter_name]
self.lora_B_active.weight += self.lora_B[adapter_name].weight
self.lora_B_active.weight += self.lora_B[adapter_name].weight

result = F.conv2d(
x,
Expand Down Expand Up @@ -1099,9 +1107,9 @@ def forward(self, x: torch.Tensor):

if self.disable_adapters:
return result

else:
adapters = list(self.lora_A.keys())
adapters = self._get_names_of_computable_adapters()

if not torch.is_autocast_enabled():
expected_dtype = result.dtype
Expand All @@ -1113,14 +1121,14 @@ def forward(self, x: torch.Tensor):
self.lora_A[adapter_name](
self.lora_dropout[adapter_name](x))
).to(expected_dtype) \
* self.scaling[adapter_name]
* self.scaling[adapter_name]
else:

for adapter_name in adapters:
result += self.lora_B[adapter_name](
self.lora_A[adapter_name](
self.lora_dropout[adapter_name](x))
) * self.scaling[adapter_name]
) * self.scaling[adapter_name]

return result

Expand Down Expand Up @@ -1185,18 +1193,18 @@ def preprocess_weights(self, k):
self.lora_A[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_A[adapter_name].weight.data)
self.lora_A[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_A[adapter_name].weight.data)
self.lora_A[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_A[adapter_name].weight.data)

self.lora_B[adapter_name].weight.data = self.trim_task_vector(adapter_name, k, self.lora_B[adapter_name].weight.data)
self.lora_B[adapter_name].weight.data = self.elect_sign_vector(adapter_name, self.lora_B[adapter_name].weight.data)
self.lora_B[adapter_name].weight.data = self.disjoint_merge(adapter_name, self.lora_B[adapter_name].weight.data)

self.preprocessed = True
def forward(self, x: torch.Tensor):
result = super().forward(x)
# if not self.preprocessed:
# self.preprocess_weights(25)
if not self.disable_adapters:
adapters = list(self.lora_A.keys())
adapters = self._get_names_of_computable_adapters()

if not torch.is_autocast_enabled():
expected_dtype = result.dtype
Expand All @@ -1205,19 +1213,18 @@ def forward(self, x: torch.Tensor):

for adapter_name in adapters:
if self.scaling[adapter_name] != 0:

temp_result = self.lora_B[adapter_name](
self.lora_A[adapter_name](x)
).to(expected_dtype) * self.scaling[adapter_name]
result = torch.add(result, temp_result)
else:
for adapter_name in adapters:
if self.scaling[adapter_name] != 0:

temp_result = self.lora_B[adapter_name](
self.lora_A[adapter_name](x)
).to(expected_dtype) * self.scaling[adapter_name]
result = torch.add(result, temp_result)

return result