Skip to content

Commit 8109f39

Browse files
sogartarIanNod
authored andcommitted
Add exporting and numerics verification for CLIP Large text model with IREE (nod-ai#664)
Add exporting to MLIR and IRRE parameters. We don't make the context length dynamic since the maximum is only 77 anyway, so the token sequences are padded to 77. We could explore later making this dynamic. This adds comparison of IREE execution of float32, bfloat16 model variants against float32 torch eager. For bfloat16 results are close up to 1.43e-2 using cosine similarity. Toy-sized model comparison for float32 and bfloat16 is also provided.
1 parent 0c890c4 commit 8109f39

File tree

8 files changed

+482
-79
lines changed

8 files changed

+482
-79
lines changed

sharktank/sharktank/layers/configs/llm_configs.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import Any, Optional
1919
import torch
2020

21+
from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name
22+
2123
__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]
2224

2325

@@ -287,9 +289,10 @@ class ClipTextConfig:
287289
output_attentions: bool = False
288290
output_hidden_states: bool = False
289291
use_return_dict: bool = True
292+
dtype: torch.dtype = torch.float32
290293

291294
@staticmethod
292-
def from_transformers_clip_text_config(
295+
def from_hugging_face_clip_text_model_config(
293296
config: "transformers.CLIPTextConfig",
294297
) -> "ClipTextConfig":
295298
return ClipTextConfig(
@@ -308,7 +311,30 @@ def from_transformers_clip_text_config(
308311
output_attentions=config.output_attentions,
309312
output_hidden_states=config.output_hidden_states,
310313
use_return_dict=config.use_return_dict,
314+
dtype=config.torch_dtype or torch.float32,
311315
)
312316

313-
def as_properties(self) -> dict[str, Any]:
314-
return asdict(self)
317+
def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig":
318+
kwargs = self.to_properties()
319+
kwargs["torch_dtype"] = kwargs["dtype"]
320+
del kwargs["dtype"]
321+
kwargs["return_dict"] = kwargs["use_return_dict"]
322+
del kwargs["use_return_dict"]
323+
from transformers import CLIPTextConfig
324+
325+
return CLIPTextConfig(**kwargs)
326+
327+
@staticmethod
328+
def from_properties(properties: dict[str, Any]) -> "ClipTextConfig":
329+
kwargs = dict(properties)
330+
kwargs.pop("SHARK_DATASET_VERSION")
331+
if "dtype" in kwargs and kwargs["dtype"] is not None:
332+
kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"])
333+
334+
return ClipTextConfig(**kwargs)
335+
336+
def to_properties(self) -> dict[str, Any]:
337+
res = asdict(self)
338+
if self.dtype is not None:
339+
res["dtype"] = dtype_to_serialized_name(self.dtype)
340+
return res

sharktank/sharktank/models/clip/clip.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
)
2222
from collections import OrderedDict
2323

24-
from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
24+
from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
2525
from ... import ops
2626
from ...types.theta import Theta, Dataset
27-
from ...types.tensors import DefaultPrimitiveTensor
27+
from ...types.tensors import AnyTensor, DefaultPrimitiveTensor
2828
from ...layers.configs import ClipTextConfig
2929
from ...layers.activations import ACT2FN
3030

@@ -68,11 +68,11 @@ def forward(
6868
return embeddings
6969

7070

71-
class ClipAttention(BaseLayer):
71+
class ClipAttention(ThetaLayer):
7272
"""Multi-headed attention from 'Attention Is All You Need' paper"""
7373

7474
def __init__(self, theta: Theta, config: ClipTextConfig):
75-
super().__init__()
75+
super().__init__(theta)
7676
self.embed_dim = config.hidden_size
7777
self.num_heads = config.num_attention_heads
7878
self.head_dim = self.embed_dim // self.num_heads
@@ -182,9 +182,9 @@ def forward(
182182
return attn_output, attn_weights_reshaped
183183

184184

185-
class ClipMlp(BaseLayer):
185+
class ClipMlp(ThetaLayer):
186186
def __init__(self, theta: Theta, config: ClipTextConfig):
187-
super().__init__()
187+
super().__init__(theta)
188188
self.config = config
189189
self.activation_fn = ACT2FN[config.hidden_act]
190190
self.fc1 = LinearLayer(theta("fc1"))
@@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
197197
return hidden_states
198198

199199

200-
class ClipEncoderLayer(BaseLayer):
200+
class ClipEncoderLayer(ThetaLayer):
201201
def __init__(self, theta: Theta, config: ClipTextConfig):
202-
super().__init__()
202+
super().__init__(theta)
203203
self.embed_dim = config.hidden_size
204204
self.self_attn = ClipAttention(theta=theta("self_attn"), config=config)
205205
self.layer_norm1 = LayerNorm(
@@ -251,14 +251,14 @@ def forward(
251251
return outputs
252252

253253

254-
class ClipEncoder(BaseLayer):
254+
class ClipEncoder(ThetaLayer):
255255
"""
256256
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
257257
[`ClipEncoderLayer`].
258258
"""
259259

260260
def __init__(self, theta: Theta, config: ClipTextConfig):
261-
super().__init__()
261+
super().__init__(theta)
262262
self.config = config
263263
self.layers = nn.ModuleList(
264264
[
@@ -356,9 +356,9 @@ def forward(
356356
)
357357

358358

359-
class ClipTextTransformer(nn.Module):
359+
class ClipTextTransformer(ThetaLayer):
360360
def __init__(self, theta: Theta, config: ClipTextConfig):
361-
super().__init__()
361+
super().__init__(theta)
362362
self.config = config
363363
embed_dim = config.hidden_size
364364
self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config)
@@ -475,9 +475,9 @@ def forward(
475475
)
476476

477477

478-
class ClipTextModel(BaseLayer):
478+
class ClipTextModel(ThetaLayer):
479479
def __init__(self, theta: Theta, config: ClipTextConfig):
480-
super().__init__()
480+
super().__init__(theta)
481481
self.config = config
482482
self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config)
483483

@@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module:
487487
def set_input_embeddings(self, value):
488488
self.text_model.embeddings.token_embedding = value
489489

490+
def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
491+
input_ids = (
492+
torch.arange(
493+
start=0,
494+
end=batch_size * self.config.max_position_embeddings,
495+
dtype=torch.long,
496+
)
497+
% self.config.vocab_size
498+
)
499+
input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings])
500+
return OrderedDict(
501+
[
502+
(
503+
"input_ids",
504+
input_ids,
505+
)
506+
]
507+
)
508+
490509
def forward(
491510
self,
492511
input_ids: Optional[torch.Tensor] = None,

sharktank/sharktank/models/clip/export.py

+59-15
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,98 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Union
7+
from typing import Optional, Union
88
import transformers
99
from transformers.models.clip.modeling_clip import (
10-
CLIPAttention as TransformersCLIPAttention,
11-
CLIPEncoderLayer as TransformersCLIPEncoderLayer,
12-
CLIPEncoder as TransformersCLIPEncoder,
10+
CLIPAttention as HfCLIPAttention,
11+
CLIPEncoderLayer as HfCLIPEncoderLayer,
12+
CLIPEncoder as HfCLIPEncoder,
1313
)
1414
from os import PathLike
1515
import torch
1616

1717
from ...types.theta import Theta, Dataset, torch_module_to_theta
18-
from ...types.tensors import DefaultPrimitiveTensor
1918
from ...layers.configs import ClipTextConfig
19+
from .clip import ClipTextModel
20+
from iree.turbine.aot import FxProgramsBuilder, export
2021

2122

22-
def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta:
23+
def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta:
2324
return torch_module_to_theta(model)
2425

2526

26-
def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta:
27+
def hugging_face_clip_encoder_layer_to_theta(model: HfCLIPEncoder) -> Theta:
2728
return torch_module_to_theta(model)
2829

2930

30-
def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta:
31+
def hugging_face_clip_encoder_to_theta(model: HfCLIPEncoderLayer) -> Theta:
3132
return torch_module_to_theta(model)
3233

3334

34-
def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta:
35+
def hugging_face_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta:
3536
return torch_module_to_theta(model)
3637

3738

38-
def transformers_clip_text_model_to_dataset(
39+
def hugging_face_clip_text_model_to_dataset(
3940
model: transformers.CLIPTextModel,
4041
) -> Dataset:
41-
config = ClipTextConfig.from_transformers_clip_text_config(model.config)
42-
properties = config.as_properties()
43-
theta = transformers_clip_text_model_to_theta(model)
42+
config = ClipTextConfig.from_hugging_face_clip_text_model_config(model.config)
43+
properties = config.to_properties()
44+
theta = hugging_face_clip_text_model_to_theta(model)
4445
theta.rename_tensors_to_paths()
4546
return Dataset(properties, theta)
4647

4748

49+
def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
50+
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)
51+
52+
4853
def export_clip_text_model_dataset_from_hugging_face(
4954
model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
5055
output_path: Union[str, PathLike],
56+
dtype: Optional[torch.dtype] = None,
5157
):
5258
if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
59+
assert dtype is None
5360
model = model_or_name_or_path
5461
else:
55-
model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path)
56-
dataset = transformers_clip_text_model_to_dataset(model)
62+
model = transformers.CLIPTextModel.from_pretrained(
63+
model_or_name_or_path, torch_dtype=dtype
64+
)
65+
dataset = hugging_face_clip_text_model_to_dataset(model)
5766
dataset.save(output_path)
67+
68+
69+
def export_clip_text_model_mlir(
70+
model: Union[ClipTextModel, PathLike],
71+
batch_sizes: list[int],
72+
mlir_output_path: str,
73+
):
74+
"""
75+
Args:
76+
model: either the torch module or path to GGUF/IRPA.
77+
"""
78+
if not isinstance(model, ClipTextModel):
79+
dataset = Dataset.load(model)
80+
config = ClipTextConfig.from_properties(dataset.properties)
81+
model = ClipTextModel(theta=dataset.root_theta, config=config)
82+
83+
fxb = FxProgramsBuilder(model)
84+
85+
for batch_size in batch_sizes:
86+
sample_inputs = model.sample_inputs(batch_size)
87+
88+
@fxb.export_program(
89+
name=f"forward_bs{batch_size}",
90+
args=tuple(sample_inputs.values()),
91+
dynamic_shapes=None,
92+
strict=False,
93+
)
94+
def _(
95+
model,
96+
input_ids,
97+
):
98+
return model(input_ids)
99+
100+
output = export(fxb, import_symbolic_shape_expressions=True)
101+
output.save_mlir(mlir_output_path)
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from ...layers.configs.llm_configs import ClipTextConfig
8+
from ...types.theta import Theta
9+
from .export import hugging_face_clip_text_model_to_theta
10+
import torch
11+
12+
13+
def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
14+
from transformers import CLIPTextConfig as HfCLIPTextConfig
15+
from transformers import CLIPTextModel as HfCLIPTextModel
16+
17+
hf_config = config.to_hugging_face_clip_text_model_config()
18+
model = HfCLIPTextModel(hf_config)
19+
return hugging_face_clip_text_model_to_theta(model)
20+
21+
22+
def make_random_input_token_sequences(
23+
batch_size: int, config: ClipTextConfig
24+
) -> torch.LongTensor:
25+
sequence_lens = torch.randint(
26+
low=1, high=config.max_position_embeddings + 1, size=(batch_size,)
27+
)
28+
sequences = torch.full(
29+
size=(batch_size, config.max_position_embeddings),
30+
fill_value=config.eos_token_id,
31+
dtype=torch.long,
32+
)
33+
for batch_idx, l in enumerate(sequence_lens):
34+
sequences[batch_idx][0:l] = torch.randint(
35+
low=0, high=config.vocab_size - 1, size=(l,), dtype=torch.long
36+
)
37+
return sequences

sharktank/sharktank/types/theta.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,14 @@ def rename_tensors_to_paths(self):
214214

215215

216216
def torch_module_to_theta(module: torch.nn.Module) -> Theta:
217-
return Theta(
217+
res = Theta(
218218
{
219219
name: DefaultPrimitiveTensor(data=param)
220220
for name, param in module.named_parameters()
221221
}
222222
)
223+
res.rename_tensors_to_paths()
224+
return res
223225

224226

225227
def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]:

sharktank/sharktank/utils/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def round_up_to_multiple_of(x: Number, multiple: Number) -> Number:
1919

2020
def cosine_similarity(
2121
a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None
22-
) -> float:
22+
) -> torch.Tensor:
2323
"""Compute cosine similarity over dimensions dim.
2424
If dim is none computes over all dimensions."""
2525
dot_product = torch.sum(a * b, dim=dim)

sharktank/sharktank/utils/testing.py

+31
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import gc
1919

2020
from ..types import *
21+
from .math import cosine_similarity
2122

2223
# Range of torch.rand() is [0,1)
2324
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
@@ -184,6 +185,36 @@ def assert_iterables_equal(
184185
), f"Iterables not equal at index {i} for elements {v1} and {v2}"
185186

186187

188+
def assert_text_encoder_state_close(
189+
actual: torch.Tensor, expected: torch.Tensor, atol: float
190+
):
191+
"""The cosine similarity has been suggested to compare encoder states.
192+
193+
Dehua Peng, Zhipeng Gui, Huayi Wu -
194+
Interpreting the Curse of Dimensionality from Distance Concentration and Manifold
195+
Effect (2023)
196+
197+
shows that cosine and all Minkowski distances suffer from the curse of
198+
dimensionality.
199+
The cosine similarity ignores the vector magnitudes. We can probably come up with a
200+
better metric, but this is maybe good enough.
201+
202+
The functions expects that the last dimension is the features per token.
203+
It will compute the cosine similarity for each token.
204+
"""
205+
cosine_similarity_per_token = cosine_similarity(
206+
actual,
207+
expected,
208+
dim=-1,
209+
)
210+
torch.testing.assert_close(
211+
cosine_similarity_per_token,
212+
torch.ones_like(cosine_similarity_per_token),
213+
atol=atol,
214+
rtol=0,
215+
)
216+
217+
187218
SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP"
188219

189220

0 commit comments

Comments
 (0)