Skip to content

Commit 48328fb

Browse files
yanboliangChillee
authored andcommitted
transposed w2 to have reduction dim be innermost dim
fix converting checkpoint and tp Update perf number Update perf number Update
1 parent b7995e4 commit 48328fb

File tree

6 files changed

+27
-26
lines changed

6 files changed

+27
-26
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ Please check the rest of this page about benchmark of LLaMA family models.
2222
### Mixtral 8x7B
2323
We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:
2424

25-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
25+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
2626
|------------------|---------|-----------|--------|------------|
27-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
28-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
27+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
28+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2929

3030
Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
3131

mixtral-moe/README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
1212
## Benchmarks
1313
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
1414

15-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
15+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
1616
|------------------|---------|-----------|--------|------------|
17-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
18-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
19-
17+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
18+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2019

2120

2221
## Generate Text

mixtral-moe/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,16 @@ class ConditionalFeedForward(nn.Module):
188188
def __init__(self, config):
189189
super().__init__()
190190
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191-
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
192192
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
193193

194194
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
195-
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
196-
w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D]
195+
w1_weights = self.w1[expert_indices] # [T, A, D, D]
196+
w3_weights = self.w3[expert_indices] # [T, A, D, D]
197197
w2_weights = self.w2[expert_indices] # [T, A, D, D]
198-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
199-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
200-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
198+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
199+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
200+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
201201
return expert_outs
202202

203203

mixtral-moe/quantize.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
7575
cur_state_dict[f"{fqn}.weight"] = int8_weight
7676
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
7777
elif isinstance(mod, ConditionalFeedForward):
78-
num_experts, intermediate_size, dim = mod.w1.shape
7978
for weight_idx in range(0, 3):
8079
weight_name = f"w{weight_idx + 1}"
8180
scales_name = f"scales{weight_idx + 1}"
8281
weight = getattr(mod, weight_name)
82+
num_experts, intermediate_size, dim = weight.shape
8383

8484
bit8_weight_list = []
8585
scales_list = []
@@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
125125
self.target_dtype = target_dtype
126126

127127
self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128-
self.register_buffer("w2", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128+
self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
129129
self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
130130

131131
self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132-
self.register_buffer("scales2", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132+
self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
133133
self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
134134

135135
def forward(self, x, expert_indices):
136-
w1_weights = (self.w1.to(x.dtype)[expert_indices] * self.scales1[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
137-
w3_weights = (self.w3.to(x.dtype)[expert_indices] * self.scales3[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
138-
w2_weights = (self.w2.to(x.dtype)[expert_indices] * self.scales2[expert_indices].to(x.dtype).unsqueeze(-1)) # [T, A, D, D]
139-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
140-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
141-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
136+
w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
137+
w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
138+
w2_weights = self.w2.to(x.dtype)[expert_indices]
139+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
140+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
141+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
142142
return expert_outs
143143

144144

mixtral-moe/scripts/convert_hf_checkpoint.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def convert_hf_checkpoint(
7676
del final_result[key]
7777
del final_result[key.replace("wq", "wk")]
7878
del final_result[key.replace("wq", "wv")]
79-
if "w1" in key or "w2" in key or "w3" in key:
79+
elif "w1" in key or "w3" in key:
8080
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
81-
if "gate" in key:
81+
elif "w2" in key:
82+
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
83+
elif "gate" in key:
8284
final_result[key] = final_result[key].contiguous()
8385

8486
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")

mixtral-moe/tp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def shard_qkv(qkv, dim, weight_splits):
9999
def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
100100
mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False)
101101
mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False)
102-
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 1), requires_grad=False)
102+
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False)
103103

104104
if hasattr(mlp.cond_ffn, "scales1"):
105105
mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)
106106
mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False)
107-
mlp.cond_ffn.scales2 = nn.Parameter(shard(mlp.cond_ffn.scales2, 1), requires_grad=False)
107+
mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False)
108108

109109
world_size = _get_world_size()
110110
mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(

0 commit comments

Comments
 (0)