Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoEBERT code reading #31

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added moebert/__init__.py
Empty file.
150 changes: 150 additions & 0 deletions moebert/moe_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F


class MoELayer(nn.Module):
def __init__(self, hidden_size, num_experts, expert, route_method, vocab_size, hash_list):
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
if route_method in ["gate-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif route_method == "hash-random":
self.hash_list = self._random_hash_list(vocab_size)
elif route_method == "hash-balance":
self.hash_list = self._balance_hash_list(hash_list)
else:
raise KeyError("Routing method not supported.")

def _random_hash_list(self, vocab_size):
hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,))
return hash_list
Comment on lines +23 to +25
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

아무 expert만 random으로 해도 잘된다.
https://arxiv.org/pdf/2106.04426.pdf


def _balance_hash_list(self, hash_list):
with open(hash_list, "rb") as file:
result = pickle.load(file)
result = torch.tensor(result, dtype=torch.int64)
return result

def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
Comment on lines +33 to +34
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate를 통하는 forward.
BERT니까 input의 shape은 bsz, seq_len, hid_dim


x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
Comment on lines +36 to +39
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bsz x seq_len, hid_dim으로 바꾸고 gate 통과시킴.
gate를 통과시킨 logits_gate은 bsz x seq_len, num_experts이고 이를 마지막 차원에서 softmax.
확률값 중 최대 index를 구함(=gate). gate는 bsz x seq_len 차원으로 이때 value는 최대값인 tensor의 index가 됨.


order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
Comment on lines +41 to +43
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate를 0차원에서 내림차순으로 sort index를 구함. 그럼 expert index대로 sort가 구해질 것임
[1, 1, 2, 3] -> 첫번째 데이터는 1번째 expert 선택, 두번째는 0번째 expert를 구함
order = argsort -> [2, 2, 1, 0]
num_tokens는 각 expert에 몇개의 token이 할당됐는지 구함.

x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
Comment on lines +44 to +45
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x의 shape은 bsz x seq_len, hid_dim
order의 shape은 bsz x seq_len임. x의 hid_dim을 expert index순으로 정렬함. 즉 같은 expert끼리 몰려있음
이를 split함수를 써서 각 expert당 몇개인지에 대한 tensor로 나눔.
그럼 (expert 0에 가야 하는 텐서들, expert 1로 가야 하는 텐서들 .... ) 이런 튜플이 나오게 됨.


# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
Comment on lines +48 to +51
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load balancing loss
어느 논문 껀지는 모르겠음


prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
Comment on lines +53 to +55
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob_gate는 bsz x seq_len, num_expert차원인데, gate(bsz x seq_len)를 unsqueeze해서 (bsz x seq_len, 1) 차원으로 늘린 뒤 gather 해줌. 즉 max값으로 뽑힌 prob만 뽑는 연산임. gather한 뒤 prob_gate의 차원은 다시 (bsz x seq_len, 1)
위에 x도 나눠준 것처럼 gate에 대한 prob 텐서값도 expert index로 정렬해주고 split해줌.


def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
Comment on lines +57 to +60
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_x, prob_x, expert_idx가 주어지면 그 expert로 forward하고 이를 probability 로 곱함.


x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)

return x, balance_loss, gate_load
Comment on lines +62 to +67
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x는 이미 expert별로 나눠진 tuple임
각 expert들에 대해서 해당 expert로 가야하는 input x를 위를 forward_expert로 넘겨줌.
결과값은 리스트가 될 것인데 이를 수직으로 쌓아줌. 그러면 차원은 bsz x seq_len으로 됨 (순서는 expert idx)
이를 다시 order.argsort(0) 인덱스해서 원래 순서대로 재정렬해줌. (?)
다시 view로 원래 차원으로 바꿔줌.
forward의 최종 return 값은 x(=bsz, seq_len, dim), balance_loss, gate_load(=각 expert가 처리하는 토큰 개수)


def _forward_gate_sentence(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)

order = gate.argsort(0)
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_sentences.clone()
x = x[order] # reorder according to expert number
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts

# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_sentences.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)

prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)

def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x.unsqueeze(-1)
return input_x

result = []
for i in range(self.num_experts):
if x[i].size(0) > 0:
result.append(forward_expert(x[i], prob_gate[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order

return result, balance_loss, gate_load

def _forward_sentence_single_expert(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)

gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
x = self.experts[gate.cpu().item()].forward(x)
return x, 0.0, gate_load

def _forward_hash(self, x, input_ids):
bsz, seq_len, dim = x.size()

x = x.view(-1, dim)
self.hash_list = self.hash_list.to(x.device)
gate = self.hash_list[input_ids.view(-1)]

order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts

x = [self.experts[i].forward(x[i]) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)

return x, 0.0, gate_load

def forward(self, x, input_ids, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
if x.size(0) == 1:
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method in ["hash-random", "hash-balance"]:
x, balance_loss, gate_load = self._forward_hash(x, input_ids)
else:
raise KeyError("Routing method not supported.")

return x, balance_loss, gate_load
131 changes: 131 additions & 0 deletions moebert/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import numpy as np
import pickle
import torch
import torch.nn as nn

from dataclasses import dataclass
from torch import Tensor
from transformers.activations import ACT2FN
from transformers.file_utils import ModelOutput
from typing import Optional, Tuple


def use_experts(layer_idx):
return True


def process_ffn(model):
if model.config.model_type == "bert":
inner_model = model.bert
else:
raise ValueError("Model type not recognized.")

for i in range(model.config.num_hidden_layers):
model_layer = inner_model.encoder.layer[i]
if model_layer.use_experts:
model_layer.importance_processor.load_experts(model_layer)


class ImportanceProcessor:
def __init__(self, config, layer_idx, num_local_experts, local_group_rank):
self.num_experts = config.moebert_expert_num # total number of experts
self.num_local_experts = num_local_experts # number of experts on this device
self.local_group_rank = local_group_rank # rank in the current process group
self.intermediate_size = config.moebert_expert_dim # FFN hidden dimension
self.share_importance = config.moebert_share_importance # number of shared FFN dimension

importance = ImportanceProcessor.load_importance_single(config.moebert_load_importance)[layer_idx, :]
self.importance = self._split_importance(importance)

self.is_moe = False # safety check

@staticmethod
def load_importance_single(importance_files):
with open(importance_files, "rb") as file:
data = pickle.load(file)
data = data["idx"]
return np.array(data)

def _split_importance(self, arr):
result = []
top_importance = arr[:self.share_importance]
remain = arr[self.share_importance:]
all_experts_remain = []
for i in range(self.num_experts):
all_experts_remain.append(remain[i::self.num_experts])
all_experts_remain = np.array(all_experts_remain)

for i in range(self.num_local_experts):
temp = all_experts_remain[self.num_local_experts * self.local_group_rank + i]
temp = np.concatenate((top_importance, temp))
temp = temp[:self.intermediate_size]
result.append(temp.copy())
result = np.array(result)
return result

def load_experts(self, model_layer):
expert_list = model_layer.experts.experts
fc1_weight_data = model_layer.intermediate.dense.weight.data
fc1_bias_data = model_layer.intermediate.dense.bias.data
fc2_weight_data = model_layer.output.dense.weight.data
fc2_bias_data = model_layer.output.dense.bias.data
layernorm_weight_data = model_layer.output.LayerNorm.weight.data
layernorm_bias_data = model_layer.output.LayerNorm.bias.data
for i in range(self.num_local_experts):
idx = self.importance[i]
expert_list[i].fc1.weight.data = fc1_weight_data[idx, :].clone()
expert_list[i].fc1.bias.data = fc1_bias_data[idx].clone()
expert_list[i].fc2.weight.data = fc2_weight_data[:, idx].clone()
expert_list[i].fc2.bias.data = fc2_bias_data.clone()
expert_list[i].LayerNorm.weight.data = layernorm_weight_data.clone()
expert_list[i].LayerNorm.bias.data = layernorm_bias_data.clone()
del model_layer.intermediate
del model_layer.output
self.is_moe = True


class FeedForward(nn.Module):
def __init__(self, config, intermediate_size, dropout):
nn.Module.__init__(self)

# first layer
self.fc1 = nn.Linear(config.hidden_size, intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act

# second layer
self.fc2 = nn.Linear(intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(dropout)

def forward(self, hidden_states: Tensor):
input_tensor = hidden_states
hidden_states = self.fc1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


@dataclass
class MoEModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
gate_loss: torch.FloatTensor = None


@dataclass
class MoEModelOutputWithPooling(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
gate_loss: torch.FloatTensor = None