Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIP text model #643

Merged
merged 4 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,13 @@ jobs:
iree-base-runtime

- name: Run tests
# TODO: unify with-t5-data and with-clip-data flags into a single flag
# and make it possible to run only tests that require data.
run: |
pytest \
--with-clip-data \
--with-t5-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
--durations=0

Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def pytest_addoption(parser):
help="Enable all llama benchmarking tests",
)

parser.addoption(
"--with-clip-data",
action="store_true",
default=False,
help=(
"Enable tests that use CLIP data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-t5-data",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .linear import LinearLayer
from .norm import RMSNormLayer
from .norm import RMSNormLayer, LayerNorm
from .rotary_embedding import RotaryEmbeddingLayer
from .token_embedding import TokenEmbeddingLayer
from .llama_attention_block import LlamaAttentionBlock
Expand Down
16 changes: 16 additions & 0 deletions sharktank/sharktank/layers/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from torch import nn
from .. import ops

# TODO: don't use nn.functional directly.
ACT2FN = {
"gelu": nn.functional.gelu,
"gelu_new": ops.gelu_tanh_approximation,
"relu": nn.functional.relu,
"quick_gelu": ops.gelu_sigmoid_approximation,
}
50 changes: 48 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from typing import Any, Optional
import torch

__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]
__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


@dataclass
Expand Down Expand Up @@ -266,3 +266,49 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)


@dataclass
class ClipTextConfig:
vocab_size: int = 49408
hidden_size: int = 512
intermediate_size: int = 2048
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 8
max_position_embeddings: int = 77
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
# This differs from `CLIPTokenizer`'s default and from openai/clip
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
pad_token_id: int = 1
bos_token_id: int = 49406
eos_token_id: int = 49407
output_attentions: bool = False
output_hidden_states: bool = False
use_return_dict: bool = True

@staticmethod
def from_transformers_clip_text_config(
config: "transformers.CLIPTextConfig",
) -> "ClipTextConfig":
return ClipTextConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
projection_dim=config.projection_dim,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
hidden_act=config.hidden_act,
layer_norm_eps=config.layer_norm_eps,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
output_attentions=config.output_attentions,
output_hidden_states=config.output_hidden_states,
use_return_dict=config.use_return_dict,
)

def as_properties(self) -> dict[str, Any]:
return asdict(self)
20 changes: 20 additions & 0 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@ def forward(self, x: torch.Tensor):
# often in higher precision. Downcast back to expected.
norm = ops.to(norm, orig_dtype)
return norm


class LayerNorm(ThetaLayer):
def __init__(
self,
theta: Theta,
*,
weight_name: str = "weight",
bias_name: str = "bias",
eps: float = 1e-05,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
self.bias = None
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
self.eps = eps

def forward(self, x: torch.Tensor):
return ops.layer_norm(x, weight=self.weight, bias=self.bias, eps=self.eps)
8 changes: 8 additions & 0 deletions sharktank/sharktank/models/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .clip import *
from .export import *
Loading
Loading