9
9
import asyncio
10
10
from enum import Enum
11
11
from typing import AsyncGenerator , Optional
12
- from typing_extensions import TYPE_CHECKING , override
13
12
14
13
import numpy as np
15
14
from loguru import logger
15
+ from typing_extensions import TYPE_CHECKING , override
16
16
17
17
from pipecat .frames .frames import ErrorFrame , Frame , TranscriptionFrame
18
18
from pipecat .services .ai_services import SegmentedSTTService
26
26
logger .error (f"Exception: { e } " )
27
27
logger .error ("In order to use Whisper, you need to `pip install pipecat-ai[whisper]`." )
28
28
raise Exception (f"Missing module: { e } " )
29
-
29
+
30
30
try :
31
31
import mlx_whisper
32
32
except ModuleNotFoundError as e :
@@ -332,6 +332,7 @@ def _load(self):
332
332
"""
333
333
try :
334
334
from faster_whisper import WhisperModel
335
+
335
336
logger .debug ("Loading Whisper model..." )
336
337
self ._model = WhisperModel (
337
338
self .model_name , device = self ._device , compute_type = self ._compute_type
@@ -414,22 +415,22 @@ def __init__(
414
415
):
415
416
# Skip WhisperSTTService.__init__ and call its parent directly
416
417
SegmentedSTTService .__init__ (self , ** kwargs )
417
-
418
+
418
419
self .set_model_name (model if isinstance (model , str ) else model .value )
419
420
self ._no_speech_prob = no_speech_prob
420
421
self ._temperature = temperature
421
422
422
423
self ._settings = {
423
424
"language" : language ,
424
425
}
425
-
426
+
426
427
# No need to call _load() as MLX Whisper loads models on demand
427
428
428
429
@override
429
430
def _load (self ):
430
431
"""MLX Whisper loads models on demand, so this is a no-op."""
431
432
pass
432
-
433
+
433
434
@override
434
435
async def run_stt (self , audio : bytes ) -> AsyncGenerator [Frame , None ]:
435
436
"""Transcribes given audio using MLX Whisper.
@@ -447,7 +448,7 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
447
448
"""
448
449
try :
449
450
import mlx_whisper
450
-
451
+
451
452
await self .start_processing_metrics ()
452
453
await self .start_ttfb_metrics ()
453
454
@@ -456,10 +457,11 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
456
457
457
458
whisper_lang = self .language_to_service_language (self ._settings ["language" ])
458
459
chunk = await asyncio .to_thread (
459
- mlx_whisper .transcribe , audio_float ,
460
+ mlx_whisper .transcribe ,
461
+ audio_float ,
460
462
path_or_hf_repo = self .model_name ,
461
463
temperature = self ._temperature ,
462
- language = whisper_lang
464
+ language = whisper_lang ,
463
465
)
464
466
text : str = ""
465
467
for segment in chunk .get ("segments" , []):
@@ -475,11 +477,11 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
475
477
476
478
await self .stop_ttfb_metrics ()
477
479
await self .stop_processing_metrics ()
478
-
480
+
479
481
if text :
480
482
logger .debug (f"Transcription: [{ text } ]" )
481
483
yield TranscriptionFrame (text , "" , time_now_iso8601 (), self ._settings ["language" ])
482
-
484
+
483
485
except Exception as e :
484
486
logger .exception (f"MLX Whisper transcription error: { e } " )
485
487
yield ErrorFrame (f"MLX Whisper transcription error: { str (e )} " )
0 commit comments