Skip to content

Commit 65aa453

Browse files
amd-vivekagVivek Agrawal
andauthored
Fixes native inference input size mistmatch issue (#447)
Adds extra input during construct_input phase to avoid input vs session_input size mismatch issue during native inference. --------- Co-authored-by: Vivek Agrawal <[email protected]>
1 parent 886550a commit 65aa453

File tree

1 file changed

+62
-7
lines changed

1 file changed

+62
-7
lines changed

alt_e2eshark/onnx_tests/models/hf_models.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
# model task. If a huge number of models that can be grouped
4646
# in a common category fall under this list, a new meta_constructor
4747
# should be created for them.
48-
update_tokenizer_input_names = [
48+
models_with_input_names_2 = {
4949
"hf_paraphrase-multilingual-MiniLM-L12-v2",
5050
"hf_all-MiniLM-L6-v2",
5151
"hf_jina-embeddings-v2-small-en",
@@ -138,7 +138,27 @@
138138
"hf_distilbert-base-nli-mean-tokens",
139139
"hf_distilbert-base-multilingual-cased",
140140
"hf_distilbert-base-cased",
141-
]
141+
}
142+
143+
models_with_input_names_3 = {
144+
"hf_bart-base",
145+
"hf_gpt2-small-spanish",
146+
"hf_opt-125m",
147+
"hf_Qwen1.5-0.5B-Chat",
148+
"hf_Qwen2-0.5B",
149+
"hf_Qwen2.5-0.5B-Instruct",
150+
"hf_really-tiny-falcon-testing",
151+
"hf_tiny-dummy-qwen2",
152+
"hf_tiny-Qwen2ForCausalLM-2.5",
153+
"hf_tiny-random-GemmaForCausalLM",
154+
"hf_tiny-random-LlamaForCausalLM",
155+
"hf_tiny-random-mt5",
156+
"hf_tiny-random-Phi3ForCausalLM",
157+
}
158+
159+
models_with_input_names_4 = {
160+
"hf_ivila-row-layoutlm-finetuned-s2vl-v2",
161+
}
142162

143163
# Add a basic_opt list to apply O1 to the models.
144164
basic_opt = []
@@ -224,14 +244,49 @@ def construct_inputs(self):
224244
prompt = ["Deeds will not be less valiant because they are unpraised."]
225245

226246
tokenizer = get_tokenizer_from_model_path(self.model_repo_path, self.cache_dir)
227-
if self.name in update_tokenizer_input_names:
228-
tokenizer.model_input_names = ["input_ids", "attention_mask"]
229247

230248
tokens = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
231-
inputs = (*list(tokens.values()),)
232-
233249
self.input_name_to_shape_map = {k: v.shape for (k, v) in tokens.items()}
234250

251+
if self.name in models_with_input_names_2:
252+
# Handles 2 inputs
253+
tokenizer.model_input_names = ["input_ids", "attention_mask"]
254+
inputs = (*list(tokens.values()), )
255+
else:
256+
self.input_name_to_shape_map["position_ids"] = self.input_name_to_shape_map["input_ids"]
257+
zeros = torch.zeros(*(self.input_name_to_shape_map["input_ids"]), dtype=int)
258+
if self.name in models_with_input_names_3:
259+
# Handles 3 inputs
260+
tokenizer.model_input_names = ["input_ids", "attention_mask", "position_ids"]
261+
elif self.name in models_with_input_names_4:
262+
tokenizer.model_input_names = ["input_ids", "bbox", "attention_mask", "position_ids"]
263+
264+
# Handles 4 inputs
265+
# Tokenizer is returning tokens dict with key token_type_ids" instead of "bbox".
266+
# For now, "token_type_ids" will be reused as bbox in this case
267+
# bbox is a bounding box with size [?, ?, 4]
268+
# where each 4 numbers represent x_min, y_min, x_max, y_max
269+
tokens["token_type_ids"] = tokens["token_type_ids"].unsqueeze(-1).repeat(1, 1, 4)
270+
else:
271+
raise RuntimeError(f"Model: {self.name} not found in any of the registry lists.")
272+
273+
inputs = (*list(tokens.values()), zeros)
274+
275+
test_tensors = TestTensors(inputs)
276+
return test_tensors
277+
278+
279+
class HfModelWithRandomInput(HfDownloadableModel):
280+
def export_model(self, optim_level: str | None = None):
281+
# We won't need optim_level.
282+
del optim_level
283+
super().export_model("O1" if self.name in basic_opt else None)
284+
285+
def construct_inputs(self):
286+
inputs = torch.randn(1, 4, 16000)
287+
288+
self.input_name_to_shape_map = {'input_ids': torch.Size([16000, 4]), 'attention_mask': torch.Size([16000, 4])}
289+
235290
test_tensors = TestTensors(inputs)
236291
return test_tensors
237292

@@ -269,7 +324,7 @@ def construct_inputs(self):
269324
if (
270325
"deberta" in self.name
271326
or "roberta" in self.name
272-
or self.name in update_tokenizer_input_names
327+
or self.name in models_with_input_names_2
273328
):
274329
tokenizer.model_input_names = ["input_ids", "attention_mask"]
275330

0 commit comments

Comments
 (0)