Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial grok #169

Merged
merged 52 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ba87a04
initial grok
dan-garvey Sep 5, 2024
de9842b
use name prefix instead of new dataclass
dan-garvey Sep 5, 2024
b7965b1
some hacks
dan-garvey Sep 6, 2024
6d3d261
more hack
dan-garvey Sep 6, 2024
b5f535d
fix moe-ffn
dan-garvey Sep 6, 2024
4095db0
Add in some missing grok specific model structure and constants
KyleHerndon Sep 9, 2024
e71630a
Add attn_output_norm layer
archana-ramalingam Sep 12, 2024
5772a3d
Update MOE block in decode
archana-ramalingam Sep 12, 2024
3f2914a
Some fixes to the grok model
KyleHerndon Sep 12, 2024
7c2e133
Merge branch 'main' into grokstar
archana-ramalingam Sep 12, 2024
e1261f5
Revert "Merge branch 'main' into grokstar"
archana-ramalingam Sep 12, 2024
a242bde
Fix merging main changes
archana-ramalingam Sep 12, 2024
bb40d12
Update tensor trace names
archana-ramalingam Sep 12, 2024
cfa8420
Update moe block test
archana-ramalingam Sep 12, 2024
325696f
Update paged attention block with grok changes
archana-ramalingam Sep 12, 2024
48fce0c
Update paged attention block with grok changes
archana-ramalingam Sep 12, 2024
d9e787c
Add use_grok to MOE block
archana-ramalingam Sep 12, 2024
ab084cc
Use use_grok in MOE block
archana-ramalingam Sep 13, 2024
29e3603
Change MOE activation from silu to gelu for Grok
archana-ramalingam Sep 13, 2024
0670e1d
Allow router weight norm for all MOEs
archana-ramalingam Sep 13, 2024
a4be20b
Update llm_configs to support llama and grok architectures
archana-ramalingam Sep 13, 2024
3049f87
Remove comment
archana-ramalingam Sep 13, 2024
b8240c8
Add optional params for Grok
archana-ramalingam Sep 13, 2024
5bf30e0
Add all models supported in sharktank
archana-ramalingam Sep 13, 2024
d970944
Make rope_freq_base mandatory param
archana-ramalingam Sep 13, 2024
b1fd818
small refactor/cleanup
dan-garvey Sep 24, 2024
85e2f87
more cleanup
dan-garvey Sep 24, 2024
7ed9a23
this shouldn't have been unrebased??
dan-garvey Sep 24, 2024
3510634
fix use_hf args
dan-garvey Sep 24, 2024
a4ff36a
Make use_grok optional in MOE and Attention blocks
archana-ramalingam Sep 24, 2024
940db2f
Add use_grok to moe_block_test
archana-ramalingam Sep 24, 2024
bb2f5a1
fix kv cache test
dan-garvey Sep 24, 2024
b6e52eb
Add PreGatherMoeBlock to import from layers
archana-ramalingam Sep 24, 2024
b790cb5
Add MOE block export for prefill + decode
archana-ramalingam Sep 24, 2024
19218f3
Fix architecture variable
archana-ramalingam Sep 24, 2024
7deb42a
Fix imports
archana-ramalingam Sep 24, 2024
1b6cb6d
Fix rope_freq_base
archana-ramalingam Sep 24, 2024
43b20c4
fix flaky test
dan-garvey Sep 24, 2024
d938a08
Merge branch 'main' into grokstar
archana-ramalingam Sep 24, 2024
6aeeb4f
Add short versions for args
archana-ramalingam Sep 24, 2024
cac489c
Remove use_hf and use_grok options from llama
archana-ramalingam Sep 24, 2024
d5c27fe
Move create_kv_cache to utils folder
archana-ramalingam Sep 24, 2024
10d6c87
Fix error
archana-ramalingam Sep 24, 2024
4816c93
Merge branch 'main' into grokstar
archana-ramalingam Sep 25, 2024
124503f
revert addition of dtype arg
dan-garvey Sep 25, 2024
46c6eb6
Merge branch 'main' into grokstar
dan-garvey Sep 25, 2024
f3a8fb1
Remove attention_dtype
archana-ramalingam Sep 25, 2024
dcc1e8f
Merge branch 'main' into grokstar
dan-garvey Sep 25, 2024
88e38e2
fix missing parenth
dan-garvey Sep 25, 2024
430045b
correctly rebase T_T
dan-garvey Sep 25, 2024
f0a3e31
nonstrict
dan-garvey Sep 26, 2024
e5dc9e9
Merge branch 'main' into grokstar
dan-garvey Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *


def main():
Expand Down Expand Up @@ -61,8 +62,12 @@ def main():
llama_config.use_hf = False
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"

if llama_config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
if llama_config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, llama_config)
else:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

Expand Down
18 changes: 9 additions & 9 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer
Expand Down Expand Up @@ -221,11 +222,6 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--attention-dtype",
help="DType to use for attention in the model",
default="float16",
)
parser.add_argument(
"--use-hf",
action="store_true",
Expand All @@ -237,9 +233,8 @@ def main():

device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
attention_dtype = getattr(torch, args.attention_dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

assert isinstance(activation_dtype, torch.dtype)
assert isinstance(attention_dtype, torch.dtype)

dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
Expand All @@ -250,14 +245,18 @@ def main():
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
attention_dtype=activation_dtype,
use_hf=args.use_hf,
)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
if config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, config)
else:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)

if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch

Expand Down Expand Up @@ -287,6 +286,7 @@ def main():
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
counter += 1


if __name__ == "__main__":
Expand Down
80 changes: 80 additions & 0 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from shark_turbine.aot import *
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from ..utils import cli


def main():
parser = cli.create_parser()
parser.add_argument(
"--output-mlir",
help="Output file path for exported MLIR file",
default="/tmp/batch_llama_v1.mlir",
)
parser.add_argument(
"--batch-size",
"-bs",
help="Batch size to generate, e.g. `4` or `2`",
type=lambda arg: int(arg),
default="2",
)
parser.add_argument(
"--verbose",
"-v",
help="Include verbose logging",
action="store_true",
)
parser.add_argument(
"--strict",
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--use-grok",
help="Enable to export Grok model's version of MOE block",
action="store_true",
)

args = cli.parse(parser)

bs = args.batch_size

model = PreGatherMoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))

@fxb.export_program(name="prefill_moe", args=(input,))
def _(model, input: torch.Tensor) -> torch.Tensor:
return model(input)

input = make_rand_torch((bs, 1, 6144))

@fxb.export_program(name="decode_moe", args=(input,))
def _(model, input: torch.Tensor) -> torch.Tensor:
return model(input)

if args.verbose:
for name, ep in fxb.programs.items():
print(f"EXPORT {name}:\n{ep}")

print("Exporting")
output = export(fxb)
print(f"Saving to '{args.output_mlir}'")
output.save_mlir(args.output_mlir)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock

from . import configs
from .configs import *
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .llm_configs import LlamaHParams
from .llm_configs import *
76 changes: 58 additions & 18 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

from dataclasses import dataclass
from typing import Any, Optional

import torch

__all__ = ["LlamaHParams"]
__all__ = ["LlamaHParams", "LlamaModelConfig"]


@dataclass
Expand All @@ -29,48 +28,55 @@ class LlamaHParams:
Comments are only provided if they differ from this source.
"""

model_arch: str
context_length: int
embedding_length: int
block_count: int
feed_forward_length: int
rope_dimension_count: int
rope_freq_base: float
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int
expert_count: int
expert_used_count: int
rope_dimension_count: Optional[int] = None
rope_freq_base: Optional[float] = None
expert_count: Optional[int] = None
expert_used_count: Optional[int] = None

@staticmethod
def from_gguf_props(p: dict[str, Any]):
name_prefix = p["general.architecture"]
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
attention_head_count = _int_prop(p, "llama.attention.head_count")
default_rope_dimension_count = 128
attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count")
rope_dimension_count = _optional_int_prop(
p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count
)

return LlamaHParams(
context_length=_int_prop(p, "llama.context_length"),
embedding_length=_int_prop(p, "llama.embedding_length"),
block_count=_int_prop(p, "llama.block_count"),
feed_forward_length=_int_prop(p, "llama.feed_forward_length"),
attn_head_dim=_int_prop(p, "llama.rope.dimension_count"),
rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"),
model_arch=name_prefix,
context_length=_int_prop(p, f"{name_prefix}.context_length"),
embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"),
block_count=_int_prop(p, f"{name_prefix}.block_count"),
feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"),
attention_head_count=attention_head_count,
attention_layer_norm_rms_epsilon=_float_prop(
p, "llama.attention.layer_norm_rms_epsilon"
p, f"{name_prefix}.attention.layer_norm_rms_epsilon"
),
attention_head_count_kv=_optional_int_prop(
p, "llama.attention.head_count_kv", attention_head_count
p, f"{name_prefix}.attention.head_count_kv", attention_head_count
),
attn_head_dim=rope_dimension_count,
rope_dimension_count=rope_dimension_count,
rope_freq_base=_optional_float_prop(
p, "llama.rope.freq_base", default_rope_freq_base
p, f"{name_prefix}.rope.freq_base", default_rope_freq_base
),
expert_count=_optional_int_prop(
p, "llama.expert_count", default_expert_count
p, f"{name_prefix}.expert_count", default_expert_count
),
expert_used_count=_optional_int_prop(
p, "llama.expert_used_count", default_expert_used_count
p, f"{name_prefix}.expert_used_count", default_expert_used_count
),
)

Expand Down Expand Up @@ -107,3 +113,37 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
return int(value)
except ValueError as e:
raise ValueError(f"Property '{name}' expected to be an int and was not") from e


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding LlamaModelConfig to config file make sense, but having create_kv_cache() function in a config file feels odd. Is there a better way to refactor this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a bit odd for a dataclass to have a function, but in this case it kind of makes sense because it just creates data. I'm ok with refactoring it out if you want

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved create_kv_cache function to utils folder, as that seemed to be the best place for it.

@dataclass
class LlamaModelConfig:
hp: LlamaHParams

# Block sequence stride for a paged KV cache. This must divide evenly
# into the context length.
block_seq_stride: int = 16

# Either "paged" or "direct".
kv_cache_type: str = "paged"

# The device on which to place intermediate state.
device: Optional[torch.device] = None

# Dtype to use for general FP activations not otherwise configured.
activation_dtype: torch.dtype = torch.float16

# Dtype to use for attention.
attention_dtype: torch.dtype = torch.float16

# Indicates if running with HuggingFace implementation and ensures
# numerical equivalency to HuggingFace's LLaMa if true (by modifying
# rotary embedding).
use_hf: bool = False

# If true, then the model may pre-initialize certain tables during
# init. This can be better for eager execution but when capturing a program,
# it is often better to preserve the calculation explicitly and rely on
# the compiler to transform it to an initialization time step. This can
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True
61 changes: 61 additions & 0 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,70 @@

__all__ = [
"FFNMOE",
"PreGatherFFNMOE",
]


class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum(einstring, inputs, weights.float())
return matmul

def bigger_mmg(self, inputs, weights, experts):
inputs = inputs[:, :]
weights = weights[experts, :, :]
matmul = torch.einsum("mek,menk->men", inputs, weights.float())
return matmul

def one_hot_matmul(self, inputs, weights, experts):
matmul = torch.einsum("mk,bnk->bmn", inputs, weights)
# Post mix the experts
oh = (
torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
.transpose(0, 1)
.to(torch.float32)
)
output = torch.einsum("bm,bmn->mn", oh, matmul)
return output

def forward(
self,
h: torch.Tensor,
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(
self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)
)
else:
ffn_gate = F.silu(
self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts)
)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men"
)
ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down)
return torch.sum(ffn_down, dim=1)


class FFNMOE(ThetaLayer):
def __init__(
self,
Expand Down
Loading
Loading