From 36c0859a0d68da7532b3660618783ed2b4b1dc19 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 10 Dec 2024 15:40:49 -0800 Subject: [PATCH] Add exporting and numerics verification for CLIP Large text model with IREE (#664) Add exporting to MLIR and IRRE parameters. We don't make the context length dynamic since the maximum is only 77 anyway, so the token sequences are padded to 77. We could explore later making this dynamic. This adds comparison of IREE execution of float32, bfloat16 model variants against float32 torch eager. For bfloat16 results are close up to 1.43e-2 using cosine similarity. Toy-sized model comparison for float32 and bfloat16 is also provided. --- .../sharktank/layers/configs/llm_configs.py | 32 +- sharktank/sharktank/models/clip/clip.py | 47 ++- sharktank/sharktank/models/clip/export.py | 74 +++- sharktank/sharktank/models/clip/testing.py | 37 ++ sharktank/sharktank/types/theta.py | 4 +- sharktank/sharktank/utils/math.py | 2 +- sharktank/sharktank/utils/testing.py | 31 ++ sharktank/tests/models/clip/clip_test.py | 334 +++++++++++++++--- 8 files changed, 482 insertions(+), 79 deletions(-) create mode 100644 sharktank/sharktank/models/clip/testing.py diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 1513f364a..8a443e6ca 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -18,6 +18,8 @@ from typing import Any, Optional import torch +from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name + __all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @@ -287,9 +289,10 @@ class ClipTextConfig: output_attentions: bool = False output_hidden_states: bool = False use_return_dict: bool = True + dtype: torch.dtype = torch.float32 @staticmethod - def from_transformers_clip_text_config( + def from_hugging_face_clip_text_model_config( config: "transformers.CLIPTextConfig", ) -> "ClipTextConfig": return ClipTextConfig( @@ -308,7 +311,30 @@ def from_transformers_clip_text_config( output_attentions=config.output_attentions, output_hidden_states=config.output_hidden_states, use_return_dict=config.use_return_dict, + dtype=config.torch_dtype or torch.float32, ) - def as_properties(self) -> dict[str, Any]: - return asdict(self) + def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": + kwargs = self.to_properties() + kwargs["torch_dtype"] = kwargs["dtype"] + del kwargs["dtype"] + kwargs["return_dict"] = kwargs["use_return_dict"] + del kwargs["use_return_dict"] + from transformers import CLIPTextConfig + + return CLIPTextConfig(**kwargs) + + @staticmethod + def from_properties(properties: dict[str, Any]) -> "ClipTextConfig": + kwargs = dict(properties) + kwargs.pop("SHARK_DATASET_VERSION") + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"]) + + return ClipTextConfig(**kwargs) + + def to_properties(self) -> dict[str, Any]: + res = asdict(self) + if self.dtype is not None: + res["dtype"] = dtype_to_serialized_name(self.dtype) + return res diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py index 29734e9f1..0593c940b 100644 --- a/sharktank/sharktank/models/clip/clip.py +++ b/sharktank/sharktank/models/clip/clip.py @@ -21,10 +21,10 @@ ) from collections import OrderedDict -from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer +from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer from ... import ops from ...types.theta import Theta, Dataset -from ...types.tensors import DefaultPrimitiveTensor +from ...types.tensors import AnyTensor, DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig from ...layers.activations import ACT2FN @@ -68,11 +68,11 @@ def forward( return embeddings -class ClipAttention(BaseLayer): +class ClipAttention(ThetaLayer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -182,9 +182,9 @@ def forward( return attn_output, attn_weights_reshaped -class ClipMlp(BaseLayer): +class ClipMlp(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = LinearLayer(theta("fc1")) @@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class ClipEncoderLayer(BaseLayer): +class ClipEncoderLayer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.self_attn = ClipAttention(theta=theta("self_attn"), config=config) self.layer_norm1 = LayerNorm( @@ -251,14 +251,14 @@ def forward( return outputs -class ClipEncoder(BaseLayer): +class ClipEncoder(ThetaLayer): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`ClipEncoderLayer`]. """ def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.layers = nn.ModuleList( [ @@ -356,9 +356,9 @@ def forward( ) -class ClipTextTransformer(nn.Module): +class ClipTextTransformer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config embed_dim = config.hidden_size self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config) @@ -475,9 +475,9 @@ def forward( ) -class ClipTextModel(BaseLayer): +class ClipTextModel(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config) @@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]: + input_ids = ( + torch.arange( + start=0, + end=batch_size * self.config.max_position_embeddings, + dtype=torch.long, + ) + % self.config.vocab_size + ) + input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings]) + return OrderedDict( + [ + ( + "input_ids", + input_ids, + ) + ] + ) + def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 83bda2cbe..3cae3f4c4 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -4,54 +4,98 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Union +from typing import Optional, Union import transformers from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) from os import PathLike import torch from ...types.theta import Theta, Dataset, torch_module_to_theta -from ...types.tensors import DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig +from .clip import ClipTextModel +from iree.turbine.aot import FxProgramsBuilder, export -def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta: +def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta: +def hugging_face_clip_encoder_layer_to_theta(model: HfCLIPEncoder) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta: +def hugging_face_clip_encoder_to_theta(model: HfCLIPEncoderLayer) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: +def hugging_face_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_dataset( +def hugging_face_clip_text_model_to_dataset( model: transformers.CLIPTextModel, ) -> Dataset: - config = ClipTextConfig.from_transformers_clip_text_config(model.config) - properties = config.as_properties() - theta = transformers_clip_text_model_to_theta(model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config(model.config) + properties = config.to_properties() + theta = hugging_face_clip_text_model_to_theta(model) theta.rename_tensors_to_paths() return Dataset(properties, theta) +def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: + return Dataset(properties=model.config.to_properties(), root_theta=model.theta) + + def export_clip_text_model_dataset_from_hugging_face( model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], output_path: Union[str, PathLike], + dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): + assert dtype is None model = model_or_name_or_path else: - model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path) - dataset = transformers_clip_text_model_to_dataset(model) + model = transformers.CLIPTextModel.from_pretrained( + model_or_name_or_path, torch_dtype=dtype + ) + dataset = hugging_face_clip_text_model_to_dataset(model) dataset.save(output_path) + + +def export_clip_text_model_mlir( + model: Union[ClipTextModel, PathLike], + batch_sizes: list[int], + mlir_output_path: str, +): + """ + Args: + model: either the torch module or path to GGUF/IRPA. + """ + if not isinstance(model, ClipTextModel): + dataset = Dataset.load(model) + config = ClipTextConfig.from_properties(dataset.properties) + model = ClipTextModel(theta=dataset.root_theta, config=config) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + sample_inputs = model.sample_inputs(batch_size) + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=tuple(sample_inputs.values()), + dynamic_shapes=None, + strict=False, + ) + def _( + model, + input_ids, + ): + return model(input_ids) + + output = export(fxb, import_symbolic_shape_expressions=True) + output.save_mlir(mlir_output_path) diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py new file mode 100644 index 000000000..87634c220 --- /dev/null +++ b/sharktank/sharktank/models/clip/testing.py @@ -0,0 +1,37 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...layers.configs.llm_configs import ClipTextConfig +from ...types.theta import Theta +from .export import hugging_face_clip_text_model_to_theta +import torch + + +def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: + from transformers import CLIPTextConfig as HfCLIPTextConfig + from transformers import CLIPTextModel as HfCLIPTextModel + + hf_config = config.to_hugging_face_clip_text_model_config() + model = HfCLIPTextModel(hf_config) + return hugging_face_clip_text_model_to_theta(model) + + +def make_random_input_token_sequences( + batch_size: int, config: ClipTextConfig +) -> torch.LongTensor: + sequence_lens = torch.randint( + low=1, high=config.max_position_embeddings + 1, size=(batch_size,) + ) + sequences = torch.full( + size=(batch_size, config.max_position_embeddings), + fill_value=config.eos_token_id, + dtype=torch.long, + ) + for batch_idx, l in enumerate(sequence_lens): + sequences[batch_idx][0:l] = torch.randint( + low=0, high=config.vocab_size - 1, size=(l,), dtype=torch.long + ) + return sequences diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 143ede184..021925169 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -214,12 +214,14 @@ def rename_tensors_to_paths(self): def torch_module_to_theta(module: torch.nn.Module) -> Theta: - return Theta( + res = Theta( { name: DefaultPrimitiveTensor(data=param) for name, param in module.named_parameters() } ) + res.rename_tensors_to_paths() + return res def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]: diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index 3723f67dd..639f559d2 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -19,7 +19,7 @@ def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: def cosine_similarity( a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None -) -> float: +) -> torch.Tensor: """Compute cosine similarity over dimensions dim. If dim is none computes over all dimensions.""" dot_product = torch.sum(a * b, dim=dim) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 6c81acf9e..d3cf08fd6 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -18,6 +18,7 @@ import gc from ..types import * +from .math import cosine_similarity # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values @@ -184,6 +185,36 @@ def assert_iterables_equal( ), f"Iterables not equal at index {i} for elements {v1} and {v2}" +def assert_text_encoder_state_close( + actual: torch.Tensor, expected: torch.Tensor, atol: float +): + """The cosine similarity has been suggested to compare encoder states. + + Dehua Peng, Zhipeng Gui, Huayi Wu - + Interpreting the Curse of Dimensionality from Distance Concentration and Manifold + Effect (2023) + + shows that cosine and all Minkowski distances suffer from the curse of + dimensionality. + The cosine similarity ignores the vector magnitudes. We can probably come up with a + better metric, but this is maybe good enough. + + The functions expects that the last dimension is the features per token. + It will compute the cosine similarity for each token. + """ + cosine_similarity_per_token = cosine_similarity( + actual, + expected, + dim=-1, + ) + torch.testing.assert_close( + cosine_similarity_per_token, + torch.ones_like(cosine_similarity_per_token), + atol=atol, + rtol=0, + ) + + SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP" diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 409999797..99af4ba6f 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -4,37 +4,61 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from collections import OrderedDict import functools +import iree.compiler +import os from parameterized import parameterized +from copy import copy import pytest import torch from torch.utils._pytree import tree_map from typing import Optional from unittest import TestCase import transformers -from transformers import CLIPTextModel as TransformersCLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel as HfCLIPTextModel, CLIPTokenizer from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) -from sharktank.types import DefaultPrimitiveTensor +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank.types import ( + DefaultPrimitiveTensor, + dtype_to_serialized_short_name, + Dataset, +) from sharktank.transforms.dataset import set_float_dtype from sharktank.utils.hf_datasets import get_dataset -from sharktank.utils.math import cosine_similarity from sharktank.utils.testing import ( + assert_text_encoder_state_close, make_rand_torch, make_random_mask, TempDirTestBase, test_prompts, ) from sharktank.models.clip.export import ( + export_clip_text_model_mlir, export_clip_text_model_dataset_from_hugging_face, - transformers_clip_attention_to_theta, - transformers_clip_encoder_layer_to_theta, - transformers_clip_encoder_to_theta, - transformers_clip_text_model_to_theta, + hugging_face_clip_attention_to_theta, + hugging_face_clip_encoder_layer_to_theta, + hugging_face_clip_encoder_to_theta, + hugging_face_clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + clip_text_model_to_dataset, +) +from sharktank.models.clip.testing import ( + make_random_input_token_sequences, + make_clip_text_model_random_theta, ) from sharktank.models.clip import ( ClipAttention, @@ -48,21 +72,244 @@ with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')") -@pytest.mark.usefixtures("path_prefix") -class ClipExportTest(TempDirTestBase): +@pytest.mark.usefixtures("caching", "path_prefix") +class ClipTextIreeTest(TempDirTestBase): def setUp(self): super().setUp() + torch.random.manual_seed(12345) if self.path_prefix is None: self.path_prefix = f"{self._temp_dir}/" @with_clip_data def testSmokeExportLargeF32FromHuggingFace(self): - repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + get_dataset( + huggingface_repo_id, + ).download() + target_dtype_name = dtype_to_serialized_short_name(torch.float32) + target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_text_model_{target_dtype_name}" + output_path = f"{target_model_path_prefix}.irpa" + export_clip_text_model_dataset_from_hugging_face( + huggingface_repo_id, output_path + ) + + @with_clip_data + def testCompareLargeIreeF32AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-5, + ) + + @with_clip_data + def testCompareLargeIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + # The observed error is 1.43e-2. We leave a bit of margin. + atol=3e-3, + ) + + @with_clip_data + def testCompareToyModelIreeF32AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-5 + ) + + @with_clip_data + def testCompareToyModelIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1e-3 + ) + + @torch.no_grad() + def runTestCompareIreeAgainstTorchEagerWithInputTokens( + self, + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + atol: float, + file_artifact_prefix_name: str, + ): + reference_dtype_name = dtype_to_serialized_short_name( + reference_model.config.dtype + ) + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + reference_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{reference_dtype_name}" + ) + target_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" + ) + + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + reference_dataset = clip_text_model_to_dataset(reference_model) + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_config.dtype) + ), + properties=target_config.to_properties(), + ) + + parameters_path = f"{target_model_path_prefix}.irpa" + if not self.caching or not os.path.exists(parameters_path): + target_dataset.save(parameters_path) + + dataset = Dataset.load(parameters_path) + target_config = ClipTextConfig.from_properties(dataset.properties) + input_args = OrderedDict([("input_ids", input_ids)]) + batch_size = input_ids.shape[0] + + mlir_path = f"{target_model_path_prefix}.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_clip_text_model_mlir( + parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + ) + iree_module_path = f"{target_model_path_prefix}.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) + + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{reference_model_path_prefix}_torch_", + ) + expected_outputs = flatten_for_iree_signature(reference_result_dict) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), devices=iree_devices + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{target_model_path_prefix}_iree_", + ) + ) + actual_outputs = [ + ops.to(iree_result[i], dtype=expected_outputs[i].dtype) + for i in range(len(expected_outputs)) + ] + + actual_last_hidden_states = actual_outputs[0] + expected_last_hidden_states = expected_outputs[0] + + assert_text_encoder_state_close( + actual_last_hidden_states, expected_last_hidden_states, atol + ) + + def runTestCompareRandomModelIreeAgainstTorch( + self, + reference_config: ClipTextConfig, + target_dtype: torch.dtype, + batch_size: int, + atol: float, + file_artifact_prefix_name: str, + ): + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareToyModelIreeAgainstTorch( + self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float + ): + batch_size = 4 + num_attention_heads = 5 + vocab_size = 11 + reference_config = ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=reference_dtype, + ) + file_artifact_prefix_name = "clip_text_model_toy" + self.runTestCompareRandomModelIreeAgainstTorch( + reference_config=reference_config, + target_dtype=target_dtype, + batch_size=batch_size, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareIreeAgainstPretrainedTorchEager( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + ): get_dataset( - repo_id, + huggingface_repo_id, ).download() - output_path = f"{self.path_prefix}{repo_id.replace('/', '--')}.irpa" - export_clip_text_model_dataset_from_hugging_face(repo_id, output_path) + + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + file_artifact_prefix_name = f"{huggingface_repo_id_as_path}_text_model" + + hf_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + reference_dataset = hugging_face_clip_text_model_to_dataset(hf_model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + hf_model.config + ) + reference_model = ClipTextModel( + theta=reference_dataset.root_theta, config=config + ) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(huggingface_repo_id) + input_ids = tokenizer( + test_prompts, + truncation=True, + max_length=reference_model.config.max_position_embeddings, + padding="max_length", + return_tensors="pt", + )["input_ids"] + + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) @pytest.mark.usefixtures("get_model_artifacts") @@ -70,7 +317,6 @@ class ClipTextEagerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() def runTestCompareTorchEagerAgainstHuggingFace( self, @@ -86,16 +332,14 @@ def runTestCompareTorchEagerAgainstHuggingFace( huggingface_repo_id, ).download() - reference_model: TransformersCLIPTextModel = ( - TransformersCLIPTextModel.from_pretrained( - huggingface_repo_id, torch_dtype=reference_dtype - ) + reference_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype ) - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config( + config = ClipTextConfig.from_hugging_face_clip_text_model_config( reference_model.config ) model = ClipTextModel(theta, config) @@ -119,16 +363,10 @@ def runTestCompareTorchEagerAgainstHuggingFace( actual_outputs, ) - cosine_similarity_per_token = cosine_similarity( + assert_text_encoder_state_close( actual_outputs["last_hidden_state"], expected_outputs["last_hidden_state"], - dim=-1, - ) - torch.testing.assert_close( - cosine_similarity_per_token, - torch.ones_like(cosine_similarity_per_token), atol=atol, - rtol=0, ) @with_clip_data @@ -146,6 +384,7 @@ def testLargeCompareTorchEagerBf16AgainstHuggingFaceF32(self): "openai/clip-vit-large-patch14", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 3.66e-4. We leave a bit of margin. atol=1e-3, ) @@ -180,15 +419,17 @@ def testCompareEagerToySizedModelAgainstTransformers( bos_token_id=vocab_size - 2, eos_token_id=vocab_size - 1, ) - reference_model = TransformersCLIPTextModel( + reference_model = HfCLIPTextModel( reference_config, ) reference_model.eval() - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipTextModel(theta, config) input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len]) @@ -210,7 +451,6 @@ class ClipAttentionTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -241,15 +481,17 @@ def testCompareEagerToySizedModelAgainstTransformers( projection_dim=3, num_attention_heads=num_attention_heads, ) - reference_model = TransformersCLIPAttention( + reference_model = HfCLIPAttention( reference_config, ) reference_model.eval() - theta = transformers_clip_attention_to_theta(reference_model) + theta = hugging_face_clip_attention_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipAttention(theta, config) reference_hidden_states = make_rand_torch( @@ -292,7 +534,6 @@ class ClipEncoderLayerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -321,15 +562,17 @@ def testCompareEagerToySizedModelAgainstTransformers( num_attention_heads=num_attention_heads, layer_norm_eps=1e-4, ) - reference_model = TransformersCLIPEncoderLayer( + reference_model = HfCLIPEncoderLayer( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_layer_to_theta(reference_model) + theta = hugging_face_clip_encoder_layer_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoderLayer(theta, config) reference_hidden_states = make_rand_torch( @@ -372,7 +615,6 @@ class ClipEncoderTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -402,15 +644,17 @@ def testCompareEagerToySizedModelAgainstTransformers( layer_norm_eps=1e-4, num_hidden_layers=2, ) - reference_model = TransformersCLIPEncoder( + reference_model = HfCLIPEncoder( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_to_theta(reference_model) + theta = hugging_face_clip_encoder_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoder(theta, config) reference_inputs_embeds = make_rand_torch(