Skip to content

Commit

Permalink
Add CLIP text model (#643)
Browse files Browse the repository at this point in the history
Ports the CLIP text model from Hugging Face. This is the first iteration
so not much is changed from the original model. Things like dropout and
checkpointing are removed.
Add numeric verification tests for the various components of the stack
when executing in eager mode. Verifications are made for float32 and
bfloat16. There are tests for toy-sized components and the whole model
as well as the Large pretrained variant.
These tests does not include testing with IREE.

Functionalities for mask creation are not yet ported.
  • Loading branch information
sogartar authored Dec 9, 2024
1 parent 2f5bfab commit c9cb226
Show file tree
Hide file tree
Showing 20 changed files with 1,209 additions and 35 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,13 @@ jobs:
iree-base-runtime
- name: Run tests
# TODO: unify with-t5-data and with-clip-data flags into a single flag
# and make it possible to run only tests that require data.
run: |
pytest \
--with-clip-data \
--with-t5-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
--durations=0
Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def pytest_addoption(parser):
help="Enable all llama benchmarking tests",
)

parser.addoption(
"--with-clip-data",
action="store_true",
default=False,
help=(
"Enable tests that use CLIP data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)
parser.addoption(
"--with-t5-data",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .linear import LinearLayer
from .norm import RMSNormLayer
from .norm import RMSNormLayer, LayerNorm
from .rotary_embedding import RotaryEmbeddingLayer
from .token_embedding import TokenEmbeddingLayer
from .llama_attention_block import LlamaAttentionBlock
Expand Down
16 changes: 16 additions & 0 deletions sharktank/sharktank/layers/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from torch import nn
from .. import ops

# TODO: don't use nn.functional directly.
ACT2FN = {
"gelu": nn.functional.gelu,
"gelu_new": ops.gelu_tanh_approximation,
"relu": nn.functional.relu,
"quick_gelu": ops.gelu_sigmoid_approximation,
}
50 changes: 48 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from typing import Any, Optional
import torch

__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]
__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


@dataclass
Expand Down Expand Up @@ -266,3 +266,49 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)


@dataclass
class ClipTextConfig:
vocab_size: int = 49408
hidden_size: int = 512
intermediate_size: int = 2048
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 8
max_position_embeddings: int = 77
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
# This differs from `CLIPTokenizer`'s default and from openai/clip
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
pad_token_id: int = 1
bos_token_id: int = 49406
eos_token_id: int = 49407
output_attentions: bool = False
output_hidden_states: bool = False
use_return_dict: bool = True

@staticmethod
def from_transformers_clip_text_config(
config: "transformers.CLIPTextConfig",
) -> "ClipTextConfig":
return ClipTextConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
projection_dim=config.projection_dim,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
hidden_act=config.hidden_act,
layer_norm_eps=config.layer_norm_eps,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
output_attentions=config.output_attentions,
output_hidden_states=config.output_hidden_states,
use_return_dict=config.use_return_dict,
)

def as_properties(self) -> dict[str, Any]:
return asdict(self)
20 changes: 20 additions & 0 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@ def forward(self, x: torch.Tensor):
# often in higher precision. Downcast back to expected.
norm = ops.to(norm, orig_dtype)
return norm


class LayerNorm(ThetaLayer):
def __init__(
self,
theta: Theta,
*,
weight_name: str = "weight",
bias_name: str = "bias",
eps: float = 1e-05,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
self.bias = None
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
self.eps = eps

def forward(self, x: torch.Tensor):
return ops.layer_norm(x, weight=self.weight, bias=self.bias, eps=self.eps)
8 changes: 8 additions & 0 deletions sharktank/sharktank/models/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .clip import *
from .export import *
Loading

0 comments on commit c9cb226

Please sign in to comment.