Skip to content

Commit

Permalink
Merge branch 'main' into fiber-dist-tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Dec 11, 2024
2 parents 379ba57 + 6c62ed1 commit 59cef3a
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 148 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
strategy:
matrix:
version: [3.11]
os: [ubuntu-latest, windows-latest]
os: [ubuntu-24.04, windows-2022]
fail-fast: false
runs-on: ${{matrix.os}}
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tuner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ permissions:

jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-24.04

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
pre-commit:
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
Expand Down
38 changes: 19 additions & 19 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,35 @@ def write_config(request, pre_process_model):
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

logger.info("Writing config file..." + start_log_group("Writing config file"))

# Construct the new config filename
config_path = (
pre_process_model
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 131072,
"attn_head_count": 8,
"attn_head_dim": 128,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 32,
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
# Read the base config file
base_config_path = pre_process_model / "config.json"
with open(base_config_path, "r") as f:
config = json.load(f)

# Override specific fields
config.update(
{
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"paged_kv_cache": {
**config.get(
"paged_kv_cache", {}
), # Preserve other paged_kv_cache settings
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
)

logger.info(f"Saving edited config to: {config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(config_path, "w") as f:
json.dump(config, f)

logger.info("Config file successfully written" + end_log_group())
yield config_path


Expand Down
15 changes: 0 additions & 15 deletions app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,6 @@ def pre_process_model(request, tmp_path_factory):
device_settings,
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 131072,
"attn_head_count": 8,
"attn_head_dim": 128,
"prefill_batch_sizes": [1, 4],
"decode_batch_sizes": [1, 4],
"transformer_block_count": 32,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
config_path = tmp_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f)

return tmp_dir


Expand Down
34 changes: 19 additions & 15 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,30 @@ def write_config(request, model_test_dir):
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

# Construct the new config filename
config_path = (
model_test_dir
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
# Read the base config file
base_config_path = model_test_dir / "config.json"
with open(base_config_path, "r") as f:
config = json.load(f)

# Override specific fields
config.update(
{
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"paged_kv_cache": {
**config.get(
"paged_kv_cache", {}
), # Preserve other paged_kv_cache settings
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
)
logger.info(f"Saving edited config to: {config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(config_path, "w") as f:
Expand Down
4 changes: 2 additions & 2 deletions docs/nightly_releases.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate

# Install 'sharktank' package from nightly releases.
pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels --pre

# Test the installation.
python -c "from sharktank import ops; print('Sanity check passed')"
Expand All @@ -75,7 +75,7 @@ python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate

# Install 'shortfin' package from nightly releases.
pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels --pre

# Test the installation.
python -c "import shortfin as sf; print('Sanity check passed')"
Expand Down
19 changes: 16 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""Export support for the PagedLLMV1 protocol of models."""

import json
from typing import Any, Dict
import torch

from iree.turbine.aot import *
Expand Down Expand Up @@ -86,17 +87,29 @@ def main():
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
def generate_params_json(
hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int]
) -> Dict[str, Any]:
"""
Generate config.json for shortfin.
For shortfin, we only write attention_head_count_kv because that's all shortfin needs.
Note that this is different from hp.attn_head_count when grouped attention shares kvcache between heads.
"""
return {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": hp.context_length,
"attn_head_count": hp.attention_head_count,
"attn_head_dim": hp.attn_head_dim,
"prefill_batch_sizes": prefill_bs,
"decode_batch_sizes": decode_bs,
"transformer_block_count": hp.block_count,
"block_seq_stride": llama_config.block_seq_stride,
"paged_kv_cache": {
"attention_head_count_kv": hp.attention_head_count_kv,
"block_seq_stride": llama_config.block_seq_stride,
"device_block_count": 256, # so that this makes its way into the config file & can be edited.
},
}

# Unrolling cache updates by batch row makes dynamo sad without an
Expand Down
32 changes: 29 additions & 3 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any, Optional
import torch

from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name

__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


Expand Down Expand Up @@ -287,9 +289,10 @@ class ClipTextConfig:
output_attentions: bool = False
output_hidden_states: bool = False
use_return_dict: bool = True
dtype: torch.dtype = torch.float32

@staticmethod
def from_transformers_clip_text_config(
def from_hugging_face_clip_text_model_config(
config: "transformers.CLIPTextConfig",
) -> "ClipTextConfig":
return ClipTextConfig(
Expand All @@ -308,7 +311,30 @@ def from_transformers_clip_text_config(
output_attentions=config.output_attentions,
output_hidden_states=config.output_hidden_states,
use_return_dict=config.use_return_dict,
dtype=config.torch_dtype or torch.float32,
)

def as_properties(self) -> dict[str, Any]:
return asdict(self)
def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig":
kwargs = self.to_properties()
kwargs["torch_dtype"] = kwargs["dtype"]
del kwargs["dtype"]
kwargs["return_dict"] = kwargs["use_return_dict"]
del kwargs["use_return_dict"]
from transformers import CLIPTextConfig

return CLIPTextConfig(**kwargs)

@staticmethod
def from_properties(properties: dict[str, Any]) -> "ClipTextConfig":
kwargs = dict(properties)
kwargs.pop("SHARK_DATASET_VERSION")
if "dtype" in kwargs and kwargs["dtype"] is not None:
kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"])

return ClipTextConfig(**kwargs)

def to_properties(self) -> dict[str, Any]:
res = asdict(self)
if self.dtype is not None:
res["dtype"] = dtype_to_serialized_name(self.dtype)
return res
47 changes: 33 additions & 14 deletions sharktank/sharktank/models/clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
)
from collections import OrderedDict

from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
from ... import ops
from ...types.theta import Theta, Dataset
from ...types.tensors import DefaultPrimitiveTensor
from ...types.tensors import AnyTensor, DefaultPrimitiveTensor
from ...layers.configs import ClipTextConfig
from ...layers.activations import ACT2FN

Expand Down Expand Up @@ -68,11 +68,11 @@ def forward(
return embeddings


class ClipAttention(BaseLayer):
class ClipAttention(ThetaLayer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
Expand Down Expand Up @@ -182,9 +182,9 @@ def forward(
return attn_output, attn_weights_reshaped


class ClipMlp(BaseLayer):
class ClipMlp(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = LinearLayer(theta("fc1"))
Expand All @@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class ClipEncoderLayer(BaseLayer):
class ClipEncoderLayer(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.embed_dim = config.hidden_size
self.self_attn = ClipAttention(theta=theta("self_attn"), config=config)
self.layer_norm1 = LayerNorm(
Expand Down Expand Up @@ -251,14 +251,14 @@ def forward(
return outputs


class ClipEncoder(BaseLayer):
class ClipEncoder(ThetaLayer):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`ClipEncoderLayer`].
"""

def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.layers = nn.ModuleList(
[
Expand Down Expand Up @@ -356,9 +356,9 @@ def forward(
)


class ClipTextTransformer(nn.Module):
class ClipTextTransformer(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
embed_dim = config.hidden_size
self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config)
Expand Down Expand Up @@ -475,9 +475,9 @@ def forward(
)


class ClipTextModel(BaseLayer):
class ClipTextModel(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config)

Expand All @@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
input_ids = (
torch.arange(
start=0,
end=batch_size * self.config.max_position_embeddings,
dtype=torch.long,
)
% self.config.vocab_size
)
input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings])
return OrderedDict(
[
(
"input_ids",
input_ids,
)
]
)

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Expand Down
Loading

0 comments on commit 59cef3a

Please sign in to comment.