-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial grok #169
initial grok #169
Changes from 47 commits
ba87a04
de9842b
b7965b1
6d3d261
b5f535d
4095db0
e71630a
5772a3d
3f2914a
7c2e133
e1261f5
a242bde
bb40d12
cfa8420
325696f
48fce0c
d9e787c
ab084cc
29e3603
0670e1d
a4be20b
3049f87
b8240c8
5bf30e0
d970944
b1fd818
85e2f87
7ed9a23
3510634
a4ff36a
940db2f
bb2f5a1
b6e52eb
b790cb5
19218f3
7deb42a
1b6cb6d
43b20c4
d938a08
6aeeb4f
cac489c
d5c27fe
10d6c87
4816c93
124503f
46c6eb6
f3a8fb1
dcc1e8f
88e38e2
430045b
f0a3e31
e5dc9e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
from shark_turbine.aot import * | ||
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch | ||
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock | ||
from ..utils import cli | ||
|
||
|
||
def main(): | ||
parser = cli.create_parser() | ||
parser.add_argument( | ||
"--output-mlir", | ||
help="Output file path for exported MLIR file", | ||
default="/tmp/batch_llama_v1.mlir", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
"-bs", | ||
help="Batch size to generate, e.g. `4` or `2`", | ||
type=lambda arg: int(arg), | ||
default="2", | ||
) | ||
parser.add_argument( | ||
"--verbose", | ||
"-v", | ||
help="Include verbose logging", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--strict", | ||
help="Enables strictness during export", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--use-grok", | ||
help="Enable to export Grok model's version of MOE block", | ||
action="store_true", | ||
) | ||
|
||
args = cli.parse(parser) | ||
|
||
bs = args.batch_size | ||
|
||
model = PreGatherMoeBlock( | ||
theta=make_moe_block_theta()("blk.0"), | ||
expert_count=8, | ||
expert_used_count=2, | ||
rms_epsilon=1e-5, | ||
use_grok=args.use_grok, | ||
) | ||
fxb = FxProgramsBuilder(model) | ||
input = make_rand_torch((bs, 32, 6144)) | ||
|
||
@fxb.export_program(name="prefill_moe", args=(input,)) | ||
def _(model, input: torch.Tensor) -> torch.Tensor: | ||
return model(input) | ||
|
||
input = make_rand_torch((bs, 1, 6144)) | ||
|
||
@fxb.export_program(name="decode_moe", args=(input,)) | ||
def _(model, input: torch.Tensor) -> torch.Tensor: | ||
return model(input) | ||
|
||
if args.verbose: | ||
for name, ep in fxb.programs.items(): | ||
print(f"EXPORT {name}:\n{ep}") | ||
|
||
print("Exporting") | ||
output = export(fxb) | ||
print(f"Saving to '{args.output_mlir}'") | ||
output.save_mlir(args.output_mlir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,9 @@ | |
|
||
from dataclasses import dataclass | ||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
__all__ = ["LlamaHParams"] | ||
__all__ = ["LlamaHParams", "LlamaModelConfig"] | ||
|
||
|
||
@dataclass | ||
|
@@ -29,48 +28,55 @@ class LlamaHParams: | |
Comments are only provided if they differ from this source. | ||
""" | ||
|
||
model_arch: str | ||
context_length: int | ||
embedding_length: int | ||
block_count: int | ||
feed_forward_length: int | ||
rope_dimension_count: int | ||
rope_freq_base: float | ||
attention_head_count: int | ||
attn_head_dim: int | ||
attention_layer_norm_rms_epsilon: float | ||
attention_head_count_kv: int | ||
expert_count: int | ||
expert_used_count: int | ||
rope_dimension_count: Optional[int] = None | ||
rope_freq_base: Optional[float] = None | ||
expert_count: Optional[int] = None | ||
expert_used_count: Optional[int] = None | ||
|
||
@staticmethod | ||
def from_gguf_props(p: dict[str, Any]): | ||
name_prefix = p["general.architecture"] | ||
default_expert_count = 0 | ||
default_expert_used_count = 0 | ||
default_rope_freq_base = 10000.0 | ||
attention_head_count = _int_prop(p, "llama.attention.head_count") | ||
default_rope_dimension_count = 128 | ||
attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count") | ||
rope_dimension_count = _optional_int_prop( | ||
p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count | ||
) | ||
|
||
return LlamaHParams( | ||
context_length=_int_prop(p, "llama.context_length"), | ||
embedding_length=_int_prop(p, "llama.embedding_length"), | ||
block_count=_int_prop(p, "llama.block_count"), | ||
feed_forward_length=_int_prop(p, "llama.feed_forward_length"), | ||
attn_head_dim=_int_prop(p, "llama.rope.dimension_count"), | ||
rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"), | ||
model_arch=name_prefix, | ||
context_length=_int_prop(p, f"{name_prefix}.context_length"), | ||
embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"), | ||
block_count=_int_prop(p, f"{name_prefix}.block_count"), | ||
feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"), | ||
attention_head_count=attention_head_count, | ||
attention_layer_norm_rms_epsilon=_float_prop( | ||
p, "llama.attention.layer_norm_rms_epsilon" | ||
p, f"{name_prefix}.attention.layer_norm_rms_epsilon" | ||
), | ||
attention_head_count_kv=_optional_int_prop( | ||
p, "llama.attention.head_count_kv", attention_head_count | ||
p, f"{name_prefix}.attention.head_count_kv", attention_head_count | ||
), | ||
attn_head_dim=rope_dimension_count, | ||
rope_dimension_count=rope_dimension_count, | ||
rope_freq_base=_optional_float_prop( | ||
p, "llama.rope.freq_base", default_rope_freq_base | ||
p, f"{name_prefix}.rope.freq_base", default_rope_freq_base | ||
), | ||
expert_count=_optional_int_prop( | ||
p, "llama.expert_count", default_expert_count | ||
p, f"{name_prefix}.expert_count", default_expert_count | ||
), | ||
expert_used_count=_optional_int_prop( | ||
p, "llama.expert_used_count", default_expert_used_count | ||
p, f"{name_prefix}.expert_used_count", default_expert_used_count | ||
), | ||
) | ||
|
||
|
@@ -107,3 +113,37 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int: | |
return int(value) | ||
except ValueError as e: | ||
raise ValueError(f"Property '{name}' expected to be an int and was not") from e | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding LlamaModelConfig to config file make sense, but having create_kv_cache() function in a config file feels odd. Is there a better way to refactor this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its a bit odd for a dataclass to have a function, but in this case it kind of makes sense because it just creates data. I'm ok with refactoring it out if you want There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved |
||
@dataclass | ||
class LlamaModelConfig: | ||
hp: LlamaHParams | ||
|
||
# Block sequence stride for a paged KV cache. This must divide evenly | ||
# into the context length. | ||
block_seq_stride: int = 16 | ||
|
||
# Either "paged" or "direct". | ||
kv_cache_type: str = "paged" | ||
|
||
# The device on which to place intermediate state. | ||
device: Optional[torch.device] = None | ||
|
||
# Dtype to use for general FP activations not otherwise configured. | ||
activation_dtype: torch.dtype = torch.float16 | ||
|
||
# Dtype to use for attention. | ||
attention_dtype: torch.dtype = torch.float16 | ||
|
||
# Indicates if running with HuggingFace implementation and ensures | ||
# numerical equivalency to HuggingFace's LLaMa if true (by modifying | ||
# rotary embedding). | ||
use_hf: bool = False | ||
|
||
# If true, then the model may pre-initialize certain tables during | ||
# init. This can be better for eager execution but when capturing a program, | ||
# it is often better to preserve the calculation explicitly and rely on | ||
# the compiler to transform it to an initialization time step. This can | ||
# be the difference of many gigabytes of static data being embedded in | ||
# the program and not. | ||
static_tables: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch