45
45
46
46
from ..exporters import TasksManager
47
47
from ..exporters .executorch import main_export
48
- from ..exporters .executorch .utils import verify_eos_tokens_in_pretrained_tokenizer
48
+ from ..exporters .executorch .utils import apply_chat_template_with_fallback , verify_eos_tokens_in_pretrained_tokenizer
49
49
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING , OptimizedModel
50
50
from ..utils .file_utils import find_files_matching_pattern
51
51
from .stats import Stats
@@ -1167,8 +1167,9 @@ class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
1167
1167
Size of the model vocabulary.
1168
1168
"""
1169
1169
1170
- # Using general `AutoModel` since it usually routes to the correct model variant and there is no
1171
- # auto model class that captures both audio and image.
1170
+ # Using `AutoModel` since there is no auto model class that captures both audio and image.
1171
+ # This is not too important since it's just used for automatically inferring the
1172
+ # task type. For MultiModal, we should always be specifying the task type anyways.
1172
1173
auto_model_class = AutoModel
1173
1174
1174
1175
def __init__ (self , models : Dict [str , "ExecuTorchModule" ], config : "PretrainedConfig" ):
@@ -1180,6 +1181,7 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
1180
1181
f"Exported .pte file needs to contain the following required methods: { required_methods } "
1181
1182
)
1182
1183
1184
+ # Multimodal-related metadata.
1183
1185
self .encoder_name = None
1184
1186
for method_name in self .model .method_names ():
1185
1187
if method_name == "audio_encoder" :
@@ -1192,38 +1194,28 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon
1192
1194
raise ValueError (
1193
1195
'Exported .pte file needs to contain either an an "audio_encoder" or a "vision_encoder" in its methods.'
1194
1196
)
1197
+ self .modality = self .model .run_method ("modality" )[0 ]
1195
1198
1199
+ # Decoder-related metadata.
1196
1200
metadata = self .model .method_names ()
1197
- if "use_kv_cache" in metadata :
1198
- self .use_kv_cache = self .model .run_method ("use_kv_cache" )[0 ]
1199
1201
if "get_max_seq_len" in metadata :
1200
1202
self .max_cache_size = self .model .run_method ("get_max_seq_len" )[0 ]
1201
- if "get_max_batch_size" in metadata :
1202
- self .max_batch_size = self .model .run_method ("get_max_batch_size" )[0 ]
1203
- if "get_dtype" in metadata :
1204
- self .dtype = self .model .run_method ("get_dtype" )[0 ]
1205
1203
if "get_bos_id" in metadata :
1206
1204
self .bos_token_id = self .model .run_method ("get_bos_id" )[0 ]
1207
1205
if "get_eos_id" in metadata :
1208
1206
self .eos_token_id = self .model .run_method ("get_eos_id" )[0 ]
1209
- if "get_vocab_size" in metadata :
1210
- self .vocab_size = self .model .run_method ("get_vocab_size" )[0 ]
1211
- if "max_hidden_seq_length" in metadata :
1212
- self .max_hidden_seq_length = self .model .run_method ("max_hidden_seq_length" )[0 ]
1213
- if "decoder_start_token_id" in metadata :
1214
- self .decoder_start_token_id = self .model .run_method ("decoder_start_token_id" )[0 ]
1215
1207
1216
1208
def forward (
1217
1209
self ,
1218
1210
input_ids : torch .Tensor ,
1219
1211
cache_position : torch .Tensor ,
1220
- input_features : Optional [torch .Tensor ] = None ,
1212
+ multimodal_features : Optional [torch .Tensor ] = None ,
1221
1213
):
1222
1214
token_embeddings = self .model .run_method ("token_embedding" , (input_ids ,))[0 ]
1223
- if input_features is not None :
1215
+ if multimodal_features is not None :
1224
1216
encoder_embeddings = self .model .run_method (
1225
1217
self .encoder_name ,
1226
- (input_features ,),
1218
+ (multimodal_features ,),
1227
1219
)[0 ]
1228
1220
encoder_token_mask = input_ids == self .encoder_token_id
1229
1221
token_embeddings [encoder_token_mask ] = encoder_embeddings
@@ -1242,7 +1234,7 @@ def generate(
1242
1234
echo : bool = False ,
1243
1235
pos_base : int = 0 ,
1244
1236
max_seq_len : Optional [int ] = None ,
1245
- input_features : Optional [torch .Tensor ] = None ,
1237
+ multimodal_features : Optional [torch .Tensor ] = None ,
1246
1238
) -> List [int ]:
1247
1239
self .device = torch .device ("cpu" )
1248
1240
if max_seq_len is None :
@@ -1259,7 +1251,7 @@ def generate(
1259
1251
logits = self .forward (
1260
1252
input_ids = prompt_tokens ,
1261
1253
cache_position = torch .arange (prompt_tokens .size (1 ), dtype = torch .long , device = self .device ),
1262
- input_features = input_features ,
1254
+ multimodal_features = multimodal_features ,
1263
1255
)
1264
1256
self .stats .on_sampling_end ()
1265
1257
self .stats .on_prompt_eval_end ()
@@ -1279,7 +1271,7 @@ def generate(
1279
1271
dtype = torch .long ,
1280
1272
device = self .device ,
1281
1273
),
1282
- input_features = None ,
1274
+ multimodal_features = None ,
1283
1275
)
1284
1276
self .stats .on_sampling_end ()
1285
1277
if not first_token_generated :
@@ -1337,13 +1329,26 @@ def text_generation(
1337
1329
self .stats .reset ()
1338
1330
self .stats .on_inference_start ()
1339
1331
1340
- inputs = processor .apply_chat_template (input_conversation )
1332
+ inputs = apply_chat_template_with_fallback (
1333
+ processor ,
1334
+ input_conversation ,
1335
+ add_generation_prompt = True ,
1336
+ tokenize = True ,
1337
+ return_dict = True ,
1338
+ return_tensors = "pt" ,
1339
+ )
1340
+
1341
1341
self .stats .on_token_encode_end ()
1342
1342
self .stats .set_num_prompt_tokens (len (inputs ["input_ids" ][0 ]))
1343
1343
1344
+ multimodal_features = None
1345
+ if self .modality == "vision" :
1346
+ multimodal_features = inputs .get ("pixel_values" , None )
1347
+ elif self .modality == "audio" :
1348
+ multimodal_features = inputs .get ("input_features" , None )
1344
1349
generated_tokens = self .generate (
1345
1350
prompt_tokens = inputs ["input_ids" ],
1346
- input_features = inputs [ "input_features" ] ,
1351
+ multimodal_features = multimodal_features ,
1347
1352
echo = echo ,
1348
1353
max_seq_len = len (inputs ["input_ids" ][0 ]) + max_seq_len ,
1349
1354
)
0 commit comments