Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoEBERT code reading #31

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

MoEBERT code reading #31

wants to merge 3 commits into from

Conversation

long8v
Copy link
Owner

@long8v long8v commented May 23, 2022

@long8v long8v changed the title add moe bert MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation May 23, 2022
Comment on lines +23 to +25
def _random_hash_list(self, vocab_size):
hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,))
return hash_list
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

아무 expert만 random으로 해도 잘된다.
https://arxiv.org/pdf/2106.04426.pdf

Comment on lines +33 to +34
def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate를 통하는 forward.
BERT니까 input의 shape은 bsz, seq_len, hid_dim

Comment on lines +36 to +39
x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bsz x seq_len, hid_dim으로 바꾸고 gate 통과시킴.
gate를 통과시킨 logits_gate은 bsz x seq_len, num_experts이고 이를 마지막 차원에서 softmax.
확률값 중 최대 index를 구함(=gate). gate는 bsz x seq_len 차원으로 이때 value는 최대값인 tensor의 index가 됨.

Comment on lines +41 to +43
order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gate를 0차원에서 내림차순으로 sort index를 구함. 그럼 expert index대로 sort가 구해질 것임
[1, 1, 2, 3] -> 첫번째 데이터는 1번째 expert 선택, 두번째는 0번째 expert를 구함
order = argsort -> [2, 2, 1, 0]
num_tokens는 각 expert에 몇개의 token이 할당됐는지 구함.

Comment on lines +44 to +45
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x의 shape은 bsz x seq_len, hid_dim
order의 shape은 bsz x seq_len임. x의 hid_dim을 expert index순으로 정렬함. 즉 같은 expert끼리 몰려있음
이를 split함수를 써서 각 expert당 몇개인지에 대한 tensor로 나눔.
그럼 (expert 0에 가야 하는 텐서들, expert 1로 가야 하는 텐서들 .... ) 이런 튜플이 나오게 됨.

Comment on lines +48 to +51
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load balancing loss
어느 논문 껀지는 모르겠음

Comment on lines +53 to +55
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob_gate는 bsz x seq_len, num_expert차원인데, gate(bsz x seq_len)를 unsqueeze해서 (bsz x seq_len, 1) 차원으로 늘린 뒤 gather 해줌. 즉 max값으로 뽑힌 prob만 뽑는 연산임. gather한 뒤 prob_gate의 차원은 다시 (bsz x seq_len, 1)
위에 x도 나눠준 것처럼 gate에 대한 prob 텐서값도 expert index로 정렬해주고 split해줌.

Comment on lines +57 to +60
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_x, prob_x, expert_idx가 주어지면 그 expert로 forward하고 이를 probability 로 곱함.

Comment on lines +62 to +67
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)

return x, balance_loss, gate_load
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x는 이미 expert별로 나눠진 tuple임
각 expert들에 대해서 해당 expert로 가야하는 input x를 위를 forward_expert로 넘겨줌.
결과값은 리스트가 될 것인데 이를 수직으로 쌓아줌. 그러면 차원은 bsz x seq_len으로 됨 (순서는 expert idx)
이를 다시 order.argsort(0) 인덱스해서 원래 순서대로 재정렬해줌. (?)
다시 view로 원래 차원으로 바꿔줌.
forward의 최종 return 값은 x(=bsz, seq_len, dim), balance_loss, gate_load(=각 expert가 처리하는 토큰 개수)

@long8v long8v changed the title MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation MoEBERT code reading Jul 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant