Skip to content

Commit e9692c6

Browse files
authored
Add vision multimodal support and Gemma3 vision (#139)
Add support for exporting vision multimodal models and enable Gemma3 vision as the first example.
1 parent 216160a commit e9692c6

File tree

7 files changed

+216
-58
lines changed

7 files changed

+216
-58
lines changed

optimum/executorch/modeling.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
from ..exporters import TasksManager
4747
from ..exporters.executorch import main_export
48-
from ..exporters.executorch.utils import verify_eos_tokens_in_pretrained_tokenizer
48+
from ..exporters.executorch.utils import apply_chat_template_with_fallback, verify_eos_tokens_in_pretrained_tokenizer
4949
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
5050
from ..utils.file_utils import find_files_matching_pattern
5151
from .stats import Stats
@@ -1167,8 +1167,9 @@ class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
11671167
Size of the model vocabulary.
11681168
"""
11691169

1170-
# Using general `AutoModel` since it usually routes to the correct model variant and there is no
1171-
# auto model class that captures both audio and image.
1170+
# Using `AutoModel` since there is no auto model class that captures both audio and image.
1171+
# This is not too important since it's just used for automatically inferring the
1172+
# task type. For MultiModal, we should always be specifying the task type anyways.
11721173
auto_model_class = AutoModel
11731174

11741175
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
@@ -1180,6 +1181,7 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
11801181
f"Exported .pte file needs to contain the following required methods: {required_methods}"
11811182
)
11821183

1184+
# Multimodal-related metadata.
11831185
self.encoder_name = None
11841186
for method_name in self.model.method_names():
11851187
if method_name == "audio_encoder":
@@ -1192,38 +1194,28 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
11921194
raise ValueError(
11931195
'Exported .pte file needs to contain either an an "audio_encoder" or a "vision_encoder" in its methods.'
11941196
)
1197+
self.modality = self.model.run_method("modality")[0]
11951198

1199+
# Decoder-related metadata.
11961200
metadata = self.model.method_names()
1197-
if "use_kv_cache" in metadata:
1198-
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
11991201
if "get_max_seq_len" in metadata:
12001202
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
1201-
if "get_max_batch_size" in metadata:
1202-
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
1203-
if "get_dtype" in metadata:
1204-
self.dtype = self.model.run_method("get_dtype")[0]
12051203
if "get_bos_id" in metadata:
12061204
self.bos_token_id = self.model.run_method("get_bos_id")[0]
12071205
if "get_eos_id" in metadata:
12081206
self.eos_token_id = self.model.run_method("get_eos_id")[0]
1209-
if "get_vocab_size" in metadata:
1210-
self.vocab_size = self.model.run_method("get_vocab_size")[0]
1211-
if "max_hidden_seq_length" in metadata:
1212-
self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0]
1213-
if "decoder_start_token_id" in metadata:
1214-
self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0]
12151207

12161208
def forward(
12171209
self,
12181210
input_ids: torch.Tensor,
12191211
cache_position: torch.Tensor,
1220-
input_features: Optional[torch.Tensor] = None,
1212+
multimodal_features: Optional[torch.Tensor] = None,
12211213
):
12221214
token_embeddings = self.model.run_method("token_embedding", (input_ids,))[0]
1223-
if input_features is not None:
1215+
if multimodal_features is not None:
12241216
encoder_embeddings = self.model.run_method(
12251217
self.encoder_name,
1226-
(input_features,),
1218+
(multimodal_features,),
12271219
)[0]
12281220
encoder_token_mask = input_ids == self.encoder_token_id
12291221
token_embeddings[encoder_token_mask] = encoder_embeddings
@@ -1242,7 +1234,7 @@ def generate(
12421234
echo: bool = False,
12431235
pos_base: int = 0,
12441236
max_seq_len: Optional[int] = None,
1245-
input_features: Optional[torch.Tensor] = None,
1237+
multimodal_features: Optional[torch.Tensor] = None,
12461238
) -> List[int]:
12471239
self.device = torch.device("cpu")
12481240
if max_seq_len is None:
@@ -1259,7 +1251,7 @@ def generate(
12591251
logits = self.forward(
12601252
input_ids=prompt_tokens,
12611253
cache_position=torch.arange(prompt_tokens.size(1), dtype=torch.long, device=self.device),
1262-
input_features=input_features,
1254+
multimodal_features=multimodal_features,
12631255
)
12641256
self.stats.on_sampling_end()
12651257
self.stats.on_prompt_eval_end()
@@ -1279,7 +1271,7 @@ def generate(
12791271
dtype=torch.long,
12801272
device=self.device,
12811273
),
1282-
input_features=None,
1274+
multimodal_features=None,
12831275
)
12841276
self.stats.on_sampling_end()
12851277
if not first_token_generated:
@@ -1337,13 +1329,26 @@ def text_generation(
13371329
self.stats.reset()
13381330
self.stats.on_inference_start()
13391331

1340-
inputs = processor.apply_chat_template(input_conversation)
1332+
inputs = apply_chat_template_with_fallback(
1333+
processor,
1334+
input_conversation,
1335+
add_generation_prompt=True,
1336+
tokenize=True,
1337+
return_dict=True,
1338+
return_tensors="pt",
1339+
)
1340+
13411341
self.stats.on_token_encode_end()
13421342
self.stats.set_num_prompt_tokens(len(inputs["input_ids"][0]))
13431343

1344+
multimodal_features = None
1345+
if self.modality == "vision":
1346+
multimodal_features = inputs.get("pixel_values", None)
1347+
elif self.modality == "audio":
1348+
multimodal_features = inputs.get("input_features", None)
13441349
generated_tokens = self.generate(
13451350
prompt_tokens=inputs["input_ids"],
1346-
input_features=inputs["input_features"],
1351+
multimodal_features=multimodal_features,
13471352
echo=echo,
13481353
max_seq_len=len(inputs["input_ids"][0]) + max_seq_len,
13491354
)

optimum/exporters/executorch/integrations.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,50 @@
3333

3434
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
3535

36-
from .utils import save_config_to_constant_methods
36+
from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
37+
38+
39+
class VisionExportableModule(torch.nn.Module):
40+
def __init__(self, model: torch.nn.Module):
41+
super().__init__()
42+
self.model = model
43+
44+
def prepare_export_inputs(self):
45+
# 1. Get export inputs
46+
model_id = self.model.config.name_or_path
47+
processor = AutoProcessor.from_pretrained(model_id)
48+
sample_conversation_with_image = [
49+
{
50+
"role": "user",
51+
"content": [
52+
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
53+
],
54+
},
55+
]
56+
processed_inputs = processor.apply_chat_template(
57+
sample_conversation_with_image,
58+
add_generation_prompt=True,
59+
tokenize=True,
60+
return_dict=True,
61+
return_tensors="pt",
62+
)
63+
if "pixel_values" not in processed_inputs:
64+
raise ValueError(
65+
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}"
66+
)
67+
export_inputs = processed_inputs["pixel_values"]
68+
69+
# 2. Get export dynamic shapes
70+
dynamic_shapes = None # No batching for now.
71+
72+
return export_inputs, dynamic_shapes
73+
74+
def forward(
75+
self,
76+
input_features: torch.FloatTensor,
77+
):
78+
image_embeds = self.model.get_image_features(input_features)
79+
return image_embeds.unsqueeze(0)
3780

3881

3982
class AudioExportableModule(torch.nn.Module):
@@ -56,7 +99,14 @@ def prepare_export_inputs(self):
5699
],
57100
}
58101
]
59-
processed_inputs = processor.apply_chat_template(sample_conversation_with_audio)
102+
processed_inputs = apply_chat_template_with_fallback(
103+
processor,
104+
sample_conversation_with_audio,
105+
add_generation_prompt=True,
106+
tokenize=True,
107+
return_dict=True,
108+
return_tensors="pt",
109+
)
60110
if "input_features" not in processed_inputs:
61111
raise ValueError(
62112
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'input_features' key: {processed_inputs}"
@@ -129,9 +179,13 @@ def __init__(
129179
self.processor_config = processor_config
130180
self.use_custom_kv_cache = use_custom_kv_cache
131181
self.use_custom_sdpa = use_custom_sdpa
132-
modality_token_placeholder_id_kwargs = {f"{modality}_token_id": getattr(self.config, f"{modality}_token_id")}
182+
additional_metadata_kwargs = {"modality": modality}
183+
if modality == "audio":
184+
additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "audio_token_id")
185+
elif modality == "vision":
186+
additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "image_token_id")
133187
self.metadata = save_config_to_constant_methods(
134-
model.config.text_config, model.generation_config, processor_config, **modality_token_placeholder_id_kwargs
188+
model.config.text_config, model.generation_config, processor_config, **additional_metadata_kwargs
135189
)
136190
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")
137191

@@ -148,9 +202,9 @@ def _prepare_text_embedding_export_inputs(self, max_seq_len: int):
148202
example_input_ids = torch.zeros((1, seq_length), dtype=torch.long)
149203

150204
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_len)
151-
dynamic_shapes = {
152-
"input": {1: seq_len_dim},
153-
} # nn.embedding forward() args are here - https://github.com/pytorch/pytorch/blob/febf3c475e6fe369b41ef009f3598659a6df0911/torch/nn/modules/sparse.py#L15.
205+
# Don't use named dynamic shapes since embedding modules can have diferent arg
206+
# names for the input. e.g. nn.embedding vs embedding modules defined in Transformers.
207+
dynamic_shapes = ({1: seq_len_dim},)
154208

155209
return example_input_ids, dynamic_shapes
156210

@@ -213,7 +267,7 @@ def export(
213267

214268
# 1. Export text decoder.
215269
exportable_module = TorchExportableModuleForDecoderOnlyLM(
216-
getattr(self.model, self.decoder_name),
270+
self.model,
217271
)
218272
exported_programs = {}
219273

@@ -267,7 +321,7 @@ def export(
267321
)
268322

269323
token_embedding_exported_program = torch.export.export(
270-
getattr(self.model, self.decoder_name).get_input_embeddings(),
324+
self.model.get_input_embeddings(),
271325
args=(input_ids,),
272326
kwargs={},
273327
dynamic_shapes=dynamic_shapes,
@@ -281,13 +335,13 @@ def export(
281335

282336
if self.modality == "audio":
283337
encoder = AudioExportableModule(self.model)
284-
input_features, dynamic_shapes = encoder.prepare_export_inputs()
285338
elif self.modality == "vision":
286-
raise ValueError("Vision is not yet supported, this will be available soon.")
339+
encoder = VisionExportableModule(self.model)
287340
else:
288341
raise ValueError(
289342
f"{self.model.config.name_or_path} has an unsupported modality that is not supported yet for Optimum - please file an issue."
290343
)
344+
input_features, dynamic_shapes = encoder.prepare_export_inputs()
291345

292346
logging.info(
293347
f"Exporting {self.modality} encoder using input_features({input_features.shape}), dynamic_shapes={dynamic_shapes}"
@@ -388,10 +442,6 @@ def export(
388442
f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}"
389443
)
390444

391-
from transformers.integrations.executorch import (
392-
TorchExportableModuleForDecoderOnlyLM,
393-
)
394-
395445
exportable_module = TorchExportableModuleForDecoderOnlyLM(
396446
self.model,
397447
max_batch_size=1,

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os.path
1818

1919
import torchao
20-
from transformers import AutoConfig, AutoModel, GenerationConfig
20+
from transformers import AutoConfig, AutoModelForPreTraining, GenerationConfig
2121

2222
from ..integrations import MultiModalTextToTextExportableModule
2323
from ..quantization import quantize_model_
@@ -159,21 +159,27 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
159159
if hasattr(config, "use_cache") and config.use_cache is False:
160160
config.use_cache = True
161161

162-
eager_model = AutoModel.from_pretrained(
162+
# Using `AutoModelForPreTraining` since it usually routes to the correct model variant and there is no
163+
# auto model class that captures both audio and image.
164+
# The correct model variant we are looking for is <Model>ForConditionalGeneration, since it is the top-level
165+
# model and thus will always contain the necessary model components. As an example of why this is needed,
166+
# if you just use Gemma3Model instead of Gemma3ForConditionalGeneration, Gemma3Model (which is the decoder part)
167+
# will not contain the LM head, which is only applied in the latter.
168+
eager_model = AutoModelForPreTraining.from_pretrained(
163169
model_name_or_path,
164170
device_map=device,
165171
torch_dtype=dtype,
166172
config=config,
167173
attn_implementation=attn_implementation,
168-
generation_config=GenerationConfig(
169-
use_cache=True,
170-
cache_implementation=cache_implementation,
171-
max_length=max_length,
172-
cache_config={
173-
"batch_size": batch_size,
174-
"max_cache_len": max_length,
175-
},
176-
),
174+
)
175+
eager_model.generation_config = GenerationConfig(
176+
use_cache=True,
177+
cache_implementation=cache_implementation,
178+
max_length=max_length,
179+
cache_config={
180+
"batch_size": batch_size,
181+
"max_cache_len": max_length,
182+
},
177183
)
178184
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
179185
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name

optimum/exporters/executorch/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ def save_config_to_constant_methods(
7171
return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None}
7272

7373

74+
def apply_chat_template_with_fallback(processor, conversation, **kwargs):
75+
"""
76+
Apply chat template with fallback for external processors.
77+
78+
For duck-typed external processors that aren't defined in Transformers, e.g.
79+
Voxtral's processor which is defined in mistral-common.
80+
These processors aren't guaranteed to have some of the other kwargs such as
81+
"add_generation_prompt".
82+
83+
Args:
84+
processor: The processor instance
85+
conversation: The conversation to process
86+
**kwargs: Additional keyword arguments to pass to apply_chat_template
87+
88+
Returns:
89+
The processed inputs from apply_chat_template
90+
"""
91+
try:
92+
return processor.apply_chat_template(conversation, **kwargs)
93+
except ValueError:
94+
# Fallback for external processors - just pass the conversation
95+
return processor.apply_chat_template(conversation)
96+
97+
7498
def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: PreTrainedTokenizer) -> bool:
7599
"""
76100
Verifies that the model's EOS token IDs are present in the tokenizer's

0 commit comments

Comments
 (0)