Skip to content

Commit aead69f

Browse files
dan-garveyarchana-ramalingamIanNod
authored
moe (#162)
Throwing this up as a pr to make it easier to view --------- Co-authored-by: archana-ramalingam <[email protected]> Co-authored-by: Ian <[email protected]>
1 parent e051c37 commit aead69f

20 files changed

+1412
-228
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# TODO: Should be using a base class with the protocol supported.
1818
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
19+
from ..models.mixtral.mixtral import *
1920

2021

2122
def main():
@@ -52,7 +53,10 @@ def main():
5253
llama_config = LlamaModelConfig(hp)
5354
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
5455
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
55-
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
56+
if llama_config.hp.expert_count:
57+
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
58+
else:
59+
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
5660

5761
def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
5862
return {

sharktank/sharktank/examples/paged_llm_v1.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..types import *
1818

1919
# TODO: Should be using a base class with the protocol supported.
20+
from ..models.mixtral.mixtral import *
2021
from ..models.llama.llama import *
2122
from ..utils.debugging import trace_tensor
2223
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
@@ -236,7 +237,11 @@ def main():
236237
activation_dtype=activation_dtype,
237238
attention_dtype=activation_dtype,
238239
)
239-
model = PagedLlamaModelV1(dataset.root_theta, config)
240+
241+
if config.hp.expert_count:
242+
model = PagedMixtralModelV1(dataset.root_theta, config)
243+
else:
244+
model = PagedLlamaModelV1(dataset.root_theta, config)
240245
if args.save_intermediates_path:
241246
from ..utils.patching import SaveModuleResultTensorsPatch
242247

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import sys
8+
9+
import torch
10+
11+
from sharktank.layers import *
12+
from sharktank.types import *
13+
from sharktank.models.mixtral.mixtral import *
14+
15+
16+
def main(args: list[str]):
17+
from ..utils import cli
18+
19+
torch.no_grad().__enter__()
20+
21+
parser = cli.create_parser()
22+
cli.add_input_dataset_options(parser)
23+
args = cli.parse(parser)
24+
25+
dataset = cli.get_input_dataset(args)
26+
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
27+
llama_config = LlamaModelConfig(hp)
28+
llama_config.kv_cache_type = "direct"
29+
llama_config.activation_dtype = torch.float16
30+
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
31+
32+
# bs ("batch size") == 1
33+
cache_state = model.cache.allocate(bs=1)
34+
35+
start_index = 0
36+
tokens = torch.tensor(
37+
[
38+
[
39+
1,
40+
1059,
41+
31871,
42+
1217,
43+
322,
44+
266,
45+
3682,
46+
6075,
47+
31902,
48+
13,
49+
31849,
50+
31871,
51+
0,
52+
0,
53+
0,
54+
0,
55+
]
56+
+ 48 * [0],
57+
]
58+
)
59+
assert tokens.shape[1] % model.cache.block_seq_stride == 0
60+
seq_block_ids = torch.tensor(
61+
[
62+
[127, 0, 0, 0],
63+
]
64+
)
65+
66+
# Important: Do not use a sequence length of 0 for empty batch slots
67+
# as it will cause softmax to nan due to a mask of all -inf. This then
68+
# propagates and causes badness.
69+
seq_lens = torch.tensor([12])
70+
71+
attention_mask = model.attention_mask(
72+
model.input_mask(seq_lens, tokens.shape[1]),
73+
)
74+
75+
print(f"Step {start_index}")
76+
logits = model.prefill(
77+
tokens,
78+
attention_mask=attention_mask,
79+
seq_block_ids=seq_block_ids,
80+
cache_state=cache_state,
81+
)
82+
# TODO: Normalize the output of extract_tokens_from_logits into tensor [bs, 1].
83+
tokens = torch.tensor(model.extract_tokens_from_logits(logits, seq_lens)).unsqueeze(
84+
1
85+
)
86+
print(f" : tokens = {tokens}")
87+
88+
# Decode a step.
89+
print("Decoding...")
90+
print(tokens.shape, tokens)
91+
start_positions = torch.tensor([12])
92+
seq_lens = seq_lens + 1
93+
decode_attention_mask = model.decode_attention_mask(
94+
model.input_mask(
95+
seq_lens,
96+
seq_block_ids.shape[1] * model.cache.block_seq_stride,
97+
),
98+
)
99+
logits = model.decode(
100+
tokens,
101+
attention_mask=decode_attention_mask,
102+
start_positions=start_positions,
103+
seq_block_ids=seq_block_ids,
104+
cache_state=cache_state,
105+
)
106+
tokens = torch.tensor(model.extract_tokens_from_logits(logits, [1])).unsqueeze(1)
107+
print(f" : tokens = {tokens}")
108+
109+
def save_prefill_module(model):
110+
from iree.compiler.extras.fx_importer import FxImporter
111+
from iree.compiler.ir import AsmState
112+
113+
importer = FxImporter()
114+
115+
print("Generating FX graph")
116+
117+
class InferenceModule(torch.nn.Module):
118+
def __init__(self):
119+
super().__init__()
120+
self.add_module("prefill", model)
121+
122+
def forward(self, tokens, attention_mask, seq_block_ids, *cache_state):
123+
return self.prefill.prefill(
124+
tokens,
125+
attention_mask=attention_mask,
126+
seq_block_ids=seq_block_ids,
127+
cache_state=list(cache_state),
128+
)
129+
130+
infmod = InferenceModule()
131+
prog = torch.export.export(
132+
infmod, (tokens, attention_mask, seq_block_ids) + tuple(cache_state)
133+
)
134+
135+
print(f"FX prog:", prog)
136+
importer.import_program(prog, func_name="prefill")
137+
output_file = "/tmp/prefill.mlirbc"
138+
print("Saving to:", output_file)
139+
with open(output_file, "wb") as f:
140+
importer.module_op.write_bytecode(f)
141+
142+
143+
if __name__ == "__main__":
144+
sys.exit(main(sys.argv[1:]))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import sys
8+
9+
import torch
10+
11+
from sharktank.layers import *
12+
from sharktank.types import *
13+
from sharktank.models.mixtral.mixtral_ref import *
14+
15+
16+
def main(args: list[str]):
17+
from ..utils import cli
18+
19+
torch.no_grad().__enter__()
20+
21+
parser = cli.create_parser()
22+
cli.add_input_dataset_options(parser)
23+
args = cli.parse(parser)
24+
25+
dataset = cli.get_input_dataset(args)
26+
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
27+
ref_llama_config = RefLlamaModelConfig(hp)
28+
ref_llama_config.activation_dtype = torch.float16
29+
model = DirectCacheMixtralModelV1(dataset.root_theta, ref_llama_config)
30+
31+
kv_cache = model.create_cache(bs=1)
32+
start_index = 0
33+
next_tokens = [1, 1059, 31871, 1217, 322, 266, 3682, 6075, 31902, 13, 31849, 31871]
34+
print(f"Step {start_index}")
35+
tokens = model.forward(
36+
torch.tensor([next_tokens]), start_index=start_index, local_kv_cache=kv_cache
37+
)
38+
print(f" : tokens = {tokens}")
39+
40+
# Decode a step.
41+
print("Decoding...")
42+
print(tokens.shape, tokens)
43+
decode_token = model.forward(tokens, start_index=12, local_kv_cache=kv_cache)
44+
print(f" : decode tokens = {decode_token}")
45+
46+
47+
if __name__ == "__main__":
48+
sys.exit(main(sys.argv[1:]))

sharktank/sharktank/layers/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,10 @@
1212
from .norm import RMSNormLayer
1313
from .rotary_embedding import RotaryEmbeddingLayer
1414
from .token_embedding import TokenEmbeddingLayer
15+
from .llama_attention_block import LlamaAttentionBlock
16+
from .paged_llama_attention_block import PagedLlamaAttentionBlock
17+
from .ffn_block import FFN
18+
from .ffn_moe_block import FFNMOE
19+
from .mixture_of_experts_block import SparseMoeBlock
1520

1621
from . import configs

sharktank/sharktank/layers/base.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616
from ..utils import debugging
1717

1818
__all__ = [
19-
"LinearLayer",
20-
"RotaryEmbeddingLayer",
21-
"RMSNormLayer",
19+
"BaseLayer",
2220
"ThetaLayer",
23-
"TokenEmbedding",
2421
]
2522

2623

sharktank/sharktank/layers/configs/llm_configs.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919

2020
import torch
2121

22-
__all__ = [
23-
"LlamaHParams",
24-
]
22+
__all__ = ["LlamaHParams"]
2523

2624

2725
@dataclass
@@ -36,14 +34,21 @@ class LlamaHParams:
3634
block_count: int
3735
feed_forward_length: int
3836
rope_dimension_count: int
37+
rope_freq_base: float
3938
attention_head_count: int
4039
attn_head_dim: int
4140
attention_layer_norm_rms_epsilon: float
4241
attention_head_count_kv: int
42+
expert_count: int
43+
expert_used_count: int
4344

4445
@staticmethod
4546
def from_gguf_props(p: dict[str, Any]):
47+
default_expert_count = 0
48+
default_expert_used_count = 0
49+
default_rope_freq_base = 10000.0
4650
attention_head_count = _int_prop(p, "llama.attention.head_count")
51+
4752
return LlamaHParams(
4853
context_length=_int_prop(p, "llama.context_length"),
4954
embedding_length=_int_prop(p, "llama.embedding_length"),
@@ -58,6 +63,15 @@ def from_gguf_props(p: dict[str, Any]):
5863
attention_head_count_kv=_optional_int_prop(
5964
p, "llama.attention.head_count_kv", attention_head_count
6065
),
66+
rope_freq_base=_optional_float_prop(
67+
p, "llama.rope.freq_base", default_rope_freq_base
68+
),
69+
expert_count=_optional_int_prop(
70+
p, "llama.expert_count", default_expert_count
71+
),
72+
expert_used_count=_optional_int_prop(
73+
p, "llama.expert_used_count", default_expert_used_count
74+
),
6175
)
6276

6377

@@ -79,10 +93,16 @@ def _int_prop(p: dict[str, Any], name: str) -> int:
7993
raise KeyError(f"Property '{name}' not found (among keys {p.keys()})")
8094

8195

96+
def _optional_float_prop(p: dict[str, Any], name: str, default_value: float) -> float:
97+
value = p.get(name, default_value)
98+
try:
99+
return float(value)
100+
except ValueError as e:
101+
raise ValueError(f"Property '{name}' expected to be a float and was not") from e
102+
103+
82104
def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
83-
value = p[name]
84-
if value is None:
85-
return default_value
105+
value = p.get(name, default_value)
86106
try:
87107
return int(value)
88108
except ValueError as e:
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from typing import Optional
8+
9+
import torch
10+
import torch.nn.functional as F
11+
12+
from .base import Theta, ThetaLayer
13+
from .linear import LinearLayer
14+
15+
__all__ = [
16+
"FFN",
17+
]
18+
19+
20+
class FFN(ThetaLayer):
21+
def __init__(
22+
self,
23+
theta: Theta,
24+
):
25+
super().__init__(theta)
26+
27+
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
28+
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
29+
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
30+
31+
def forward(
32+
self,
33+
h: torch.Tensor,
34+
):
35+
ffn_gate = F.silu(self.ffn_gate(h))
36+
ffn_up = self.ffn_up(h)
37+
ffn_down = self.ffn_down(ffn_gate * ffn_up)
38+
return ffn_down

0 commit comments

Comments
 (0)