|
4 | 4 | # See https://llvm.org/LICENSE.txt for license information.
|
5 | 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
6 | 6 |
|
7 |
| -from ...layers.configs.llm_configs import ClipTextConfig |
8 |
| -from ...types.theta import Theta |
9 |
| -from .export import hugging_face_clip_text_model_to_theta |
| 7 | +import functools |
10 | 8 | import torch
|
| 9 | +from os import PathLike, makedirs |
| 10 | +from typing import Union, Optional |
| 11 | +from copy import copy |
| 12 | +from iree.turbine.aot.params import ParameterArchiveBuilder |
| 13 | + |
| 14 | +from ...layers.configs.llm_configs import ClipTextConfig |
| 15 | +from .clip import ClipTextModel |
| 16 | +from ...types.theta import Theta, Dataset |
| 17 | +from ...types.tensors import dtype_to_serialized_short_name |
| 18 | +from ...utils.io import save_tensor_as_irpa |
| 19 | +from .export import ( |
| 20 | + clip_text_model_to_dataset, |
| 21 | + hugging_face_clip_text_model_to_theta, |
| 22 | + export_clip_text_model_to_iree, |
| 23 | +) |
| 24 | +from ...transforms.dataset import set_float_dtype |
| 25 | + |
| 26 | + |
| 27 | +def clip_toy_text_model_config(dtype: Optional[torch.dtype] = None) -> ClipTextConfig: |
| 28 | + num_attention_heads = 5 |
| 29 | + vocab_size = 11 |
| 30 | + return ClipTextConfig( |
| 31 | + vocab_size=vocab_size, |
| 32 | + hidden_size=13 * num_attention_heads, |
| 33 | + intermediate_size=7, |
| 34 | + projection_dim=3, |
| 35 | + num_attention_heads=num_attention_heads, |
| 36 | + max_position_embeddings=17, |
| 37 | + layer_norm_eps=1e-4, |
| 38 | + num_hidden_layers=2, |
| 39 | + bos_token_id=vocab_size - 2, |
| 40 | + eos_token_id=vocab_size - 1, |
| 41 | + dtype=dtype, |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +def export_clip_toy_text_model_default_iree_test_data(output_dir: PathLike): |
| 46 | + makedirs(output_dir, exist_ok=True) |
| 47 | + |
| 48 | + # We want to always export the same without interfering with RNG for the rest of |
| 49 | + # the program. |
| 50 | + rng_state = torch.get_rng_state() |
| 51 | + torch.random.manual_seed(12345) |
| 52 | + |
| 53 | + reference_dtype = torch.float32 |
| 54 | + target_dtypes = [torch.float32, torch.bfloat16] |
| 55 | + target_iree_parameters_output_paths = [] |
| 56 | + target_mlir_output_paths = [] |
| 57 | + batch_size = 4 |
| 58 | + for dtype in target_dtypes: |
| 59 | + prefix = output_dir / f"{dtype_to_serialized_short_name(dtype)}" |
| 60 | + target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa") |
| 61 | + target_mlir_output_paths.append(f"{prefix}.mlir") |
| 62 | + call_prefix = output_dir / f"forward_bs{batch_size}" |
| 63 | + input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa" |
| 64 | + expected_last_hidden_state_output_path = ( |
| 65 | + f"{call_prefix}_expected_result0_last_hidden_state_" |
| 66 | + f"{dtype_to_serialized_short_name(reference_dtype)}.irpa" |
| 67 | + ) |
| 68 | + export_clip_toy_text_model_iree_test_data( |
| 69 | + reference_dtype=reference_dtype, |
| 70 | + target_dtypes=target_dtypes, |
| 71 | + batch_size=batch_size, |
| 72 | + input_ids_output_path=input_ids_output_path, |
| 73 | + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path, |
| 74 | + target_iree_parameters_output_paths=target_iree_parameters_output_paths, |
| 75 | + target_mlir_output_paths=target_mlir_output_paths, |
| 76 | + ) |
| 77 | + |
| 78 | + torch.set_rng_state(rng_state) |
| 79 | + |
| 80 | + |
| 81 | +def export_clip_toy_text_model_iree_test_data( |
| 82 | + reference_dtype: torch.dtype, |
| 83 | + target_dtypes: list[torch.dtype], |
| 84 | + batch_size: int, |
| 85 | + target_iree_parameters_output_paths: list[PathLike], |
| 86 | + target_mlir_output_paths: list[PathLike], |
| 87 | + input_ids_output_path: PathLike, |
| 88 | + expected_last_hidden_state_output_path: PathLike, |
| 89 | +): |
| 90 | + reference_config = clip_toy_text_model_config(reference_dtype) |
| 91 | + input_ids = make_random_input_token_sequences( |
| 92 | + batch_size=batch_size, config=reference_config |
| 93 | + ) |
| 94 | + reference_theta = make_clip_text_model_random_theta(reference_config) |
| 95 | + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) |
| 96 | + for i, ( |
| 97 | + target_dtype, |
| 98 | + target_iree_parameters_output_path, |
| 99 | + target_mlir_output_path, |
| 100 | + ) in enumerate( |
| 101 | + zip( |
| 102 | + target_dtypes, |
| 103 | + target_iree_parameters_output_paths, |
| 104 | + target_mlir_output_paths, |
| 105 | + strict=True, |
| 106 | + ) |
| 107 | + ): |
| 108 | + current_input_ids_output_path = None |
| 109 | + current_expected_last_hidden_state_output_path = None |
| 110 | + if i == 0: |
| 111 | + current_input_ids_output_path = input_ids_output_path |
| 112 | + current_expected_last_hidden_state_output_path = ( |
| 113 | + expected_last_hidden_state_output_path |
| 114 | + ) |
| 115 | + export_clip_text_model_iree_test_data( |
| 116 | + reference_model=reference_model, |
| 117 | + target_dtype=target_dtype, |
| 118 | + input_ids=input_ids, |
| 119 | + target_iree_parameters_output_path=target_iree_parameters_output_path, |
| 120 | + target_mlir_output_path=target_mlir_output_path, |
| 121 | + input_ids_output_path=current_input_ids_output_path, |
| 122 | + expected_last_hidden_state_output_path=current_expected_last_hidden_state_output_path, |
| 123 | + ) |
| 124 | + |
| 125 | + |
| 126 | +def export_clip_text_model_iree_test_data( |
| 127 | + reference_model: ClipTextModel, |
| 128 | + target_dtype: torch.dtype, |
| 129 | + input_ids: torch.LongTensor, |
| 130 | + target_mlir_output_path: PathLike, |
| 131 | + target_iree_parameters_output_path: PathLike, |
| 132 | + input_ids_output_path: Optional[PathLike] = None, |
| 133 | + expected_last_hidden_state_output_path: Optional[PathLike] = None, |
| 134 | +): |
| 135 | + batch_size = input_ids.shape[0] |
| 136 | + reference_dataset = clip_text_model_to_dataset(reference_model) |
| 137 | + target_config = copy(reference_model.config) |
| 138 | + target_config.dtype = target_dtype |
| 139 | + target_dataset = Dataset( |
| 140 | + root_theta=reference_dataset.root_theta.transform( |
| 141 | + functools.partial(set_float_dtype, dtype=target_dtype) |
| 142 | + ), |
| 143 | + properties=target_config.to_properties(), |
| 144 | + ) |
| 145 | + target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config) |
| 146 | + export_clip_text_model_to_iree( |
| 147 | + target_model, |
| 148 | + batch_sizes=[batch_size], |
| 149 | + mlir_output_path=target_mlir_output_path, |
| 150 | + parameters_output_path=target_iree_parameters_output_path, |
| 151 | + ) |
| 152 | + |
| 153 | + if input_ids_output_path is not None: |
| 154 | + save_tensor_as_irpa(input_ids, input_ids_output_path) |
| 155 | + |
| 156 | + if expected_last_hidden_state_output_path is None: |
| 157 | + return |
| 158 | + |
| 159 | + expected_last_hidden_state = reference_model(input_ids=input_ids)[ |
| 160 | + "last_hidden_state" |
| 161 | + ] |
| 162 | + save_tensor_as_irpa( |
| 163 | + expected_last_hidden_state, expected_last_hidden_state_output_path |
| 164 | + ) |
11 | 165 |
|
12 | 166 |
|
13 | 167 | def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
|
14 |
| - from transformers import CLIPTextConfig as HfCLIPTextConfig |
15 | 168 | from transformers import CLIPTextModel as HfCLIPTextModel
|
16 | 169 |
|
17 | 170 | hf_config = config.to_hugging_face_clip_text_model_config()
|
|
0 commit comments