|
5 | 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | 6 |
|
7 | 7 | import requests |
| 8 | +import torch |
8 | 9 |
|
9 | 10 | from pathlib import Path |
10 | 11 |
|
|
14 | 15 |
|
15 | 16 | from transformers import ( |
16 | 17 | AutoTokenizer, |
17 | | - BartTokenizer, |
18 | | - BertTokenizer, |
19 | | - PhobertTokenizer, |
20 | | - RobertaTokenizer, |
21 | | - XLMRobertaTokenizer, |
22 | 18 | ) |
23 | 19 |
|
24 | 20 | from torchvision import transforms |
|
41 | 37 | "object-detection", |
42 | 38 | "image-segmentation", |
43 | 39 | "semantic-segmentation", |
| 40 | + "audio-classification", |
44 | 41 | ] |
45 | 42 |
|
46 | 43 | # These are NLP model names that have a mismatch between tokenizer |
|
148 | 145 |
|
149 | 146 |
|
150 | 147 | def get_tokenizer_from_model_path(model_repo_path: str, cache_dir: str | Path): |
151 | | - name = model_repo_path.split("/")[-1] |
152 | | - if "deberta" in name.lower(): |
153 | | - return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
154 | | - |
155 | | - if "bart" in name.lower(): |
156 | | - return BartTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
157 | | - |
158 | | - if "xlm" in name.lower() and "roberta" in name.lower(): |
159 | | - return XLMRobertaTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
| 148 | + trust_remote_code = False |
160 | 149 |
|
161 | | - if "roberta" in name.lower(): |
162 | | - return RobertaTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
163 | | - |
164 | | - if "phobert" in name.lower(): |
165 | | - return PhobertTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
166 | | - |
167 | | - if "bert" in name.lower(): |
168 | | - return BertTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
| 150 | + name = model_repo_path.split("/")[-1] |
| 151 | + if 'kobert' in name.lower(): |
| 152 | + trust_remote_code = True |
169 | 153 |
|
170 | | - return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir) |
| 154 | + return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir, trust_remote_code=True) |
171 | 155 |
|
172 | 156 |
|
173 | 157 | def build_repo_to_model_map(): |
@@ -215,6 +199,13 @@ def build_repo_to_model_map(): |
215 | 199 | ) |
216 | 200 | ) |
217 | 201 |
|
| 202 | +# Meta constructor for all multiple choice models. |
| 203 | +meta_constructor_random_input = lambda m_name: ( |
| 204 | + lambda *args, **kwargs: HfModelWithRandomInput( |
| 205 | + model_repo_map[m_name][0], model_repo_map[m_name][1], *args, **kwargs |
| 206 | + ) |
| 207 | +) |
| 208 | + |
218 | 209 | # Meta constructor for all multiple choice models. |
219 | 210 | meta_constructor_multiple_choice = lambda m_name: ( |
220 | 211 | lambda *args, **kwargs: HfModelMultipleChoice( |
@@ -245,6 +236,21 @@ def construct_inputs(self): |
245 | 236 | return test_tensors |
246 | 237 |
|
247 | 238 |
|
| 239 | +class HfModelWithRandomInput(HfDownloadableModel): |
| 240 | + def export_model(self, optim_level: str | None = None): |
| 241 | + # We won't need optim_level. |
| 242 | + del optim_level |
| 243 | + super().export_model("O1" if self.name in basic_opt else None) |
| 244 | + |
| 245 | + def construct_inputs(self): |
| 246 | + inputs = torch.randn(1, 4, 16000) |
| 247 | + |
| 248 | + self.input_name_to_shape_map = {'input_ids': torch.Size([16000, 4]), 'attention_mask': torch.Size([16000, 4])} |
| 249 | + |
| 250 | + test_tensors = TestTensors(inputs) |
| 251 | + return test_tensors |
| 252 | + |
| 253 | + |
248 | 254 | class HfModelMultipleChoice(HfDownloadableModel): |
249 | 255 | def export_model(self, optim_level: str | None = None): |
250 | 256 | # We won't need optim_level. |
@@ -334,6 +340,8 @@ def setup_test_image(height=224, width=224): |
334 | 340 | | "semantic-segmentation" |
335 | 341 | ): |
336 | 342 | register_test(meta_constructor_cv(t), t) |
| 343 | + case "audio-classification": |
| 344 | + register_test(meta_constructor_random_input(t), t) |
337 | 345 | case "multiple-choice": |
338 | 346 | register_test(meta_constructor_multiple_choice(t), t) |
339 | 347 | case _: |
|
0 commit comments