Skip to content

Commit d35f4c6

Browse files
committed
Add prompt and temperature args to OpenAI and Groq hosted Whisper STT services
1 parent 0a990b2 commit d35f4c6

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

Diff for: src/pipecat/services/base_whisper.py

+6
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class BaseWhisperSTTService(SegmentedSTTService):
111111
api_key: Service API key. Defaults to None.
112112
base_url: Service API base URL. Defaults to None.
113113
language: Language of the audio input. Defaults to English.
114+
prompt: Optional text to guide the model's style or continue a previous segment.
115+
temperature: Sampling temperature between 0 and 1. Defaults to 0.0.
114116
**kwargs: Additional arguments passed to SegmentedSTTService.
115117
"""
116118

@@ -121,12 +123,16 @@ def __init__(
121123
api_key: Optional[str] = None,
122124
base_url: Optional[str] = None,
123125
language: Optional[Language] = Language.EN,
126+
prompt: Optional[str] = None,
127+
temperature: float = 0.0,
124128
**kwargs,
125129
):
126130
super().__init__(**kwargs)
127131
self.set_model_name(model)
128132
self._client = self._create_client(api_key, base_url)
129133
self._language = self.language_to_service_language(language or Language.EN)
134+
self._prompt = prompt
135+
self._temperature = temperature
130136

131137
def _create_client(self, api_key: Optional[str], base_url: Optional[str]):
132138
return AsyncOpenAI(api_key=api_key, base_url=base_url)

Diff for: src/pipecat/services/groq.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class GroqSTTService(BaseWhisperSTTService):
5454
api_key: Groq API key. Defaults to None.
5555
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
5656
language: Language of the audio input. Defaults to English.
57+
prompt: Optional text to guide the model's style or continue a previous segment.
58+
temperature: Sampling temperature between 0 and 1. Defaults to 0.0.
5759
**kwargs: Additional arguments passed to BaseWhisperSTTService.
5860
"""
5961

@@ -64,17 +66,35 @@ def __init__(
6466
api_key: Optional[str] = None,
6567
base_url: str = "https://api.groq.com/openai/v1",
6668
language: Optional[Language] = Language.EN,
69+
prompt: Optional[str] = None,
70+
temperature: float = 0.0,
6771
**kwargs,
6872
):
6973
super().__init__(
70-
model=model, api_key=api_key, base_url=base_url, language=language, **kwargs
74+
model=model,
75+
api_key=api_key,
76+
base_url=base_url,
77+
language=language,
78+
prompt=prompt,
79+
temperature=temperature,
80+
**kwargs,
7181
)
7282

7383
async def _transcribe(self, audio: bytes) -> Transcription:
7484
assert self._language is not None # Assigned in the BaseWhisperSTTService class
75-
return await self._client.audio.transcriptions.create(
76-
file=("audio.wav", audio, "audio/wav"),
77-
model=self.model_name,
78-
response_format="json",
79-
language=self._language,
80-
)
85+
86+
# Build kwargs dict with only set parameters
87+
kwargs = {
88+
"file": ("audio.wav", audio, "audio/wav"),
89+
"model": self.model_name,
90+
"response_format": "json",
91+
"language": self._language,
92+
}
93+
94+
if self._prompt is not None:
95+
kwargs["prompt"] = self._prompt
96+
97+
if self._temperature is not None:
98+
kwargs["temperature"] = self._temperature
99+
100+
return await self._client.audio.transcriptions.create(**kwargs)

Diff for: src/pipecat/services/openai.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ class OpenAISTTService(BaseWhisperSTTService):
408408
api_key: OpenAI API key. Defaults to None.
409409
base_url: API base URL. Defaults to None.
410410
language: Language of the audio input. Defaults to English.
411+
prompt: Optional text to guide the model's style or continue a previous segment.
412+
temperature: Sampling temperature between 0 and 1. Defaults to 0.0.
411413
**kwargs: Additional arguments passed to BaseWhisperSTTService.
412414
"""
413415

@@ -418,17 +420,37 @@ def __init__(
418420
api_key: Optional[str] = None,
419421
base_url: Optional[str] = None,
420422
language: Optional[Language] = Language.EN,
423+
prompt: Optional[str] = None,
424+
temperature: float = 0.0,
421425
**kwargs,
422426
):
423427
super().__init__(
424-
model=model, api_key=api_key, base_url=base_url, language=language, **kwargs
428+
model=model,
429+
api_key=api_key,
430+
base_url=base_url,
431+
language=language,
432+
prompt=prompt,
433+
temperature=temperature,
434+
**kwargs,
425435
)
426436

427437
async def _transcribe(self, audio: bytes) -> Transcription:
428438
assert self._language is not None # Assigned in the BaseWhisperSTTService class
429-
return await self._client.audio.transcriptions.create(
430-
file=("audio.wav", audio, "audio/wav"), model=self.model_name, language=self._language
431-
)
439+
440+
# Build kwargs dict with only set parameters
441+
kwargs = {
442+
"file": ("audio.wav", audio, "audio/wav"),
443+
"model": self.model_name,
444+
"language": self._language,
445+
}
446+
447+
if self._prompt is not None:
448+
kwargs["prompt"] = self._prompt
449+
450+
if self._temperature is not None:
451+
kwargs["temperature"] = self._temperature
452+
453+
return await self._client.audio.transcriptions.create(**kwargs)
432454

433455

434456
class OpenAITTSService(TTSService):

0 commit comments

Comments
 (0)