From 36c0859a0d68da7532b3660618783ed2b4b1dc19 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 10 Dec 2024 15:40:49 -0800 Subject: [PATCH 1/4] Add exporting and numerics verification for CLIP Large text model with IREE (#664) Add exporting to MLIR and IRRE parameters. We don't make the context length dynamic since the maximum is only 77 anyway, so the token sequences are padded to 77. We could explore later making this dynamic. This adds comparison of IREE execution of float32, bfloat16 model variants against float32 torch eager. For bfloat16 results are close up to 1.43e-2 using cosine similarity. Toy-sized model comparison for float32 and bfloat16 is also provided. --- .../sharktank/layers/configs/llm_configs.py | 32 +- sharktank/sharktank/models/clip/clip.py | 47 ++- sharktank/sharktank/models/clip/export.py | 74 +++- sharktank/sharktank/models/clip/testing.py | 37 ++ sharktank/sharktank/types/theta.py | 4 +- sharktank/sharktank/utils/math.py | 2 +- sharktank/sharktank/utils/testing.py | 31 ++ sharktank/tests/models/clip/clip_test.py | 334 +++++++++++++++--- 8 files changed, 482 insertions(+), 79 deletions(-) create mode 100644 sharktank/sharktank/models/clip/testing.py diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 1513f364a..8a443e6ca 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -18,6 +18,8 @@ from typing import Any, Optional import torch +from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name + __all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @@ -287,9 +289,10 @@ class ClipTextConfig: output_attentions: bool = False output_hidden_states: bool = False use_return_dict: bool = True + dtype: torch.dtype = torch.float32 @staticmethod - def from_transformers_clip_text_config( + def from_hugging_face_clip_text_model_config( config: "transformers.CLIPTextConfig", ) -> "ClipTextConfig": return ClipTextConfig( @@ -308,7 +311,30 @@ def from_transformers_clip_text_config( output_attentions=config.output_attentions, output_hidden_states=config.output_hidden_states, use_return_dict=config.use_return_dict, + dtype=config.torch_dtype or torch.float32, ) - def as_properties(self) -> dict[str, Any]: - return asdict(self) + def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": + kwargs = self.to_properties() + kwargs["torch_dtype"] = kwargs["dtype"] + del kwargs["dtype"] + kwargs["return_dict"] = kwargs["use_return_dict"] + del kwargs["use_return_dict"] + from transformers import CLIPTextConfig + + return CLIPTextConfig(**kwargs) + + @staticmethod + def from_properties(properties: dict[str, Any]) -> "ClipTextConfig": + kwargs = dict(properties) + kwargs.pop("SHARK_DATASET_VERSION") + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"]) + + return ClipTextConfig(**kwargs) + + def to_properties(self) -> dict[str, Any]: + res = asdict(self) + if self.dtype is not None: + res["dtype"] = dtype_to_serialized_name(self.dtype) + return res diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py index 29734e9f1..0593c940b 100644 --- a/sharktank/sharktank/models/clip/clip.py +++ b/sharktank/sharktank/models/clip/clip.py @@ -21,10 +21,10 @@ ) from collections import OrderedDict -from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer +from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer from ... import ops from ...types.theta import Theta, Dataset -from ...types.tensors import DefaultPrimitiveTensor +from ...types.tensors import AnyTensor, DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig from ...layers.activations import ACT2FN @@ -68,11 +68,11 @@ def forward( return embeddings -class ClipAttention(BaseLayer): +class ClipAttention(ThetaLayer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -182,9 +182,9 @@ def forward( return attn_output, attn_weights_reshaped -class ClipMlp(BaseLayer): +class ClipMlp(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = LinearLayer(theta("fc1")) @@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class ClipEncoderLayer(BaseLayer): +class ClipEncoderLayer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.self_attn = ClipAttention(theta=theta("self_attn"), config=config) self.layer_norm1 = LayerNorm( @@ -251,14 +251,14 @@ def forward( return outputs -class ClipEncoder(BaseLayer): +class ClipEncoder(ThetaLayer): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`ClipEncoderLayer`]. """ def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.layers = nn.ModuleList( [ @@ -356,9 +356,9 @@ def forward( ) -class ClipTextTransformer(nn.Module): +class ClipTextTransformer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config embed_dim = config.hidden_size self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config) @@ -475,9 +475,9 @@ def forward( ) -class ClipTextModel(BaseLayer): +class ClipTextModel(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config) @@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]: + input_ids = ( + torch.arange( + start=0, + end=batch_size * self.config.max_position_embeddings, + dtype=torch.long, + ) + % self.config.vocab_size + ) + input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings]) + return OrderedDict( + [ + ( + "input_ids", + input_ids, + ) + ] + ) + def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 83bda2cbe..3cae3f4c4 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -4,54 +4,98 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Union +from typing import Optional, Union import transformers from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) from os import PathLike import torch from ...types.theta import Theta, Dataset, torch_module_to_theta -from ...types.tensors import DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig +from .clip import ClipTextModel +from iree.turbine.aot import FxProgramsBuilder, export -def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta: +def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta: +def hugging_face_clip_encoder_layer_to_theta(model: HfCLIPEncoder) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta: +def hugging_face_clip_encoder_to_theta(model: HfCLIPEncoderLayer) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: +def hugging_face_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_dataset( +def hugging_face_clip_text_model_to_dataset( model: transformers.CLIPTextModel, ) -> Dataset: - config = ClipTextConfig.from_transformers_clip_text_config(model.config) - properties = config.as_properties() - theta = transformers_clip_text_model_to_theta(model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config(model.config) + properties = config.to_properties() + theta = hugging_face_clip_text_model_to_theta(model) theta.rename_tensors_to_paths() return Dataset(properties, theta) +def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: + return Dataset(properties=model.config.to_properties(), root_theta=model.theta) + + def export_clip_text_model_dataset_from_hugging_face( model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], output_path: Union[str, PathLike], + dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): + assert dtype is None model = model_or_name_or_path else: - model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path) - dataset = transformers_clip_text_model_to_dataset(model) + model = transformers.CLIPTextModel.from_pretrained( + model_or_name_or_path, torch_dtype=dtype + ) + dataset = hugging_face_clip_text_model_to_dataset(model) dataset.save(output_path) + + +def export_clip_text_model_mlir( + model: Union[ClipTextModel, PathLike], + batch_sizes: list[int], + mlir_output_path: str, +): + """ + Args: + model: either the torch module or path to GGUF/IRPA. + """ + if not isinstance(model, ClipTextModel): + dataset = Dataset.load(model) + config = ClipTextConfig.from_properties(dataset.properties) + model = ClipTextModel(theta=dataset.root_theta, config=config) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + sample_inputs = model.sample_inputs(batch_size) + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=tuple(sample_inputs.values()), + dynamic_shapes=None, + strict=False, + ) + def _( + model, + input_ids, + ): + return model(input_ids) + + output = export(fxb, import_symbolic_shape_expressions=True) + output.save_mlir(mlir_output_path) diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py new file mode 100644 index 000000000..87634c220 --- /dev/null +++ b/sharktank/sharktank/models/clip/testing.py @@ -0,0 +1,37 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...layers.configs.llm_configs import ClipTextConfig +from ...types.theta import Theta +from .export import hugging_face_clip_text_model_to_theta +import torch + + +def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: + from transformers import CLIPTextConfig as HfCLIPTextConfig + from transformers import CLIPTextModel as HfCLIPTextModel + + hf_config = config.to_hugging_face_clip_text_model_config() + model = HfCLIPTextModel(hf_config) + return hugging_face_clip_text_model_to_theta(model) + + +def make_random_input_token_sequences( + batch_size: int, config: ClipTextConfig +) -> torch.LongTensor: + sequence_lens = torch.randint( + low=1, high=config.max_position_embeddings + 1, size=(batch_size,) + ) + sequences = torch.full( + size=(batch_size, config.max_position_embeddings), + fill_value=config.eos_token_id, + dtype=torch.long, + ) + for batch_idx, l in enumerate(sequence_lens): + sequences[batch_idx][0:l] = torch.randint( + low=0, high=config.vocab_size - 1, size=(l,), dtype=torch.long + ) + return sequences diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 143ede184..021925169 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -214,12 +214,14 @@ def rename_tensors_to_paths(self): def torch_module_to_theta(module: torch.nn.Module) -> Theta: - return Theta( + res = Theta( { name: DefaultPrimitiveTensor(data=param) for name, param in module.named_parameters() } ) + res.rename_tensors_to_paths() + return res def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]: diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index 3723f67dd..639f559d2 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -19,7 +19,7 @@ def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: def cosine_similarity( a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None -) -> float: +) -> torch.Tensor: """Compute cosine similarity over dimensions dim. If dim is none computes over all dimensions.""" dot_product = torch.sum(a * b, dim=dim) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 6c81acf9e..d3cf08fd6 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -18,6 +18,7 @@ import gc from ..types import * +from .math import cosine_similarity # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values @@ -184,6 +185,36 @@ def assert_iterables_equal( ), f"Iterables not equal at index {i} for elements {v1} and {v2}" +def assert_text_encoder_state_close( + actual: torch.Tensor, expected: torch.Tensor, atol: float +): + """The cosine similarity has been suggested to compare encoder states. + + Dehua Peng, Zhipeng Gui, Huayi Wu - + Interpreting the Curse of Dimensionality from Distance Concentration and Manifold + Effect (2023) + + shows that cosine and all Minkowski distances suffer from the curse of + dimensionality. + The cosine similarity ignores the vector magnitudes. We can probably come up with a + better metric, but this is maybe good enough. + + The functions expects that the last dimension is the features per token. + It will compute the cosine similarity for each token. + """ + cosine_similarity_per_token = cosine_similarity( + actual, + expected, + dim=-1, + ) + torch.testing.assert_close( + cosine_similarity_per_token, + torch.ones_like(cosine_similarity_per_token), + atol=atol, + rtol=0, + ) + + SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP" diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 409999797..99af4ba6f 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -4,37 +4,61 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from collections import OrderedDict import functools +import iree.compiler +import os from parameterized import parameterized +from copy import copy import pytest import torch from torch.utils._pytree import tree_map from typing import Optional from unittest import TestCase import transformers -from transformers import CLIPTextModel as TransformersCLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel as HfCLIPTextModel, CLIPTokenizer from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) -from sharktank.types import DefaultPrimitiveTensor +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank.types import ( + DefaultPrimitiveTensor, + dtype_to_serialized_short_name, + Dataset, +) from sharktank.transforms.dataset import set_float_dtype from sharktank.utils.hf_datasets import get_dataset -from sharktank.utils.math import cosine_similarity from sharktank.utils.testing import ( + assert_text_encoder_state_close, make_rand_torch, make_random_mask, TempDirTestBase, test_prompts, ) from sharktank.models.clip.export import ( + export_clip_text_model_mlir, export_clip_text_model_dataset_from_hugging_face, - transformers_clip_attention_to_theta, - transformers_clip_encoder_layer_to_theta, - transformers_clip_encoder_to_theta, - transformers_clip_text_model_to_theta, + hugging_face_clip_attention_to_theta, + hugging_face_clip_encoder_layer_to_theta, + hugging_face_clip_encoder_to_theta, + hugging_face_clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + clip_text_model_to_dataset, +) +from sharktank.models.clip.testing import ( + make_random_input_token_sequences, + make_clip_text_model_random_theta, ) from sharktank.models.clip import ( ClipAttention, @@ -48,21 +72,244 @@ with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')") -@pytest.mark.usefixtures("path_prefix") -class ClipExportTest(TempDirTestBase): +@pytest.mark.usefixtures("caching", "path_prefix") +class ClipTextIreeTest(TempDirTestBase): def setUp(self): super().setUp() + torch.random.manual_seed(12345) if self.path_prefix is None: self.path_prefix = f"{self._temp_dir}/" @with_clip_data def testSmokeExportLargeF32FromHuggingFace(self): - repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + get_dataset( + huggingface_repo_id, + ).download() + target_dtype_name = dtype_to_serialized_short_name(torch.float32) + target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_text_model_{target_dtype_name}" + output_path = f"{target_model_path_prefix}.irpa" + export_clip_text_model_dataset_from_hugging_face( + huggingface_repo_id, output_path + ) + + @with_clip_data + def testCompareLargeIreeF32AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-5, + ) + + @with_clip_data + def testCompareLargeIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + # The observed error is 1.43e-2. We leave a bit of margin. + atol=3e-3, + ) + + @with_clip_data + def testCompareToyModelIreeF32AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-5 + ) + + @with_clip_data + def testCompareToyModelIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1e-3 + ) + + @torch.no_grad() + def runTestCompareIreeAgainstTorchEagerWithInputTokens( + self, + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + atol: float, + file_artifact_prefix_name: str, + ): + reference_dtype_name = dtype_to_serialized_short_name( + reference_model.config.dtype + ) + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + reference_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{reference_dtype_name}" + ) + target_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" + ) + + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + reference_dataset = clip_text_model_to_dataset(reference_model) + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_config.dtype) + ), + properties=target_config.to_properties(), + ) + + parameters_path = f"{target_model_path_prefix}.irpa" + if not self.caching or not os.path.exists(parameters_path): + target_dataset.save(parameters_path) + + dataset = Dataset.load(parameters_path) + target_config = ClipTextConfig.from_properties(dataset.properties) + input_args = OrderedDict([("input_ids", input_ids)]) + batch_size = input_ids.shape[0] + + mlir_path = f"{target_model_path_prefix}.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_clip_text_model_mlir( + parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + ) + iree_module_path = f"{target_model_path_prefix}.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) + + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{reference_model_path_prefix}_torch_", + ) + expected_outputs = flatten_for_iree_signature(reference_result_dict) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), devices=iree_devices + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{target_model_path_prefix}_iree_", + ) + ) + actual_outputs = [ + ops.to(iree_result[i], dtype=expected_outputs[i].dtype) + for i in range(len(expected_outputs)) + ] + + actual_last_hidden_states = actual_outputs[0] + expected_last_hidden_states = expected_outputs[0] + + assert_text_encoder_state_close( + actual_last_hidden_states, expected_last_hidden_states, atol + ) + + def runTestCompareRandomModelIreeAgainstTorch( + self, + reference_config: ClipTextConfig, + target_dtype: torch.dtype, + batch_size: int, + atol: float, + file_artifact_prefix_name: str, + ): + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareToyModelIreeAgainstTorch( + self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float + ): + batch_size = 4 + num_attention_heads = 5 + vocab_size = 11 + reference_config = ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=reference_dtype, + ) + file_artifact_prefix_name = "clip_text_model_toy" + self.runTestCompareRandomModelIreeAgainstTorch( + reference_config=reference_config, + target_dtype=target_dtype, + batch_size=batch_size, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareIreeAgainstPretrainedTorchEager( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + ): get_dataset( - repo_id, + huggingface_repo_id, ).download() - output_path = f"{self.path_prefix}{repo_id.replace('/', '--')}.irpa" - export_clip_text_model_dataset_from_hugging_face(repo_id, output_path) + + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + file_artifact_prefix_name = f"{huggingface_repo_id_as_path}_text_model" + + hf_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + reference_dataset = hugging_face_clip_text_model_to_dataset(hf_model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + hf_model.config + ) + reference_model = ClipTextModel( + theta=reference_dataset.root_theta, config=config + ) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(huggingface_repo_id) + input_ids = tokenizer( + test_prompts, + truncation=True, + max_length=reference_model.config.max_position_embeddings, + padding="max_length", + return_tensors="pt", + )["input_ids"] + + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) @pytest.mark.usefixtures("get_model_artifacts") @@ -70,7 +317,6 @@ class ClipTextEagerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() def runTestCompareTorchEagerAgainstHuggingFace( self, @@ -86,16 +332,14 @@ def runTestCompareTorchEagerAgainstHuggingFace( huggingface_repo_id, ).download() - reference_model: TransformersCLIPTextModel = ( - TransformersCLIPTextModel.from_pretrained( - huggingface_repo_id, torch_dtype=reference_dtype - ) + reference_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype ) - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config( + config = ClipTextConfig.from_hugging_face_clip_text_model_config( reference_model.config ) model = ClipTextModel(theta, config) @@ -119,16 +363,10 @@ def runTestCompareTorchEagerAgainstHuggingFace( actual_outputs, ) - cosine_similarity_per_token = cosine_similarity( + assert_text_encoder_state_close( actual_outputs["last_hidden_state"], expected_outputs["last_hidden_state"], - dim=-1, - ) - torch.testing.assert_close( - cosine_similarity_per_token, - torch.ones_like(cosine_similarity_per_token), atol=atol, - rtol=0, ) @with_clip_data @@ -146,6 +384,7 @@ def testLargeCompareTorchEagerBf16AgainstHuggingFaceF32(self): "openai/clip-vit-large-patch14", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 3.66e-4. We leave a bit of margin. atol=1e-3, ) @@ -180,15 +419,17 @@ def testCompareEagerToySizedModelAgainstTransformers( bos_token_id=vocab_size - 2, eos_token_id=vocab_size - 1, ) - reference_model = TransformersCLIPTextModel( + reference_model = HfCLIPTextModel( reference_config, ) reference_model.eval() - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipTextModel(theta, config) input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len]) @@ -210,7 +451,6 @@ class ClipAttentionTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -241,15 +481,17 @@ def testCompareEagerToySizedModelAgainstTransformers( projection_dim=3, num_attention_heads=num_attention_heads, ) - reference_model = TransformersCLIPAttention( + reference_model = HfCLIPAttention( reference_config, ) reference_model.eval() - theta = transformers_clip_attention_to_theta(reference_model) + theta = hugging_face_clip_attention_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipAttention(theta, config) reference_hidden_states = make_rand_torch( @@ -292,7 +534,6 @@ class ClipEncoderLayerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -321,15 +562,17 @@ def testCompareEagerToySizedModelAgainstTransformers( num_attention_heads=num_attention_heads, layer_norm_eps=1e-4, ) - reference_model = TransformersCLIPEncoderLayer( + reference_model = HfCLIPEncoderLayer( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_layer_to_theta(reference_model) + theta = hugging_face_clip_encoder_layer_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoderLayer(theta, config) reference_hidden_states = make_rand_torch( @@ -372,7 +615,6 @@ class ClipEncoderTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -402,15 +644,17 @@ def testCompareEagerToySizedModelAgainstTransformers( layer_norm_eps=1e-4, num_hidden_layers=2, ) - reference_model = TransformersCLIPEncoder( + reference_model = HfCLIPEncoder( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_to_theta(reference_model) + theta = hugging_face_clip_encoder_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoder(theta, config) reference_inputs_embeds = make_rand_torch( From 63edf36feb76912ba6f5a4a64cac7d59d4e734f0 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 11 Dec 2024 03:47:47 -0500 Subject: [PATCH 2/4] Make config.json consistent between shortfin and sharktank (#487) And remove the adaption layer in buidl_tools/integration_tests/llm/conftest.py --- .../llm/sglang_benchmarks/conftest.py | 38 +++++++++---------- .../integration_tests/llm/sglang/conftest.py | 15 -------- .../llm/shortfin/conftest.py | 34 +++++++++-------- .../sharktank/examples/export_paged_llm_v1.py | 19 ++++++++-- .../llm/components/config_struct.py | 33 ++++++++++------ 5 files changed, 75 insertions(+), 64 deletions(-) diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index 27bfddfa2..1a2633b0e 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -68,35 +68,35 @@ def write_config(request, pre_process_model): batch_sizes = request.param["batch_sizes"] prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] - logger.info("Writing config file..." + start_log_group("Writing config file")) - + # Construct the new config filename config_path = ( pre_process_model / f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json" ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 32, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256, - "prefix_sharing_algorithm": prefix_sharing_algorithm, - }, - } + # Read the base config file + base_config_path = pre_process_model / "config.json" + with open(base_config_path, "r") as f: + config = json.load(f) + + # Override specific fields + config.update( + { + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "paged_kv_cache": { + **config.get( + "paged_kv_cache", {} + ), # Preserve other paged_kv_cache settings + "prefix_sharing_algorithm": prefix_sharing_algorithm, + }, + } + ) logger.info(f"Saving edited config to: {config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") with open(config_path, "w") as f: json.dump(config, f) - - logger.info("Config file successfully written" + end_log_group()) yield config_path diff --git a/app_tests/integration_tests/llm/sglang/conftest.py b/app_tests/integration_tests/llm/sglang/conftest.py index 8543708da..cc79fc365 100644 --- a/app_tests/integration_tests/llm/sglang/conftest.py +++ b/app_tests/integration_tests/llm/sglang/conftest.py @@ -64,21 +64,6 @@ def pre_process_model(request, tmp_path_factory): device_settings, ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": [1, 4], - "decode_batch_sizes": [1, 4], - "transformer_block_count": 32, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - config_path = tmp_dir / "config.json" - with open(config_path, "w") as f: - json.dump(config, f) - return tmp_dir diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 42e541506..55c9e8bdc 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -87,26 +87,30 @@ def write_config(request, model_test_dir): batch_sizes = request.param["batch_sizes"] prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] + # Construct the new config filename config_path = ( model_test_dir / f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json" ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256, - "prefix_sharing_algorithm": prefix_sharing_algorithm, - }, - } + # Read the base config file + base_config_path = model_test_dir / "config.json" + with open(base_config_path, "r") as f: + config = json.load(f) + + # Override specific fields + config.update( + { + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "paged_kv_cache": { + **config.get( + "paged_kv_cache", {} + ), # Preserve other paged_kv_cache settings + "prefix_sharing_algorithm": prefix_sharing_algorithm, + }, + } + ) logger.info(f"Saving edited config to: {config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") with open(config_path, "w") as f: diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 6dd9785c3..900c1a9ae 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -7,6 +7,7 @@ """Export support for the PagedLLMV1 protocol of models.""" import json +from typing import Any, Dict import torch from iree.turbine.aot import * @@ -86,17 +87,29 @@ def main(): else: model = PagedLlamaModelV1(dataset.root_theta, llama_config) - def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + def generate_params_json( + hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int] + ) -> Dict[str, Any]: + """ + Generate config.json for shortfin. + + + For shortfin, we only write attention_head_count_kv because that's all shortfin needs. + Note that this is different from hp.attn_head_count when grouped attention shares kvcache between heads. + """ return { "module_name": "module", "module_abi_version": 1, "max_seq_len": hp.context_length, - "attn_head_count": hp.attention_head_count, "attn_head_dim": hp.attn_head_dim, "prefill_batch_sizes": prefill_bs, "decode_batch_sizes": decode_bs, "transformer_block_count": hp.block_count, - "block_seq_stride": llama_config.block_seq_stride, + "paged_kv_cache": { + "attention_head_count_kv": hp.attention_head_count_kv, + "block_seq_stride": llama_config.block_seq_stride, + "device_block_count": 256, # so that this makes its way into the config file & can be edited. + }, } # Unrolling cache updates by batch row makes dynamo sad without an diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 7caed5d07..8fefa0a12 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -11,20 +11,20 @@ In a typical transformer model, the KV cache is organized similar to (mapped to our parameter names below): k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) + attn_head_count_kv, attn_head_dim) v = ... For context, a popular model has parameters of: attn_dtype_size = 2 # (fp16) max_seq_len = 2048 transformer_block_count = 32 - attn_head_count = 32 + attn_head_count_kv = 32 attn_head_dim = 128 # (dim / head_count) If paging, then we primarily care about the organization of a single block, where a block represents a single position in the sequence for a single item in the batch. Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) + block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) In this scenario, we declare that one block holds the KV cache for all transformer block layers because it reduces the accounting. As such, for the above example, @@ -80,10 +80,15 @@ def _decode_dtype(name: str) -> sfnp.DType: class PagedKVCacheParams: """Parameters for the paged KV cache.""" - # Position stride per attention block + # Tokens per page. block_seq_stride: int + # Number of attention heads per block. This can be different from the model's + # attention head count due to sharing. + attention_head_count_kv: int + # Size of the cache on each device. + # Default: 256 device_block_count: int prefix_sharing_algorithm: str = "none" # currently supporting none and trie @@ -92,19 +97,23 @@ class PagedKVCacheParams: @dataclass_json(undefined=Undefined.RAISE) @dataclass class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" + """ + Parameters for a specific compiled model, sufficient to do cache planning and + invocations. + + Compatibility should be maintained with function generate_params_json in + + sharktank/sharktank/examples/export_paged_llm_v1.py + """ # Maximum length of a sequence including prompt and output. max_seq_len: int - # Number of transformer blocks. + # Number of transformer layers (aka attention blocks / transformer blocks). transformer_block_count: int - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head + # Dimensionality of each attention head. This is the dimensionality of the + # key and value vectors. AKA rope_dimension_count from the GGUF props. attn_head_dim: int # Batch sizes that the prefill stage is compiled for. These are expected to be @@ -159,7 +168,7 @@ def paged_kv_unit_size_elements(self) -> int: size = 1 size *= self.transformer_block_count size *= 2 # K and V cache line - size *= self.attn_head_count + size *= self.paged_kv_cache.attention_head_count_kv size *= self.attn_head_dim return size From 690274a4db2e272458653afb1be99badadb7d121 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 11 Dec 2024 17:03:55 +0100 Subject: [PATCH 3/4] Fix install instructions for nightly releases (#675) Without `--pre`, stable releases get pulled. --- docs/nightly_releases.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md index a7c182cd4..1fbc20bde 100644 --- a/docs/nightly_releases.md +++ b/docs/nightly_releases.md @@ -58,7 +58,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'sharktank' package from nightly releases. -pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels --pre # Test the installation. python -c "from sharktank import ops; print('Sanity check passed')" @@ -75,7 +75,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'shortfin' package from nightly releases. -pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels --pre # Test the installation. python -c "import shortfin as sf; print('Sanity check passed')" From 6c62ed1f151d0a1aed2663ab48452626175a0998 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 11 Dec 2024 17:05:13 +0100 Subject: [PATCH 4/4] Pin OS to specific versions (#673) With this, runner images are pinned to specific versions instead of referring to `latest`. This allows to have a controlled upgrade path instead instead of just waiting for when a change in the reference affects the workflows. --- .github/workflows/ci-sharktank.yml | 2 +- .github/workflows/ci-tuner.yml | 2 +- .github/workflows/pre-commit.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 4fdc5775a..f169eb2b2 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -27,7 +27,7 @@ jobs: strategy: matrix: version: [3.11] - os: [ubuntu-latest, windows-latest] + os: [ubuntu-24.04, windows-2022] fail-fast: false runs-on: ${{matrix.os}} defaults: diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 82c5f3514..64bbc9a57 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -28,7 +28,7 @@ permissions: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 8ec1e8d55..07d89e0e8 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -7,7 +7,7 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0