Skip to content

Commit

Permalink
Fixes native inference input size mistmatch issue (#447)
Browse files Browse the repository at this point in the history
Adds extra input during construct_input phase to avoid input vs
session_input size mismatch issue during native inference.

---------

Co-authored-by: Vivek Agrawal <[email protected]>
  • Loading branch information
amd-vivekag and Vivek Agrawal authored Feb 24, 2025
1 parent 886550a commit 65aa453
Showing 1 changed file with 62 additions and 7 deletions.
69 changes: 62 additions & 7 deletions alt_e2eshark/onnx_tests/models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# model task. If a huge number of models that can be grouped
# in a common category fall under this list, a new meta_constructor
# should be created for them.
update_tokenizer_input_names = [
models_with_input_names_2 = {
"hf_paraphrase-multilingual-MiniLM-L12-v2",
"hf_all-MiniLM-L6-v2",
"hf_jina-embeddings-v2-small-en",
Expand Down Expand Up @@ -138,7 +138,27 @@
"hf_distilbert-base-nli-mean-tokens",
"hf_distilbert-base-multilingual-cased",
"hf_distilbert-base-cased",
]
}

models_with_input_names_3 = {
"hf_bart-base",
"hf_gpt2-small-spanish",
"hf_opt-125m",
"hf_Qwen1.5-0.5B-Chat",
"hf_Qwen2-0.5B",
"hf_Qwen2.5-0.5B-Instruct",
"hf_really-tiny-falcon-testing",
"hf_tiny-dummy-qwen2",
"hf_tiny-Qwen2ForCausalLM-2.5",
"hf_tiny-random-GemmaForCausalLM",
"hf_tiny-random-LlamaForCausalLM",
"hf_tiny-random-mt5",
"hf_tiny-random-Phi3ForCausalLM",
}

models_with_input_names_4 = {
"hf_ivila-row-layoutlm-finetuned-s2vl-v2",
}

# Add a basic_opt list to apply O1 to the models.
basic_opt = []
Expand Down Expand Up @@ -224,14 +244,49 @@ def construct_inputs(self):
prompt = ["Deeds will not be less valiant because they are unpraised."]

tokenizer = get_tokenizer_from_model_path(self.model_repo_path, self.cache_dir)
if self.name in update_tokenizer_input_names:
tokenizer.model_input_names = ["input_ids", "attention_mask"]

tokens = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
inputs = (*list(tokens.values()),)

self.input_name_to_shape_map = {k: v.shape for (k, v) in tokens.items()}

if self.name in models_with_input_names_2:
# Handles 2 inputs
tokenizer.model_input_names = ["input_ids", "attention_mask"]
inputs = (*list(tokens.values()), )
else:
self.input_name_to_shape_map["position_ids"] = self.input_name_to_shape_map["input_ids"]
zeros = torch.zeros(*(self.input_name_to_shape_map["input_ids"]), dtype=int)
if self.name in models_with_input_names_3:
# Handles 3 inputs
tokenizer.model_input_names = ["input_ids", "attention_mask", "position_ids"]
elif self.name in models_with_input_names_4:
tokenizer.model_input_names = ["input_ids", "bbox", "attention_mask", "position_ids"]

# Handles 4 inputs
# Tokenizer is returning tokens dict with key token_type_ids" instead of "bbox".
# For now, "token_type_ids" will be reused as bbox in this case
# bbox is a bounding box with size [?, ?, 4]
# where each 4 numbers represent x_min, y_min, x_max, y_max
tokens["token_type_ids"] = tokens["token_type_ids"].unsqueeze(-1).repeat(1, 1, 4)
else:
raise RuntimeError(f"Model: {self.name} not found in any of the registry lists.")

inputs = (*list(tokens.values()), zeros)

test_tensors = TestTensors(inputs)
return test_tensors


class HfModelWithRandomInput(HfDownloadableModel):
def export_model(self, optim_level: str | None = None):
# We won't need optim_level.
del optim_level
super().export_model("O1" if self.name in basic_opt else None)

def construct_inputs(self):
inputs = torch.randn(1, 4, 16000)

self.input_name_to_shape_map = {'input_ids': torch.Size([16000, 4]), 'attention_mask': torch.Size([16000, 4])}

test_tensors = TestTensors(inputs)
return test_tensors

Expand Down Expand Up @@ -269,7 +324,7 @@ def construct_inputs(self):
if (
"deberta" in self.name
or "roberta" in self.name
or self.name in update_tokenizer_input_names
or self.name in models_with_input_names_2
):
tokenizer.model_input_names = ["input_ids", "attention_mask"]

Expand Down

0 comments on commit 65aa453

Please sign in to comment.