From 64a7d5fbbb73c545b60b7855be3af9947e3efe71 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 28 Nov 2024 00:06:19 +0000 Subject: [PATCH 1/4] Add CLIP text model Ports the CLIP text model from Hugging Face. Add numeric verification tests for the various components of the stack when executing in eager mode. Verifications are made for float32 and bfloat16. There are tests for toy-sized components and the whole model as well as the Large pretrained variant. These tests does not include testing with IREE. Functionalities for mask creation are not yet ported. --- .github/workflows/ci-sharktank.yml | 4 + sharktank/conftest.py | 9 + sharktank/sharktank/layers/__init__.py | 2 +- sharktank/sharktank/layers/activations.py | 16 + .../sharktank/layers/configs/llm_configs.py | 50 +- sharktank/sharktank/layers/norm.py | 21 + sharktank/sharktank/models/clip/__init__.py | 8 + sharktank/sharktank/models/clip/clip.py | 502 ++++++++++++++++++ sharktank/sharktank/models/clip/export.py | 57 ++ sharktank/sharktank/models/t5/t5.py | 8 +- sharktank/sharktank/ops/default_impls.py | 4 +- sharktank/sharktank/ops/sharded_impls.py | 4 +- sharktank/sharktank/ops/signatures.py | 15 +- sharktank/sharktank/types/tensors.py | 6 + sharktank/sharktank/types/theta.py | 16 +- sharktank/sharktank/utils/hf_datasets.py | 19 + sharktank/sharktank/utils/math.py | 13 + sharktank/sharktank/utils/testing.py | 14 + sharktank/tests/models/clip/clip_test.py | 449 ++++++++++++++++ sharktank/tests/models/t5/t5_test.py | 21 +- 20 files changed, 1203 insertions(+), 35 deletions(-) create mode 100644 sharktank/sharktank/layers/activations.py create mode 100644 sharktank/sharktank/models/clip/__init__.py create mode 100644 sharktank/sharktank/models/clip/clip.py create mode 100644 sharktank/sharktank/models/clip/export.py create mode 100644 sharktank/tests/models/clip/clip_test.py diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index bf616bee6..fc3d50bb6 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -122,9 +122,13 @@ jobs: iree-base-runtime - name: Run tests + # TODO: unify with-t5-data and with-clip-data flags into a single flag + # and make it possible to run only tests that require data. run: | pytest \ + --with-clip-data \ --with-t5-data \ + sharktank/tests/models/clip/clip_test.py \ sharktank/tests/models/t5/t5_test.py \ --durations=0 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index ddd371198..9d6257513 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -88,6 +88,15 @@ def pytest_addoption(parser): help="Enable all llama benchmarking tests", ) + parser.addoption( + "--with-clip-data", + action="store_true", + default=False, + help=( + "Enable tests that use CLIP data like models that is not a part of the source " + "code. The user is expected to provide the data" + ), + ) parser.addoption( "--with-t5-data", action="store_true", diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 5828d2dd3..3caf7631d 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -9,7 +9,7 @@ from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache from .causal_llm import BaseCausalLMModel from .linear import LinearLayer -from .norm import RMSNormLayer +from .norm import RMSNormLayer, LayerNorm from .rotary_embedding import RotaryEmbeddingLayer from .token_embedding import TokenEmbeddingLayer from .llama_attention_block import LlamaAttentionBlock diff --git a/sharktank/sharktank/layers/activations.py b/sharktank/sharktank/layers/activations.py new file mode 100644 index 000000000..40ff94a23 --- /dev/null +++ b/sharktank/sharktank/layers/activations.py @@ -0,0 +1,16 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from torch import nn +from .. import ops + +# TODO: don't use nn.functional directly. +ACT2FN = { + "gelu": nn.functional.gelu, + "gelu_new": ops.gelu_tanh_approximation, + "relu": nn.functional.relu, + "quick_gelu": ops.gelu_sigmoid_approximation, +} diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 996a92152..1513f364a 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -14,11 +14,11 @@ (and indeed, can bootstrap these off of GGUF files). """ -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from typing import Any, Optional import torch -__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"] +__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @dataclass @@ -266,3 +266,49 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): all_kwargs.update(kwargs) return T5Config(**all_kwargs) + + +@dataclass +class ClipTextConfig: + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + output_attentions: bool = False + output_hidden_states: bool = False + use_return_dict: bool = True + + @staticmethod + def from_transformers_clip_text_config( + config: "transformers.CLIPTextConfig", + ) -> "ClipTextConfig": + return ClipTextConfig( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + projection_dim=config.projection_dim, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + max_position_embeddings=config.max_position_embeddings, + hidden_act=config.hidden_act, + layer_norm_eps=config.layer_norm_eps, + pad_token_id=config.pad_token_id, + bos_token_id=config.bos_token_id, + eos_token_id=config.eos_token_id, + output_attentions=config.output_attentions, + output_hidden_states=config.output_hidden_states, + use_return_dict=config.use_return_dict, + ) + + def as_properties(self) -> dict[str, Any]: + return asdict(self) diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index 4fa08050a..619faf4b4 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -39,3 +39,24 @@ def forward(self, x: torch.Tensor): # often in higher precision. Downcast back to expected. norm = ops.to(norm, orig_dtype) return norm + + +class LayerNorm(ThetaLayer): + def __init__( + self, + theta: Theta, + *, + weight_name: str = "weight", + bias_name: str = "bias", + eps: float = 1e-05, + ): + super().__init__(theta) + self.weight = self.theta_tensor(weight_name) + if bias_name in self.theta.keys: + self.bias = self.theta_tensor(bias_name) + else: + self.bias = None + self.eps = eps + + def forward(self, x: torch.Tensor): + return ops.layer_norm(x, weight=self.weight, bias=self.bias, eps=self.eps) diff --git a/sharktank/sharktank/models/clip/__init__.py b/sharktank/sharktank/models/clip/__init__.py new file mode 100644 index 000000000..cbe52c953 --- /dev/null +++ b/sharktank/sharktank/models/clip/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .clip import * +from .export import * diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py new file mode 100644 index 000000000..b3dfb4791 --- /dev/null +++ b/sharktank/sharktank/models/clip/clip.py @@ -0,0 +1,502 @@ +# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Optional, Tuple, Union +import torch +from torch import nn +import transformers + +# TODO: port _prepare_4d_attention_mask and _create_4d_causal_attention_mask to sharktank +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _create_4d_causal_attention_mask, +) +from collections import OrderedDict + +from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer +from ... import ops +from ...types.theta import Theta, Dataset +from ...types.tensors import DefaultPrimitiveTensor +from ...layers.configs import ClipTextConfig +from ...layers.activations import ACT2FN + + +class ClipTextEmbeddings(nn.Module): + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.token_embedding = TokenEmbeddingLayer( + theta=theta("token_embedding"), dtype=None + ) + self.position_embedding = TokenEmbeddingLayer( + theta=theta("position_embedding"), dtype=None + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class ClipAttention(BaseLayer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + + self.k_proj = LinearLayer(theta("k_proj")) + self.v_proj = LinearLayer(theta("v_proj")) + self.q_proj = LinearLayer(theta("q_proj")) + self.out_proj = LinearLayer(theta("out_proj")) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = ops.matmul(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = ops.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_output = ops.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class ClipMlp(BaseLayer): + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = LinearLayer(theta("fc1")) + self.fc2 = LinearLayer(theta("fc2")) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class ClipEncoderLayer(BaseLayer): + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ClipAttention(theta=theta("self_attn"), config=config) + self.layer_norm1 = LayerNorm( + theta=theta("layer_norm1"), eps=config.layer_norm_eps + ) + self.mlp = ClipMlp(theta=theta("mlp"), config=config) + self.layer_norm2 = LayerNorm( + theta=theta("layer_norm2"), eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ClipEncoder(BaseLayer): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`ClipEncoderLayer`]. + """ + + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + ClipEncoderLayer(theta=theta(f"layers.{i}"), config=config) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, dict[str, Any]]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a dict instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return OrderedDict( + (k, v) + for k, v in [ + ("last_hidden_state", hidden_states), + ("hidden_states", encoder_states), + ("attentions", all_attentions), + ] + if v is not None + ) + + +class ClipTextTransformer(nn.Module): + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config) + self.encoder = ClipEncoder(theta=theta("encoder"), config=config) + self.final_layer_norm = LayerNorm( + theta=theta("final_layer_norm"), eps=config.layer_norm_eps + ) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, dict[str, Any]]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = ( + encoder_outputs["last_hidden_state"] if return_dict else encoder_outputs[0] + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), + input_ids.argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids == self.eos_token_id).int().argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return OrderedDict( + (k, v) + for k, v in [ + ("last_hidden_state", last_hidden_state), + ("pooler_output", pooled_output), + ( + "hidden_states", + encoder_outputs["hidden_states"] + if "hidden_states" in encoder_outputs + else None, + ), + ( + "attentions", + encoder_outputs["attentions"] + if "attentions" in encoder_outputs + else None, + ), + ] + if v is not None + ) + + +class ClipTextModel(BaseLayer): + def __init__(self, theta: Theta, config: ClipTextConfig): + super().__init__() + self.config = config + self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config) + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, dict[str, Any]]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py new file mode 100644 index 000000000..83bda2cbe --- /dev/null +++ b/sharktank/sharktank/models/clip/export.py @@ -0,0 +1,57 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union +import transformers +from transformers.models.clip.modeling_clip import ( + CLIPAttention as TransformersCLIPAttention, + CLIPEncoderLayer as TransformersCLIPEncoderLayer, + CLIPEncoder as TransformersCLIPEncoder, +) +from os import PathLike +import torch + +from ...types.theta import Theta, Dataset, torch_module_to_theta +from ...types.tensors import DefaultPrimitiveTensor +from ...layers.configs import ClipTextConfig + + +def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta: + return torch_module_to_theta(model) + + +def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta: + return torch_module_to_theta(model) + + +def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta: + return torch_module_to_theta(model) + + +def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: + return torch_module_to_theta(model) + + +def transformers_clip_text_model_to_dataset( + model: transformers.CLIPTextModel, +) -> Dataset: + config = ClipTextConfig.from_transformers_clip_text_config(model.config) + properties = config.as_properties() + theta = transformers_clip_text_model_to_theta(model) + theta.rename_tensors_to_paths() + return Dataset(properties, theta) + + +def export_clip_text_model_dataset_from_hugging_face( + model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], + output_path: Union[str, PathLike], +): + if isinstance(model_or_name_or_path, transformers.CLIPTextModel): + model = model_or_name_or_path + else: + model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path) + dataset = transformers_clip_text_model_to_dataset(model) + dataset.save(output_path) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index 88472db1d..a2a8958be 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -28,6 +28,7 @@ from ...types.theta import Theta from ...types.tensors import AnyTensor from ...layers import FFN, T5Config +from ...layers.activations import ACT2FN __all__ = [ "T5Config", @@ -43,13 +44,6 @@ logger = logging.getLogger(__name__) -ACT2FN = { - "gelu": nn.functional.gelu, - "gelu_new": ops.gelu_tanh_approximation, - "relu": nn.functional.relu, -} - - class T5LayerFF(nn.Module): def __init__( self, diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 47e737fb1..d66e97233 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -158,13 +158,13 @@ def elementwise_ternary(operator, x, y, z, *args, **kwargs): # Embedding Lookup @embedding_lookup.override(Tensor, Tensor) -def embedding_lookup_default(input, embedding_matrix, dtype: dtype): +def embedding_lookup_default(input, embedding_matrix, dtype: Optional[dtype]): return F.embedding(unbox_tensor(input), unbox_tensor(embedding_matrix).to(dtype)) @embedding_lookup.override(Tensor, QuantizedTensor) def embedding_lookup_Tensor_QuantizedTensor( - input, embedding_matrix: QuantizedTensor, dtype: dtype + input, embedding_matrix: QuantizedTensor, dtype: Optional[dtype] ): dequant = embedding_matrix.unpack().dequant(dtype=dtype) return F.embedding(unbox_tensor(input), dequant) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 07f466f5b..015e88a4b 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -397,7 +397,9 @@ def elementwise_binary_replicated_lhs_unsharded_rhs( # Embedding Lookup @embedding_lookup.override(ReplicatedTensor, ReplicatedTensor) def embedding_lookup_default( - input: ReplicatedTensor, embedding_matrix: ReplicatedTensor, dtype: torch.dtype + input: ReplicatedTensor, + embedding_matrix: ReplicatedTensor, + dtype: Optional[torch.dtype], ): assert input.shard_count == embedding_matrix.shard_count shards = [ diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index dc7fb108a..cbe959d28 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -29,6 +29,7 @@ "expand", "flatten", "gather", + "gelu_sigmoid_approximation", "gelu_tanh_approximation", "get_index", "gemm", @@ -228,7 +229,7 @@ def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs): @overridable def embedding_lookup( - input: AnyTensor, embedding_matrix: AnyTensor, dtype: dtype + input: AnyTensor, embedding_matrix: AnyTensor, dtype: Optional[dtype] ) -> AnyTensor: """Performs the equivalent of F.embedding(input, embedding_matrix). @@ -241,7 +242,10 @@ def embedding_lookup( @embedding_lookup.trampoline def _embedding_lookup_trampoline( - d: SignatureDispatcher, input: AnyTensor, embedding_matrix: AnyTensor, dtype: dtype + d: SignatureDispatcher, + input: AnyTensor, + embedding_matrix: AnyTensor, + dtype: Optional[dtype], ): tensors = (input, embedding_matrix) for override in d.find_overrides(tensors): @@ -376,6 +380,13 @@ def _gather_trampoline( d.fail(dispatch_args) +def gelu_sigmoid_approximation(input: AnyTensor) -> AnyTensor: + """Applies GELU approximation that is fast but somewhat inaccurate. + See: https://github.com/hendrycks/GELUs + """ + return input * elementwise(torch.sigmoid, 1.702 * input) + + def gelu_tanh_approximation(input: AnyTensor) -> AnyTensor: """Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Approximation with tanh""" diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 2c267ac49..3a38fdce4 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -378,6 +378,12 @@ def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": shape = args[0] return reshape(self, shape) + def size(self, dim: Optional[int] = None) -> tuple[int]: + if dim is None: + return tuple(self.shape) + else: + return self.shape[dim] + def transpose(self, dim0: int, dim1: int) -> "AnyTensor": from ..ops import transpose diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 29bc29bb8..143ede184 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -28,14 +28,11 @@ PrimitiveTensor, QuantizedTensor, InferenceTensorMetadata, + DefaultPrimitiveTensor, REGISTERED_INFERENCE_TENSOR_CLASSES, ) -__all__ = [ - "Dataset", - "flat_to_nested_dict", - "Theta", -] +__all__ = ["Dataset", "flat_to_nested_dict", "Theta", "torch_module_to_theta"] IOReportCallback = Callable[[str], None] @@ -216,6 +213,15 @@ def rename_tensors_to_paths(self): tensor.name = path +def torch_module_to_theta(module: torch.nn.Module) -> Theta: + return Theta( + { + name: DefaultPrimitiveTensor(data=param) + for name, param in module.named_parameters() + } + ) + + def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]: """Nest a flat or semi-flat dictionary. diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 0562d5854..c6a799404 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -344,6 +344,25 @@ def alias_dataset(from_name: str, to_name: str): ), ) +Dataset( + "openai/clip-vit-large-patch14", + ( + RemoteFile( + "config", + "openai/clip-vit-large-patch14", + "config.json", + extra_filenames=[ + "model.safetensors", + "preprocessor_config.json", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "vocab.json", + ], + ), + ), +) + ################################################################################ # Tool entrypoint ################################################################################ diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index 3f32ac952..3723f67dd 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -4,7 +4,9 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Union, Optional from numbers import Number +import torch def ceildiv(a: int | float, b: int | float) -> int | float: @@ -13,3 +15,14 @@ def ceildiv(a: int | float, b: int | float) -> int | float: def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: return x + (-x % multiple) + + +def cosine_similarity( + a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None +) -> float: + """Compute cosine similarity over dimensions dim. + If dim is none computes over all dimensions.""" + dot_product = torch.sum(a * b, dim=dim) + norm_a = a.pow(2).sum(dim=dim).sqrt() + norm_b = b.pow(2).sum(dim=dim).sqrt() + return dot_product / (norm_a * norm_b) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 32acec8ac..6c81acf9e 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -25,6 +25,12 @@ def make_rand_torch(shape: list[int], dtype: Optional[torch.dtype] = torch.float return torch.rand(shape, dtype=dtype) * 2 - 1 +def make_random_mask(shape: tuple[int], dtype: Optional[torch.dtype] = None): + mask = make_rand_torch(shape=shape, dtype=dtype) + mask = (mask >= 0).to(dtype=dtype) + return mask + + class TempDirTestBase(unittest.TestCase): def setUp(self): self._temp_dir = Path(tempfile.mkdtemp(type(self).__qualname__)) @@ -196,3 +202,11 @@ def decorator(test_item: Callable): return test_item return decorator + + +test_prompts = [ + "Studies have been shown that owning a dog is good for you", + "The horse went into the river", + "We need at least one sentence long enough so that it spans more than one padding block which by default is of size 16.", + "Make the batch size 4", +] diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py new file mode 100644 index 000000000..409999797 --- /dev/null +++ b/sharktank/tests/models/clip/clip_test.py @@ -0,0 +1,449 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import functools +from parameterized import parameterized +import pytest +import torch +from torch.utils._pytree import tree_map +from typing import Optional +from unittest import TestCase +import transformers +from transformers import CLIPTextModel as TransformersCLIPTextModel, CLIPTokenizer +from transformers.models.clip.modeling_clip import ( + CLIPAttention as TransformersCLIPAttention, + CLIPEncoderLayer as TransformersCLIPEncoderLayer, + CLIPEncoder as TransformersCLIPEncoder, +) + +from sharktank.types import DefaultPrimitiveTensor +from sharktank.transforms.dataset import set_float_dtype +from sharktank.utils.hf_datasets import get_dataset +from sharktank.utils.math import cosine_similarity +from sharktank.utils.testing import ( + make_rand_torch, + make_random_mask, + TempDirTestBase, + test_prompts, +) +from sharktank.models.clip.export import ( + export_clip_text_model_dataset_from_hugging_face, + transformers_clip_attention_to_theta, + transformers_clip_encoder_layer_to_theta, + transformers_clip_encoder_to_theta, + transformers_clip_text_model_to_theta, +) +from sharktank.models.clip import ( + ClipAttention, + ClipEncoderLayer, + ClipEncoder, + ClipTextModel, +) +from sharktank.layers.configs.llm_configs import ClipTextConfig +from sharktank import ops + +with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')") + + +@pytest.mark.usefixtures("path_prefix") +class ClipExportTest(TempDirTestBase): + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + @with_clip_data + def testSmokeExportLargeF32FromHuggingFace(self): + repo_id = "openai/clip-vit-large-patch14" + get_dataset( + repo_id, + ).download() + output_path = f"{self.path_prefix}{repo_id.replace('/', '--')}.irpa" + export_clip_text_model_dataset_from_hugging_face(repo_id, output_path) + + +@pytest.mark.usefixtures("get_model_artifacts") +class ClipTextEagerTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + def runTestCompareTorchEagerAgainstHuggingFace( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: float, + ): + """Compares the last hidden states with the cosine similarity metric. + This metric is sensible as the outputs are the result of layer normalization. + The angle between the vectors would indicate how close they are.""" + get_dataset( + huggingface_repo_id, + ).download() + + reference_model: TransformersCLIPTextModel = ( + TransformersCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + ) + + theta = transformers_clip_text_model_to_theta(reference_model) + theta.rename_tensors_to_paths() + theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) + config = ClipTextConfig.from_transformers_clip_text_config( + reference_model.config + ) + model = ClipTextModel(theta, config) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + huggingface_repo_id, + max_length=reference_model.config.max_position_embeddings, + ) + input_ids = tokenizer( + test_prompts, + truncation=True, + max_length=reference_model.config.max_position_embeddings, + padding="max_length", + return_tensors="pt", + )["input_ids"] + + expected_outputs = reference_model(input_ids=input_ids) + actual_outputs = model(input_ids=DefaultPrimitiveTensor(data=input_ids)) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + cosine_similarity_per_token = cosine_similarity( + actual_outputs["last_hidden_state"], + expected_outputs["last_hidden_state"], + dim=-1, + ) + torch.testing.assert_close( + cosine_similarity_per_token, + torch.ones_like(cosine_similarity_per_token), + atol=atol, + rtol=0, + ) + + @with_clip_data + def testLargeCompareTorchEagerF32AgainstHuggingFaceF32(self): + self.runTestCompareTorchEagerAgainstHuggingFace( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-5, + ) + + @with_clip_data + def testLargeCompareTorchEagerBf16AgainstHuggingFaceF32(self): + self.runTestCompareTorchEagerAgainstHuggingFace( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-3, + ) + + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16, 4e-2, 1.6e-2], + [torch.float32, torch.bfloat16, 4e-2, 1.6e-2], + ] + ) + def testCompareEagerToySizedModelAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) + batch_size = 19 + tgt_len = 23 + num_attention_heads = 5 + vocab_size = 11 + reference_config = transformers.CLIPTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + layer_norm_eps=1e-4, + num_hidden_layers=2, + final_layer_norm=1e-3, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + ) + reference_model = TransformersCLIPTextModel( + reference_config, + ) + reference_model.eval() + + theta = transformers_clip_text_model_to_theta(reference_model) + theta.rename_tensors_to_paths() + theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) + config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + model = ClipTextModel(theta, config) + + input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len]) + + expected_outputs = reference_model(input_ids=input_ids) + + actual_outputs = model(input_ids=DefaultPrimitiveTensor(data=input_ids)) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + +class ClipAttentionTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + @parameterized.expand( + [ + [torch.float32, torch.float32], + # Default values are not enough because torch.nn.Linear does fused + # multiply-add, while our implementation is decomposed. + # There may be other source of discrepancy. + [torch.bfloat16, torch.bfloat16, 0.5e-2, 1.6e-2], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareEagerToySizedModelAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) + batch_size = 19 + tgt_len = 23 + src_len = tgt_len + num_attention_heads = 2 + reference_config = transformers.CLIPTextConfig( + vocab_size=11, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + ) + reference_model = TransformersCLIPAttention( + reference_config, + ) + reference_model.eval() + + theta = transformers_clip_attention_to_theta(reference_model) + theta.rename_tensors_to_paths() + theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) + config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + model = ClipAttention(theta, config) + + reference_hidden_states = make_rand_torch( + shape=[batch_size, tgt_len, reference_config.hidden_size], + dtype=reference_dtype, + ) + reference_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + reference_causal_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + expected_outputs = reference_model( + hidden_states=reference_hidden_states, + attention_mask=reference_attention_mask, + causal_attention_mask=reference_causal_attention_mask, + ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + attention_mask = ops.to(reference_attention_mask, dtype=target_dtype) + causal_attention_mask = ops.to( + reference_causal_attention_mask, dtype=target_dtype + ) + actual_outputs = model( + hidden_states=DefaultPrimitiveTensor(data=hidden_states), + attention_mask=DefaultPrimitiveTensor(data=attention_mask), + causal_attention_mask=DefaultPrimitiveTensor(data=causal_attention_mask), + ) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + +class ClipEncoderLayerTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16, 1e-2, 1.6e-2], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareEagerToySizedModelAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) + batch_size = 19 + tgt_len = 23 + src_len = tgt_len + num_attention_heads = 2 + reference_config = transformers.CLIPTextConfig( + vocab_size=11, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + layer_norm_eps=1e-4, + ) + reference_model = TransformersCLIPEncoderLayer( + reference_config, + ) + reference_model.eval() + + theta = transformers_clip_encoder_layer_to_theta(reference_model) + theta.rename_tensors_to_paths() + theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) + config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + model = ClipEncoderLayer(theta, config) + + reference_hidden_states = make_rand_torch( + shape=[batch_size, tgt_len, reference_config.hidden_size], + dtype=reference_dtype, + ) + reference_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + reference_causal_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + expected_outputs = reference_model( + hidden_states=reference_hidden_states, + attention_mask=reference_attention_mask, + causal_attention_mask=reference_causal_attention_mask, + ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + attention_mask = ops.to(reference_attention_mask, dtype=target_dtype) + causal_attention_mask = ops.to( + reference_causal_attention_mask, dtype=target_dtype + ) + actual_outputs = model( + hidden_states=DefaultPrimitiveTensor(data=hidden_states), + attention_mask=DefaultPrimitiveTensor(data=attention_mask), + causal_attention_mask=DefaultPrimitiveTensor(data=causal_attention_mask), + ) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + +class ClipEncoderTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16, 2e-2, 1.6e-2], + [torch.float32, torch.bfloat16, 2e-2, 1.6e-2], + ] + ) + def testCompareEagerToySizedModelAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) + batch_size = 19 + tgt_len = 23 + src_len = tgt_len + num_attention_heads = 5 + reference_config = transformers.CLIPTextConfig( + vocab_size=11, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + layer_norm_eps=1e-4, + num_hidden_layers=2, + ) + reference_model = TransformersCLIPEncoder( + reference_config, + ) + reference_model.eval() + + theta = transformers_clip_encoder_to_theta(reference_model) + theta.rename_tensors_to_paths() + theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) + config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + model = ClipEncoder(theta, config) + + reference_inputs_embeds = make_rand_torch( + shape=[batch_size, tgt_len, reference_config.hidden_size], + dtype=reference_dtype, + ) + reference_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + reference_causal_attention_mask = make_random_mask( + shape=[batch_size, 1, tgt_len, src_len], dtype=reference_dtype + ) + expected_outputs = reference_model( + inputs_embeds=reference_inputs_embeds, + attention_mask=reference_attention_mask, + causal_attention_mask=reference_causal_attention_mask, + ) + + inputs_embeds = ops.to(reference_inputs_embeds, dtype=target_dtype) + attention_mask = ops.to(reference_attention_mask, dtype=target_dtype) + causal_attention_mask = ops.to( + reference_causal_attention_mask, dtype=target_dtype + ) + actual_outputs = model( + inputs_embeds=DefaultPrimitiveTensor(data=inputs_embeds), + attention_mask=DefaultPrimitiveTensor(data=attention_mask), + causal_attention_mask=DefaultPrimitiveTensor(data=causal_attention_mask), + ) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 1a696ba57..8ec7619fe 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -39,7 +39,12 @@ export_encoder_mlir, export_encoder_iree_parameters, ) -from sharktank.utils.testing import make_rand_torch, TempDirTestBase +from sharktank.utils.testing import ( + make_rand_torch, + make_random_mask, + TempDirTestBase, + test_prompts, +) from sharktank.utils.hf_datasets import get_dataset from sharktank.utils.iree import ( get_iree_devices, @@ -57,20 +62,6 @@ with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')") -def make_random_mask(shape: tuple[int], dtype: torch.dtype): - mask = make_rand_torch(shape=shape, dtype=dtype) - mask = (mask >= 0).to(dtype=dtype) - return mask - - -test_prompts = [ - "Studies have been shown that owning a dog is good for you", - "The horse went into the river", - "We need at least one sentence long enough so that it spans more than one padding block which by default is of size 16.", - "Make the batch size 4", -] - - @pytest.mark.usefixtures("get_model_artifacts") class T5EncoderEagerTest(TestCase): def setUp(self): From 4ede0535c8d2c105658c19efe3958f2274ac138c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 4 Dec 2024 15:20:30 +0000 Subject: [PATCH 2/4] Address some PR comments --- sharktank/sharktank/layers/norm.py | 3 +-- sharktank/sharktank/types/tensors.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index 619faf4b4..a93649b29 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -52,10 +52,9 @@ def __init__( ): super().__init__(theta) self.weight = self.theta_tensor(weight_name) + self.bias = None if bias_name in self.theta.keys: self.bias = self.theta_tensor(bias_name) - else: - self.bias = None self.eps = eps def forward(self, x: torch.Tensor): diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 3a38fdce4..153a5d753 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -381,8 +381,7 @@ def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": def size(self, dim: Optional[int] = None) -> tuple[int]: if dim is None: return tuple(self.shape) - else: - return self.shape[dim] + return self.shape[dim] def transpose(self, dim0: int, dim1: int) -> "AnyTensor": from ..ops import transpose From cb0b703f41fc7ac797d44268849457eefca1342c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 5 Dec 2024 02:18:58 +0000 Subject: [PATCH 3/4] Add comment and remove support of EOS=2 --- sharktank/sharktank/models/clip/clip.py | 30 ++++++++++--------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py index b3dfb4791..5f611178d 100644 --- a/sharktank/sharktank/models/clip/clip.py +++ b/sharktank/sharktank/models/clip/clip.py @@ -5,6 +5,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# CLIP text model +# It is based on +# https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/clip/modeling_clip.py + from typing import Any, Optional, Tuple, Union import torch from torch import nn @@ -423,24 +427,14 @@ def forward( ) last_hidden_state = self.final_layer_norm(last_hidden_state) - if self.eos_token_id == 2: - # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. - # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added - # ------------------------------------------------------------ - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), - input_ids.argmax(dim=-1), - ] - else: - # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), - # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) - (input_ids == self.eos_token_id).int().argmax(dim=-1), - ] + # We don't support this variant. + assert self.eos_token_id != 2 + + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids == self.eos_token_id).int().argmax(dim=-1), + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] From 0d30d623e67c56a5cfdc849ddf5b136f2fa559d8 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 5 Dec 2024 02:48:07 +0000 Subject: [PATCH 4/4] Put back in EOS=2 case as it actually gets hit --- sharktank/sharktank/models/clip/clip.py | 30 ++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py index 5f611178d..29734e9f1 100644 --- a/sharktank/sharktank/models/clip/clip.py +++ b/sharktank/sharktank/models/clip/clip.py @@ -427,14 +427,28 @@ def forward( ) last_hidden_state = self.final_layer_norm(last_hidden_state) - # We don't support this variant. - assert self.eos_token_id != 2 - - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), - # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) - (input_ids == self.eos_token_id).int().argmax(dim=-1), - ] + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR + # https://github.com/huggingface/transformers/pull/24773 + # Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), + input_ids.argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR + # https://github.com/huggingface/transformers/pull/24773 + # (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids == self.eos_token_id).int().argmax(dim=-1), + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:]