-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MoEBERT code reading #31
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import copy | ||
import pickle | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MoELayer(nn.Module): | ||
def __init__(self, hidden_size, num_experts, expert, route_method, vocab_size, hash_list): | ||
nn.Module.__init__(self) | ||
self.num_experts = num_experts | ||
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) | ||
self.route_method = route_method | ||
if route_method in ["gate-token", "gate-sentence"]: | ||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() | ||
elif route_method == "hash-random": | ||
self.hash_list = self._random_hash_list(vocab_size) | ||
elif route_method == "hash-balance": | ||
self.hash_list = self._balance_hash_list(hash_list) | ||
else: | ||
raise KeyError("Routing method not supported.") | ||
|
||
def _random_hash_list(self, vocab_size): | ||
hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,)) | ||
return hash_list | ||
|
||
def _balance_hash_list(self, hash_list): | ||
with open(hash_list, "rb") as file: | ||
result = pickle.load(file) | ||
result = torch.tensor(result, dtype=torch.int64) | ||
return result | ||
|
||
def _forward_gate_token(self, x): | ||
bsz, seq_len, dim = x.size() | ||
Comment on lines
+33
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gate를 통하는 forward. |
||
|
||
x = x.view(-1, dim) | ||
logits_gate = self.gate(x) | ||
prob_gate = F.softmax(logits_gate, dim=-1) | ||
gate = torch.argmax(prob_gate, dim=-1) | ||
Comment on lines
+36
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bsz x seq_len, hid_dim으로 바꾸고 gate 통과시킴. |
||
|
||
order = gate.argsort(0) | ||
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) | ||
gate_load = num_tokens.clone() | ||
Comment on lines
+41
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gate를 0차원에서 내림차순으로 sort index를 구함. 그럼 expert index대로 sort가 구해질 것임 |
||
x = x[order] # reorder according to expert number | ||
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts | ||
Comment on lines
+44
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. x의 shape은 bsz x seq_len, hid_dim |
||
|
||
# compute the load balancing loss | ||
P = prob_gate.mean(0) | ||
temp = num_tokens.float() | ||
f = temp / temp.sum(0, keepdim=True) | ||
balance_loss = self.num_experts * torch.sum(P * f) | ||
Comment on lines
+48
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. load balancing loss |
||
|
||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) | ||
prob_gate = prob_gate[order] | ||
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) | ||
Comment on lines
+53
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prob_gate는 bsz x seq_len, num_expert차원인데, gate(bsz x seq_len)를 unsqueeze해서 (bsz x seq_len, 1) 차원으로 늘린 뒤 gather 해줌. 즉 max값으로 뽑힌 prob만 뽑는 연산임. gather한 뒤 prob_gate의 차원은 다시 (bsz x seq_len, 1) |
||
|
||
def forward_expert(input_x, prob_x, expert_idx): | ||
input_x = self.experts[expert_idx].forward(input_x) | ||
input_x = input_x * prob_x | ||
return input_x | ||
Comment on lines
+57
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input_x, prob_x, expert_idx가 주어지면 그 expert로 forward하고 이를 probability 로 곱함. |
||
|
||
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] | ||
x = torch.vstack(x) | ||
x = x[order.argsort(0)] # restore original order | ||
x = x.view(bsz, seq_len, dim) | ||
|
||
return x, balance_loss, gate_load | ||
Comment on lines
+62
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. x는 이미 expert별로 나눠진 tuple임 |
||
|
||
def _forward_gate_sentence(self, x, attention_mask): | ||
x_masked = x * attention_mask.unsqueeze(-1) | ||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) | ||
logits_gate = self.gate(x_average) | ||
prob_gate = F.softmax(logits_gate, dim=-1) | ||
gate = torch.argmax(prob_gate, dim=-1) | ||
|
||
order = gate.argsort(0) | ||
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0) | ||
gate_load = num_sentences.clone() | ||
x = x[order] # reorder according to expert number | ||
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts | ||
|
||
# compute the load balancing loss | ||
P = prob_gate.mean(0) | ||
temp = num_sentences.float() | ||
f = temp / temp.sum(0, keepdim=True) | ||
balance_loss = self.num_experts * torch.sum(P * f) | ||
|
||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) | ||
prob_gate = prob_gate[order] | ||
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0) | ||
|
||
def forward_expert(input_x, prob_x, expert_idx): | ||
input_x = self.experts[expert_idx].forward(input_x) | ||
input_x = input_x * prob_x.unsqueeze(-1) | ||
return input_x | ||
|
||
result = [] | ||
for i in range(self.num_experts): | ||
if x[i].size(0) > 0: | ||
result.append(forward_expert(x[i], prob_gate[i], i)) | ||
result = torch.vstack(result) | ||
result = result[order.argsort(0)] # restore original order | ||
|
||
return result, balance_loss, gate_load | ||
|
||
def _forward_sentence_single_expert(self, x, attention_mask): | ||
x_masked = x * attention_mask.unsqueeze(-1) | ||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) | ||
logits_gate = self.gate(x_average) | ||
prob_gate = F.softmax(logits_gate, dim=-1) | ||
gate = torch.argmax(prob_gate, dim=-1) | ||
|
||
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0) | ||
x = self.experts[gate.cpu().item()].forward(x) | ||
return x, 0.0, gate_load | ||
|
||
def _forward_hash(self, x, input_ids): | ||
bsz, seq_len, dim = x.size() | ||
|
||
x = x.view(-1, dim) | ||
self.hash_list = self.hash_list.to(x.device) | ||
gate = self.hash_list[input_ids.view(-1)] | ||
|
||
order = gate.argsort(0) | ||
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) | ||
gate_load = num_tokens.clone() | ||
x = x[order] # reorder according to expert number | ||
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts | ||
|
||
x = [self.experts[i].forward(x[i]) for i in range(self.num_experts)] | ||
x = torch.vstack(x) | ||
x = x[order.argsort(0)] # restore original order | ||
x = x.view(bsz, seq_len, dim) | ||
|
||
return x, 0.0, gate_load | ||
|
||
def forward(self, x, input_ids, attention_mask): | ||
if self.route_method == "gate-token": | ||
x, balance_loss, gate_load = self._forward_gate_token(x) | ||
elif self.route_method == "gate-sentence": | ||
if x.size(0) == 1: | ||
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask) | ||
else: | ||
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) | ||
elif self.route_method in ["hash-random", "hash-balance"]: | ||
x, balance_loss, gate_load = self._forward_hash(x, input_ids) | ||
else: | ||
raise KeyError("Routing method not supported.") | ||
|
||
return x, balance_loss, gate_load |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import numpy as np | ||
import pickle | ||
import torch | ||
import torch.nn as nn | ||
|
||
from dataclasses import dataclass | ||
from torch import Tensor | ||
from transformers.activations import ACT2FN | ||
from transformers.file_utils import ModelOutput | ||
from typing import Optional, Tuple | ||
|
||
|
||
def use_experts(layer_idx): | ||
return True | ||
|
||
|
||
def process_ffn(model): | ||
if model.config.model_type == "bert": | ||
inner_model = model.bert | ||
else: | ||
raise ValueError("Model type not recognized.") | ||
|
||
for i in range(model.config.num_hidden_layers): | ||
model_layer = inner_model.encoder.layer[i] | ||
if model_layer.use_experts: | ||
model_layer.importance_processor.load_experts(model_layer) | ||
|
||
|
||
class ImportanceProcessor: | ||
def __init__(self, config, layer_idx, num_local_experts, local_group_rank): | ||
self.num_experts = config.moebert_expert_num # total number of experts | ||
self.num_local_experts = num_local_experts # number of experts on this device | ||
self.local_group_rank = local_group_rank # rank in the current process group | ||
self.intermediate_size = config.moebert_expert_dim # FFN hidden dimension | ||
self.share_importance = config.moebert_share_importance # number of shared FFN dimension | ||
|
||
importance = ImportanceProcessor.load_importance_single(config.moebert_load_importance)[layer_idx, :] | ||
self.importance = self._split_importance(importance) | ||
|
||
self.is_moe = False # safety check | ||
|
||
@staticmethod | ||
def load_importance_single(importance_files): | ||
with open(importance_files, "rb") as file: | ||
data = pickle.load(file) | ||
data = data["idx"] | ||
return np.array(data) | ||
|
||
def _split_importance(self, arr): | ||
result = [] | ||
top_importance = arr[:self.share_importance] | ||
remain = arr[self.share_importance:] | ||
all_experts_remain = [] | ||
for i in range(self.num_experts): | ||
all_experts_remain.append(remain[i::self.num_experts]) | ||
all_experts_remain = np.array(all_experts_remain) | ||
|
||
for i in range(self.num_local_experts): | ||
temp = all_experts_remain[self.num_local_experts * self.local_group_rank + i] | ||
temp = np.concatenate((top_importance, temp)) | ||
temp = temp[:self.intermediate_size] | ||
result.append(temp.copy()) | ||
result = np.array(result) | ||
return result | ||
|
||
def load_experts(self, model_layer): | ||
expert_list = model_layer.experts.experts | ||
fc1_weight_data = model_layer.intermediate.dense.weight.data | ||
fc1_bias_data = model_layer.intermediate.dense.bias.data | ||
fc2_weight_data = model_layer.output.dense.weight.data | ||
fc2_bias_data = model_layer.output.dense.bias.data | ||
layernorm_weight_data = model_layer.output.LayerNorm.weight.data | ||
layernorm_bias_data = model_layer.output.LayerNorm.bias.data | ||
for i in range(self.num_local_experts): | ||
idx = self.importance[i] | ||
expert_list[i].fc1.weight.data = fc1_weight_data[idx, :].clone() | ||
expert_list[i].fc1.bias.data = fc1_bias_data[idx].clone() | ||
expert_list[i].fc2.weight.data = fc2_weight_data[:, idx].clone() | ||
expert_list[i].fc2.bias.data = fc2_bias_data.clone() | ||
expert_list[i].LayerNorm.weight.data = layernorm_weight_data.clone() | ||
expert_list[i].LayerNorm.bias.data = layernorm_bias_data.clone() | ||
del model_layer.intermediate | ||
del model_layer.output | ||
self.is_moe = True | ||
|
||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, config, intermediate_size, dropout): | ||
nn.Module.__init__(self) | ||
|
||
# first layer | ||
self.fc1 = nn.Linear(config.hidden_size, intermediate_size) | ||
if isinstance(config.hidden_act, str): | ||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | ||
else: | ||
self.intermediate_act_fn = config.hidden_act | ||
|
||
# second layer | ||
self.fc2 = nn.Linear(intermediate_size, config.hidden_size) | ||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, hidden_states: Tensor): | ||
input_tensor = hidden_states | ||
hidden_states = self.fc1(hidden_states) | ||
hidden_states = self.intermediate_act_fn(hidden_states) | ||
hidden_states = self.fc2(hidden_states) | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | ||
return hidden_states | ||
|
||
|
||
@dataclass | ||
class MoEModelOutput(ModelOutput): | ||
last_hidden_state: torch.FloatTensor = None | ||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
gate_loss: torch.FloatTensor = None | ||
|
||
|
||
@dataclass | ||
class MoEModelOutputWithPooling(ModelOutput): | ||
last_hidden_state: torch.FloatTensor = None | ||
pooler_output: torch.FloatTensor = None | ||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
gate_loss: torch.FloatTensor = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
아무 expert만 random으로 해도 잘된다.
https://arxiv.org/pdf/2106.04426.pdf