Skip to content

Commit 886550a

Browse files
amd-vivekagVivek Agrawal
andauthored
Skip tokenizer checks in favor of AutoTokenizer (#442)
Removes if conditions based on model name to decide Tokenizer class. Instead, directly uses AutoTokenizer class. Co-authored-by: Vinayak Dev <[email protected]> --------- Co-authored-by: Vivek Agrawal <[email protected]>
1 parent b297f17 commit 886550a

File tree

6 files changed

+47
-28
lines changed

6 files changed

+47
-28
lines changed

alt_e2eshark/onnx_tests/helper_classes.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def export_model(self, optim_level: str | None = None):
8686
optimize=optim_level,
8787
)
8888

89+
def __repr__(self):
90+
cls = self.__class__.__name__
91+
model_details = self.model_repo_path.split('/')
92+
model_path = '/hf_'.join(model_details)
93+
return f"{cls} (full_model_path={model_path}, task_name={self.task}, name={self.name}, onnx_model_path={os.path.dirname(self.model)})"
94+
8995
def construct_model(self):
9096
model_dir = str(Path(self.model).parent)
9197

@@ -100,7 +106,12 @@ def find_models(model_dir):
100106
found_models = find_models(model_dir)
101107

102108
if len(found_models) == 0:
103-
self.export_model()
109+
try:
110+
self.export_model()
111+
except:
112+
#print(self.__repr__())
113+
raise RuntimeError("Failed to Export class: ", self)
114+
104115
found_models = find_models(model_dir)
105116
if len(found_models) == 1:
106117
self.model = found_models[0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ntu-spml/hf_distilhubert
2+
microsoft/hf_wavlm-base-plus

alt_e2eshark/onnx_tests/models/external_lists/hf-model-paths/hf-feature-extraction-model-list.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ BAAI/hf_bge-large-en-v1.5
1313
mixedbread-ai/hf_mxbai-embed-large-v1
1414
BAAI/hf_bge-base-en-v1.5
1515
facebook/hf_bart-base
16-
ntu-spml/hf_distilhubert
1716
cointegrated/hf_rubert-tiny
1817
sentence-transformers/hf_paraphrase-multilingual-mpnet-base-v2
1918
BAAI/hf_bge-large-zh-v1.5
@@ -46,7 +45,6 @@ sentence-transformers/hf_msmarco-distilbert-base-v4
4645
avsolatorio/hf_GIST-Embedding-v0
4746
sentence-transformers/hf_msmarco-distilbert-base-tas-b
4847
sentence-transformers/hf_paraphrase-mpnet-base-v2
49-
microsoft/hf_wavlm-base-plus
5048
avsolatorio/hf_GIST-large-Embedding-v0
5149
Supabase/hf_gte-small
5250
sentence-transformers/hf_paraphrase-MiniLM-L3-v2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
hf_distilhubert
2+
hf_wavlm-base-plus

alt_e2eshark/onnx_tests/models/external_lists/hf-model-shards/hf-feature-extraction-shard.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ hf_bge-large-en-v1.5
1313
hf_mxbai-embed-large-v1
1414
hf_bge-base-en-v1.5
1515
hf_bart-base
16-
hf_distilhubert
1716
hf_rubert-tiny
1817
hf_paraphrase-multilingual-mpnet-base-v2
1918
hf_bge-large-zh-v1.5
@@ -46,7 +45,6 @@ hf_msmarco-distilbert-base-v4
4645
hf_GIST-Embedding-v0
4746
hf_msmarco-distilbert-base-tas-b
4847
hf_paraphrase-mpnet-base-v2
49-
hf_wavlm-base-plus
5048
hf_GIST-large-Embedding-v0
5149
hf_gte-small
5250
hf_paraphrase-MiniLM-L3-v2

alt_e2eshark/onnx_tests/models/hf_models.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import requests
8+
import torch
89

910
from pathlib import Path
1011

@@ -14,11 +15,6 @@
1415

1516
from transformers import (
1617
AutoTokenizer,
17-
BartTokenizer,
18-
BertTokenizer,
19-
PhobertTokenizer,
20-
RobertaTokenizer,
21-
XLMRobertaTokenizer,
2218
)
2319

2420
from torchvision import transforms
@@ -41,6 +37,7 @@
4137
"object-detection",
4238
"image-segmentation",
4339
"semantic-segmentation",
40+
"audio-classification",
4441
]
4542

4643
# These are NLP model names that have a mismatch between tokenizer
@@ -148,26 +145,13 @@
148145

149146

150147
def get_tokenizer_from_model_path(model_repo_path: str, cache_dir: str | Path):
151-
name = model_repo_path.split("/")[-1]
152-
if "deberta" in name.lower():
153-
return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
154-
155-
if "bart" in name.lower():
156-
return BartTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
157-
158-
if "xlm" in name.lower() and "roberta" in name.lower():
159-
return XLMRobertaTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
148+
trust_remote_code = False
160149

161-
if "roberta" in name.lower():
162-
return RobertaTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
163-
164-
if "phobert" in name.lower():
165-
return PhobertTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
166-
167-
if "bert" in name.lower():
168-
return BertTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
150+
name = model_repo_path.split("/")[-1]
151+
if 'kobert' in name.lower():
152+
trust_remote_code = True
169153

170-
return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir)
154+
return AutoTokenizer.from_pretrained(model_repo_path, cache_dir=cache_dir, trust_remote_code=True)
171155

172156

173157
def build_repo_to_model_map():
@@ -215,6 +199,13 @@ def build_repo_to_model_map():
215199
)
216200
)
217201

202+
# Meta constructor for all multiple choice models.
203+
meta_constructor_random_input = lambda m_name: (
204+
lambda *args, **kwargs: HfModelWithRandomInput(
205+
model_repo_map[m_name][0], model_repo_map[m_name][1], *args, **kwargs
206+
)
207+
)
208+
218209
# Meta constructor for all multiple choice models.
219210
meta_constructor_multiple_choice = lambda m_name: (
220211
lambda *args, **kwargs: HfModelMultipleChoice(
@@ -245,6 +236,21 @@ def construct_inputs(self):
245236
return test_tensors
246237

247238

239+
class HfModelWithRandomInput(HfDownloadableModel):
240+
def export_model(self, optim_level: str | None = None):
241+
# We won't need optim_level.
242+
del optim_level
243+
super().export_model("O1" if self.name in basic_opt else None)
244+
245+
def construct_inputs(self):
246+
inputs = torch.randn(1, 4, 16000)
247+
248+
self.input_name_to_shape_map = {'input_ids': torch.Size([16000, 4]), 'attention_mask': torch.Size([16000, 4])}
249+
250+
test_tensors = TestTensors(inputs)
251+
return test_tensors
252+
253+
248254
class HfModelMultipleChoice(HfDownloadableModel):
249255
def export_model(self, optim_level: str | None = None):
250256
# We won't need optim_level.
@@ -334,6 +340,8 @@ def setup_test_image(height=224, width=224):
334340
| "semantic-segmentation"
335341
):
336342
register_test(meta_constructor_cv(t), t)
343+
case "audio-classification":
344+
register_test(meta_constructor_random_input(t), t)
337345
case "multiple-choice":
338346
register_test(meta_constructor_multiple_choice(t), t)
339347
case _:

0 commit comments

Comments
 (0)