Skip to content

Commit d4d3046

Browse files
authored
Add Granite Speech 3.3 (#144)
1 parent 828ae02 commit d4d3046

File tree

6 files changed

+267
-50
lines changed

6 files changed

+267
-50
lines changed

optimum/executorch/modeling.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545

4646
from ..exporters import TasksManager
4747
from ..exporters.executorch import main_export
48-
from ..exporters.executorch.utils import apply_chat_template_with_fallback, verify_eos_tokens_in_pretrained_tokenizer
48+
from ..exporters.executorch.utils import (
49+
process_conversation_inputs,
50+
verify_eos_tokens_in_pretrained_tokenizer,
51+
)
4952
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
5053
from ..utils.file_utils import find_files_matching_pattern
5154
from .stats import Stats
@@ -88,7 +91,11 @@ class ExecuTorchModelBase(OptimizedModel, ABC):
8891

8992
auto_model_class = None
9093

91-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
94+
def __init__(
95+
self,
96+
models: Dict[str, "ExecuTorchModule"],
97+
config: "PretrainedConfig",
98+
):
9299
super().__init__(model=None, config=config)
93100

94101
if self.__class__.auto_model_class is None:
@@ -444,7 +451,11 @@ class ExecuTorchModelForSeq2SeqLM(ExecuTorchModelBase):
444451

445452
auto_model_class = AutoModelForSeq2SeqLM
446453

447-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
454+
def __init__(
455+
self,
456+
models: Dict[str, "ExecuTorchModule"],
457+
config: "PretrainedConfig",
458+
):
448459
super().__init__(models=models, config=config)
449460
if not hasattr(self, "encoder"):
450461
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
@@ -640,7 +651,11 @@ class ExecuTorchModelForCausalLM(ExecuTorchModelBase):
640651

641652
auto_model_class = AutoModelForCausalLM
642653

643-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
654+
def __init__(
655+
self,
656+
models: Dict[str, "ExecuTorchModule"],
657+
config: "PretrainedConfig",
658+
):
644659
super().__init__(models, config)
645660
if not hasattr(self, "model"):
646661
raise AttributeError("Expected attribute 'model' not found in the instance.")
@@ -862,7 +877,11 @@ class ExecuTorchModelForMaskedLM(ExecuTorchModelBase):
862877

863878
auto_model_class = AutoModelForMaskedLM
864879

865-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
880+
def __init__(
881+
self,
882+
models: Dict[str, "ExecuTorchModule"],
883+
config: "PretrainedConfig",
884+
):
866885
super().__init__(models, config)
867886
if not hasattr(self, "model"):
868887
raise AttributeError("Expected attribute 'model' not found in the instance.")
@@ -934,7 +953,11 @@ class ExecuTorchModelForImageClassification(ExecuTorchModelBase):
934953

935954
auto_model_class = AutoModelForImageClassification
936955

937-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
956+
def __init__(
957+
self,
958+
models: Dict[str, "ExecuTorchModule"],
959+
config: "PretrainedConfig",
960+
):
938961
super().__init__(models, config)
939962
if not hasattr(self, "model"):
940963
raise AttributeError("Expected attribute 'model' not found in the instance.")
@@ -993,7 +1016,11 @@ class ExecuTorchModelForSpeechSeq2Seq(ExecuTorchModelBase):
9931016

9941017
auto_model_class = AutoModelForSpeechSeq2Seq
9951018

996-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
1019+
def __init__(
1020+
self,
1021+
models: Dict[str, "ExecuTorchModule"],
1022+
config: "PretrainedConfig",
1023+
):
9971024
super().__init__(models=models, config=config)
9981025
if not hasattr(self, "encoder"):
9991026
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
@@ -1172,7 +1199,11 @@ class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
11721199
# task type. For MultiModal, we should always be specifying the task type anyways.
11731200
auto_model_class = AutoModel
11741201

1175-
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
1202+
def __init__(
1203+
self,
1204+
models: Dict[str, "ExecuTorchModule"],
1205+
config: "PretrainedConfig",
1206+
):
11761207
super().__init__(models=models, config=config)
11771208
required_methods = ["text_decoder", "token_embedding"]
11781209
for required_method in required_methods:
@@ -1329,13 +1360,10 @@ def text_generation(
13291360
self.stats.reset()
13301361
self.stats.on_inference_start()
13311362

1332-
inputs = apply_chat_template_with_fallback(
1363+
inputs = process_conversation_inputs(
13331364
processor,
1365+
tokenizer,
13341366
input_conversation,
1335-
add_generation_prompt=True,
1336-
tokenize=True,
1337-
return_dict=True,
1338-
return_tensors="pt",
13391367
)
13401368

13411369
self.stats.on_token_encode_end()

optimum/exporters/executorch/integrations.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.export import ExportedProgram
2121
from torch.nn.attention import SDPBackend
2222
from transformers import (
23+
AutoConfig,
2324
AutoProcessor,
2425
PreTrainedModel,
2526
StaticCache,
@@ -88,30 +89,47 @@ def prepare_export_inputs(self):
8889
# 1. Get export inputs
8990
model_id = self.model.config.name_or_path
9091
processor = AutoProcessor.from_pretrained(model_id)
91-
sample_conversation_with_audio = [
92-
{
93-
"role": "user",
94-
"content": [
95-
{
96-
"type": "audio",
97-
"url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav",
98-
},
99-
],
100-
}
101-
]
102-
processed_inputs = apply_chat_template_with_fallback(
103-
processor,
104-
sample_conversation_with_audio,
105-
add_generation_prompt=True,
106-
tokenize=True,
107-
return_dict=True,
108-
return_tensors="pt",
109-
)
92+
config = AutoConfig.from_pretrained(model_id)
93+
94+
if config.model_type == "granite_speech":
95+
import torchaudio
96+
from huggingface_hub import hf_hub_download
97+
98+
audio_path = hf_hub_download(repo_id=model_id, filename="10226_10111_000000.wav")
99+
wav, _sampling_rate = torchaudio.load(audio_path, normalize=True)
100+
processed_inputs = processor(
101+
"", # No text needed.
102+
wav,
103+
return_tensors="pt",
104+
)
105+
else:
106+
sample_conversation_with_audio = [
107+
{
108+
"role": "user",
109+
"content": [
110+
{
111+
"type": "audio",
112+
"url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav",
113+
},
114+
],
115+
}
116+
]
117+
processed_inputs = apply_chat_template_with_fallback(
118+
processor,
119+
sample_conversation_with_audio,
120+
add_generation_prompt=True,
121+
tokenize=True,
122+
return_dict=True,
123+
return_tensors="pt",
124+
)
110125
if "input_features" not in processed_inputs:
111126
raise ValueError(
112127
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'input_features' key: {processed_inputs}"
113128
)
114129
export_inputs = processed_inputs["input_features"]
130+
# Make sure the export inputs has a batch size > 1 so that it doesn't 0/1 specialize.
131+
if export_inputs.shape[0] == 1:
132+
export_inputs = export_inputs.repeat(2, 1, 1)
115133

116134
# 2. Get export dynamic shapes
117135
# For certain models like Voxtral, each 30 seconds represent one batch. So theoretically this caps
@@ -129,7 +147,11 @@ def forward(
129147
self,
130148
input_features: torch.FloatTensor,
131149
):
132-
audio_embeds = self.model.get_audio_embeds(input_features)
150+
# TODO: remove on next Transformers pin bump.
151+
if hasattr(self.model, "get_audio_embeds"):
152+
audio_embeds = self.model.get_audio_embeds(input_features)
153+
else:
154+
audio_embeds = self.model.get_audio_features(input_features)
133155
return audio_embeds.unsqueeze(0)
134156

135157

@@ -164,8 +186,6 @@ def __init__(
164186
):
165187
super().__init__()
166188

167-
if modality not in encoder_name:
168-
raise ValueError(f'encoder_name "{encoder_name}" does not match specified modality "{modality}".')
169189
if not hasattr(model, encoder_name):
170190
raise ValueError(f'Model does not contain encoder "{encoder_name}".')
171191

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _validate_multimodal_components(model):
3939
"text_model",
4040
]
4141
POTENTIAL_AUDIO_ENCODER_NAMES = [
42+
"encoder", # Here mainly for Granite Speech.
4243
"audio_tower",
4344
"audio_model",
4445
]
@@ -146,12 +147,9 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
146147
except (OSError, json.JSONDecodeError):
147148
processor_config = None
148149

149-
# Make sure config has text_config and vision_config:
150-
if not (hasattr(config, "text_config") and (hasattr(config, "vision_config") or hasattr(config, "audio_config"))):
151-
raise ValueError(
152-
f"The model {model_name_or_path} does not have a `text_config` or `vision_config`/`audio_config` attribute in its config. "
153-
"This is required for multimodal text-to-text models."
154-
)
150+
# Make sure config has text_config.
151+
if not (hasattr(config, "text_config")):
152+
raise ValueError(f"The model {model_name_or_path} does not have a `text_config`.")
155153

156154
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
157155
# NOTE: Avoid hitting the data-dependent control flow in _longrope_frequency_update.

optimum/exporters/executorch/utils.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Set
15+
import copy
16+
import io
17+
import logging
18+
from typing import Any, Dict, List, Optional, Set
1619

1720
import torch
21+
import transformers
1822
from transformers import GenerationConfig, PretrainedConfig
23+
from transformers.processing_utils import ProcessorMixin
1924
from transformers.tokenization_utils import PreTrainedTokenizer
2025

2126

@@ -126,3 +131,74 @@ def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenize
126131
is_valid = any(model_id in candidate_eos_ids for model_id in model_eos_ids)
127132

128133
return is_valid
134+
135+
136+
def process_conversation_inputs(
137+
processor: ProcessorMixin,
138+
tokenizer: PreTrainedTokenizer,
139+
input_conversation: List[Dict[str, Any]],
140+
):
141+
"""
142+
Process input conversation for multimodal models.
143+
144+
This function handles the preprocessing of conversation inputs, with special handling for
145+
GraniteSpeechProcessor which requires extracting and processing audio content from conversations
146+
prior to feeding into the processor.
147+
148+
Args:
149+
processor: The processor to use for input processing
150+
tokenizer: The tokenizer to use for text processing
151+
input_conversation: List of conversation messages, may contain audio content
152+
153+
Returns:
154+
Processed inputs ready for model consumption
155+
"""
156+
if isinstance(processor, transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor):
157+
import requests
158+
import torchaudio
159+
160+
conversation = copy.deepcopy(input_conversation)
161+
audio_path = None
162+
163+
# Extract audio content and remove from conversation
164+
audio_items = [(i, item) for i, item in enumerate(conversation) if item.get("type") == "audio"]
165+
if audio_items:
166+
idx, audio_item = audio_items[0]
167+
audio_path = audio_item["content"]
168+
# Remove the audio content from the input conversation since it
169+
# is handled outside for Granite.
170+
del conversation[idx]
171+
else:
172+
raise ValueError("No audio content found in conversation")
173+
174+
# Download and process audio
175+
try:
176+
resp = requests.get(audio_path)
177+
resp.raise_for_status()
178+
buf = io.BytesIO(resp.content)
179+
except requests.exceptions.RequestException:
180+
print("Could not download input audio file.")
181+
182+
wav, sampling_rate = torchaudio.load(buf, normalize=True)
183+
if wav.shape[0] != 1:
184+
wav = wav.mean(dim=0, keepdim=True) # Convert stereo to mono.
185+
logging.warning("Resampled audio stereo to mono")
186+
if sampling_rate != 16000:
187+
wav = torchaudio.functional.resample(wav, sampling_rate, 16000)
188+
logging.warning(f"Resampled audio from {sampling_rate}Hz to 16000Hz")
189+
190+
# Generate text prompt and process with audio
191+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
192+
inputs = processor(prompt, wav, return_tensors="pt")
193+
else:
194+
# Standard processing for other processors
195+
inputs = apply_chat_template_with_fallback(
196+
processor,
197+
input_conversation,
198+
add_generation_prompt=True,
199+
tokenize=True,
200+
return_dict=True,
201+
return_tensors="pt",
202+
)
203+
204+
return inputs

0 commit comments

Comments
 (0)