Skip to content

Commit a1e632a

Browse files
authored
Add CLI script exporting CLIP Toy model IREE test data (#672)
This is required to have an easy way of exporting test data that will be used in IREE to guard against regressions. E.g. ``` python -m sharktank.models.clip.export_toy_text_model_iree_test_data \ --output-dir=clip_toy_text_model ``` Refactor some of the existing tests to reuse the new export logic.
1 parent ffb0dd2 commit a1e632a

File tree

5 files changed

+280
-105
lines changed

5 files changed

+280
-105
lines changed

sharktank/sharktank/models/clip/export.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
CLIPEncoderLayer as HfCLIPEncoderLayer,
1212
CLIPEncoder as HfCLIPEncoder,
1313
)
14-
from os import PathLike
1514
import torch
15+
from os import PathLike
1616

1717
from ...types.theta import Theta, Dataset, torch_module_to_theta
1818
from ...layers.configs import ClipTextConfig
@@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
5050
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)
5151

5252

53+
def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike):
54+
dataset = clip_text_model_to_dataset(model)
55+
dataset.save(output_path)
56+
57+
5358
def export_clip_text_model_dataset_from_hugging_face(
54-
model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
55-
output_path: Union[str, PathLike],
59+
model_or_name_or_path: Union[PathLike, transformers.CLIPTextModel],
60+
output_path: PathLike,
5661
dtype: Optional[torch.dtype] = None,
5762
):
5863
if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
@@ -99,3 +104,17 @@ def _(
99104

100105
output = export(fxb, import_symbolic_shape_expressions=True)
101106
output.save_mlir(mlir_output_path)
107+
108+
109+
def export_clip_text_model_to_iree(
110+
model: ClipTextModel,
111+
batch_sizes: list[int],
112+
mlir_output_path: PathLike,
113+
parameters_output_path: PathLike,
114+
):
115+
export_clip_text_model_iree_parameters(model, parameters_output_path)
116+
export_clip_text_model_mlir(
117+
model=parameters_output_path,
118+
batch_sizes=batch_sizes,
119+
mlir_output_path=mlir_output_path,
120+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from argparse import ArgumentParser
8+
from typing import Optional
9+
from pathlib import Path
10+
11+
from .testing import export_clip_toy_text_model_default_iree_test_data
12+
13+
14+
def main(args: Optional[list[str]] = None):
15+
parser = ArgumentParser(
16+
description=(
17+
"Export test data for toy-sized CLIP text model."
18+
" This program MLIR, parameters sample input and expected output."
19+
" Exports float32 and bfloat16 model variants."
20+
" The expected output is always in float32 precision."
21+
)
22+
)
23+
parser.add_argument("--output-dir", type=str, default=f"clip_toy_text_model")
24+
args = parser.parse_args(args=args)
25+
export_clip_toy_text_model_default_iree_test_data(Path(args.output_dir))
26+
27+
28+
if __name__ == "__main__":
29+
main()

sharktank/sharktank/models/clip/testing.py

+157-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,167 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from ...layers.configs.llm_configs import ClipTextConfig
8-
from ...types.theta import Theta
9-
from .export import hugging_face_clip_text_model_to_theta
7+
import functools
108
import torch
9+
from os import PathLike, makedirs
10+
from typing import Union, Optional
11+
from copy import copy
12+
from iree.turbine.aot.params import ParameterArchiveBuilder
13+
14+
from ...layers.configs.llm_configs import ClipTextConfig
15+
from .clip import ClipTextModel
16+
from ...types.theta import Theta, Dataset
17+
from ...types.tensors import dtype_to_serialized_short_name
18+
from ...utils.io import save_tensor_as_irpa
19+
from .export import (
20+
clip_text_model_to_dataset,
21+
hugging_face_clip_text_model_to_theta,
22+
export_clip_text_model_to_iree,
23+
)
24+
from ...transforms.dataset import set_float_dtype
25+
26+
27+
def clip_toy_text_model_config(dtype: Optional[torch.dtype] = None) -> ClipTextConfig:
28+
num_attention_heads = 5
29+
vocab_size = 11
30+
return ClipTextConfig(
31+
vocab_size=vocab_size,
32+
hidden_size=13 * num_attention_heads,
33+
intermediate_size=7,
34+
projection_dim=3,
35+
num_attention_heads=num_attention_heads,
36+
max_position_embeddings=17,
37+
layer_norm_eps=1e-4,
38+
num_hidden_layers=2,
39+
bos_token_id=vocab_size - 2,
40+
eos_token_id=vocab_size - 1,
41+
dtype=dtype,
42+
)
43+
44+
45+
def export_clip_toy_text_model_default_iree_test_data(output_dir: PathLike):
46+
makedirs(output_dir, exist_ok=True)
47+
48+
# We want to always export the same without interfering with RNG for the rest of
49+
# the program.
50+
rng_state = torch.get_rng_state()
51+
torch.random.manual_seed(12345)
52+
53+
reference_dtype = torch.float32
54+
target_dtypes = [torch.float32, torch.bfloat16]
55+
target_iree_parameters_output_paths = []
56+
target_mlir_output_paths = []
57+
batch_size = 4
58+
for dtype in target_dtypes:
59+
prefix = output_dir / f"{dtype_to_serialized_short_name(dtype)}"
60+
target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa")
61+
target_mlir_output_paths.append(f"{prefix}.mlir")
62+
call_prefix = output_dir / f"forward_bs{batch_size}"
63+
input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa"
64+
expected_last_hidden_state_output_path = (
65+
f"{call_prefix}_expected_result0_last_hidden_state_"
66+
f"{dtype_to_serialized_short_name(reference_dtype)}.irpa"
67+
)
68+
export_clip_toy_text_model_iree_test_data(
69+
reference_dtype=reference_dtype,
70+
target_dtypes=target_dtypes,
71+
batch_size=batch_size,
72+
input_ids_output_path=input_ids_output_path,
73+
expected_last_hidden_state_output_path=expected_last_hidden_state_output_path,
74+
target_iree_parameters_output_paths=target_iree_parameters_output_paths,
75+
target_mlir_output_paths=target_mlir_output_paths,
76+
)
77+
78+
torch.set_rng_state(rng_state)
79+
80+
81+
def export_clip_toy_text_model_iree_test_data(
82+
reference_dtype: torch.dtype,
83+
target_dtypes: list[torch.dtype],
84+
batch_size: int,
85+
target_iree_parameters_output_paths: list[PathLike],
86+
target_mlir_output_paths: list[PathLike],
87+
input_ids_output_path: PathLike,
88+
expected_last_hidden_state_output_path: PathLike,
89+
):
90+
reference_config = clip_toy_text_model_config(reference_dtype)
91+
input_ids = make_random_input_token_sequences(
92+
batch_size=batch_size, config=reference_config
93+
)
94+
reference_theta = make_clip_text_model_random_theta(reference_config)
95+
reference_model = ClipTextModel(theta=reference_theta, config=reference_config)
96+
for i, (
97+
target_dtype,
98+
target_iree_parameters_output_path,
99+
target_mlir_output_path,
100+
) in enumerate(
101+
zip(
102+
target_dtypes,
103+
target_iree_parameters_output_paths,
104+
target_mlir_output_paths,
105+
strict=True,
106+
)
107+
):
108+
current_input_ids_output_path = None
109+
current_expected_last_hidden_state_output_path = None
110+
if i == 0:
111+
current_input_ids_output_path = input_ids_output_path
112+
current_expected_last_hidden_state_output_path = (
113+
expected_last_hidden_state_output_path
114+
)
115+
export_clip_text_model_iree_test_data(
116+
reference_model=reference_model,
117+
target_dtype=target_dtype,
118+
input_ids=input_ids,
119+
target_iree_parameters_output_path=target_iree_parameters_output_path,
120+
target_mlir_output_path=target_mlir_output_path,
121+
input_ids_output_path=current_input_ids_output_path,
122+
expected_last_hidden_state_output_path=current_expected_last_hidden_state_output_path,
123+
)
124+
125+
126+
def export_clip_text_model_iree_test_data(
127+
reference_model: ClipTextModel,
128+
target_dtype: torch.dtype,
129+
input_ids: torch.LongTensor,
130+
target_mlir_output_path: PathLike,
131+
target_iree_parameters_output_path: PathLike,
132+
input_ids_output_path: Optional[PathLike] = None,
133+
expected_last_hidden_state_output_path: Optional[PathLike] = None,
134+
):
135+
batch_size = input_ids.shape[0]
136+
reference_dataset = clip_text_model_to_dataset(reference_model)
137+
target_config = copy(reference_model.config)
138+
target_config.dtype = target_dtype
139+
target_dataset = Dataset(
140+
root_theta=reference_dataset.root_theta.transform(
141+
functools.partial(set_float_dtype, dtype=target_dtype)
142+
),
143+
properties=target_config.to_properties(),
144+
)
145+
target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config)
146+
export_clip_text_model_to_iree(
147+
target_model,
148+
batch_sizes=[batch_size],
149+
mlir_output_path=target_mlir_output_path,
150+
parameters_output_path=target_iree_parameters_output_path,
151+
)
152+
153+
if input_ids_output_path is not None:
154+
save_tensor_as_irpa(input_ids, input_ids_output_path)
155+
156+
if expected_last_hidden_state_output_path is None:
157+
return
158+
159+
expected_last_hidden_state = reference_model(input_ids=input_ids)[
160+
"last_hidden_state"
161+
]
162+
save_tensor_as_irpa(
163+
expected_last_hidden_state, expected_last_hidden_state_output_path
164+
)
11165

12166

13167
def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
14-
from transformers import CLIPTextConfig as HfCLIPTextConfig
15168
from transformers import CLIPTextModel as HfCLIPTextModel
16169

17170
hf_config = config.to_hugging_face_clip_text_model_config()

sharktank/sharktank/utils/io.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from pathlib import Path
8+
import torch
9+
from os import PathLike
810

9-
from iree.turbine.aot import (
10-
ParameterArchiveBuilder,
11-
)
11+
from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive
1212

1313

1414
class ShardedArchiveBuilder(ParameterArchiveBuilder):
@@ -49,3 +49,22 @@ def path_for_rank(path: Path, rank: int):
4949
/tmp/foobar.rank0.irpa
5050
"""
5151
return path.with_suffix(f".rank{rank}{path.suffix}")
52+
53+
54+
def save_tensor_as_irpa(tensor: torch.Tensor, path: PathLike):
55+
"""Save a single tensor into an IRPA file."""
56+
param_builder = ParameterArchiveBuilder()
57+
param_builder.add_tensor("", tensor)
58+
param_builder.save(path)
59+
60+
61+
def load_irpa_as_tensor(tensor: torch.Tensor, path: PathLike, **kwargs):
62+
"""Load a tensor form an IRPA file that holds only one tensor."""
63+
params = ParameterArchive(path, **kwargs)
64+
items = params.items()
65+
if len(items) != 1:
66+
raise ValueError(
67+
f'Too many items {len(items)} in IRPA file "{path}".'
68+
" Only a single tensor was expected."
69+
)
70+
return items[0][1].as_tensor()

0 commit comments

Comments
 (0)