Skip to content

Commit

Permalink
Refactor llama / mixtral / grok for shared features (#267)
Browse files Browse the repository at this point in the history
Many of these features can toggle between depending on architecture.
Replumbing the configurations separately allows better reuse and
understanding of how models vary between eachother.

grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok
has some hardcoded values that have better representations, e.g.
`sqrt(6144)` and `sqrt(3)`.

output normalization is optional but used by mixtral. Presence of the
tensor is sufficient for performing the normalization.
  • Loading branch information
rsuderman authored Oct 16, 2024
1 parent e700bfa commit f5fcd00
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 244 deletions.
13 changes: 8 additions & 5 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn.functional as F

from iree.turbine.aot import *

from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank.layers.mixture_of_experts_block import MoeBlock
from ..utils import cli


Expand Down Expand Up @@ -37,21 +40,21 @@ def main():
action="store_true",
)
parser.add_argument(
"--use-grok",
help="Enable to export Grok model's version of MOE block",
"--use-gelu",
help="Enable to use gelu for moe activation",
action="store_true",
)

args = cli.parse(parser)

bs = args.batch_size

model = PreGatherMoeBlock(
model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
moe_activation=F.gelu if args.use_gelu else F.silu,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock
from .mixture_of_experts_block import MoeBlock

from .configs import *
12 changes: 5 additions & 7 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base import ThetaLayer
from .linear import LinearLayer
from ..types import Theta, DefaultPrimitiveTensor
from ..ops import einsum_2args
from ..ops import einsum_2args, elementwise

__all__ = [
"FFNMOE",
Expand All @@ -24,15 +24,15 @@ class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
activation=F.silu,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")
self.activation = activation

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
Expand Down Expand Up @@ -63,10 +63,8 @@ def forward(
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts))
else:
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts))
ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
ffn_gate = elementwise(self.activation, ffn_gate)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/layers/llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Optional

import math

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = values.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
self.head_dim
)

# Apply attention mask.
if attention_mask is not None:
Expand Down
114 changes: 9 additions & 105 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from .ffn_moe_block import FFNMOE, PreGatherFFNMOE

__all__ = [
"SparseMoeBlock",
"PreGatherMoeBlock",
"MoeBlock",
]


class SparseMoeBlock(ThetaLayer):
class MoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
Expand All @@ -35,108 +34,12 @@ def __init__(
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
):
super().__init__(theta)

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))

# Add FFN norm
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)

# Add expert_count x FFN
self.experts = nn.ModuleList(
[FFNMOE(theta, expert_idx=i) for i in range(expert_count)]
)

self.expert_count = expert_count
self.expert_used_count = expert_used_count

def forward(
self,
h: torch.Tensor,
):
ffn_input = self.ffn_norm(h)
batch_size, sequence_length, feature_dim = ffn_input.shape
ffn_input = ffn_input.view(-1, feature_dim)

# For each token, the router calculates the router weights for all experts
# router_logits: (batch_size * sequence_length, expert_count)
router_logits = self.ffn_gate_inp(ffn_input)
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

# Select top k experts from router weights
router_weights, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)

moe_output = torch.zeros(
(batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype
)

# Create an expert mask by one hot encoding the selected top k experts
# used to index which expert is to be invoked for each token
# expert_mask: (expert_count, expert_used_count, sequence_length)
expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute(
2, 1, 0
)

# Iterate over all experts in the model
for expert_idx in range(self.expert_count):
expert_layer = self.experts[expert_idx]
top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx])

# Given the hidden states, index the tokens assigned to this expert
# and calculate the current expert's hidden state and weigh the
# output expert hidden states by the router weights, based on the
# appropriate tokens
current_expert_tokens = ffn_input[None, token_idx]

current_expert = (
expert_layer(current_expert_tokens)
* router_weights[token_idx, top_k_expert_idx, None]
)

current_expert = current_expert.reshape(-1, feature_dim)

moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

moe_output = self.layer_output_norm(moe_output)
return h + moe_output


class PreGatherMoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens
(or reduced performance).
"""

def __init__(
self,
theta: Theta,
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
moe_activation=F.silu,
):
super().__init__(theta)

self.expert_count = expert_count
self.expert_used_count = expert_used_count
self.use_grok = use_grok

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))
Expand All @@ -146,15 +49,17 @@ def __init__(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm layer for Grok
if self.use_grok:
# Add optional FFN output norm layer
if theta.optional_tensor("layer_output_norm") is not None:
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)
else:
self.add_module("layer_output_norm", torch.nn.Identity())

# Add expert_count x FFN
self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok)
self.experts = PreGatherFFNMOE(theta, activation=moe_activation)

def forward(
self,
Expand All @@ -180,7 +85,6 @@ def forward(
moe_output = self.experts(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

if self.use_grok:
moe_output = self.layer_output_norm(moe_output)
moe_output = self.layer_output_norm(moe_output)

return h + moe_output
36 changes: 20 additions & 16 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
):
super().__init__(theta)

Expand All @@ -46,7 +47,8 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.use_grok = use_grok
self.attention_scale = attention_scale
self.softcap = softcap

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
Expand All @@ -56,7 +58,12 @@ def __init__(
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))

if self.use_grok:
if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
"attn_output_norm",
torch.nn.Identity(),
)
else:
self.add_module(
"attn_output_norm",
RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon),
Expand Down Expand Up @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

# Flash attention.
if not self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
elif self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3))
attn_weights = 30.0 * torch.tanh(
attn_weights * (0.08838834764831845 / 30.0)
)
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)

self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Project.
attn_output = self.attn_output(attn_output)

if self.use_grok:
attn_output = self.attn_output_norm(attn_output)
attn_output = self.attn_output_norm(attn_output)

h = h + attn_output

return h

def transact_cache_direct(
Expand Down
Loading

0 comments on commit f5fcd00

Please sign in to comment.