Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/faster_whisper_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,14 @@ class Config(BaseSettings):
Controls how many latest seconds of audio are being passed through VAD.
Should be greater than `max_inactivity_seconds`
"""

chat_completion_base_url: str = "https://api.openai.com/v1"
chat_completion_api_key: str | None = None

speech_base_url: str | None = None
speech_api_key: str | None = None
speech_model: str = "piper"
speech_extra_body: dict = {"sample_rate": 24000}

transcription_base_url: str | None = None
transcription_api_key: str | None = None
57 changes: 57 additions & 0 deletions src/faster_whisper_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from httpx import ASGITransport, AsyncClient
from openai import AsyncOpenAI
from openai.resources.audio import AsyncSpeech, AsyncTranscriptions
from openai.resources.chat.completions import AsyncCompletions

from faster_whisper_server.config import Config
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
Expand Down Expand Up @@ -45,3 +49,56 @@ async def verify_api_key(


ApiKeyDependency = Depends(verify_api_key)


@lru_cache
def get_completion_client() -> AsyncCompletions:
config = get_config() # HACK
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
return oai_client.chat.completions


CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_client)]


@lru_cache
def get_speech_client() -> AsyncSpeech:
config = get_config() # HACK
if config.speech_base_url is None:
# this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
from faster_whisper_server.routers.speech import (
router as speech_router,
)

http_client = AsyncClient(
transport=ASGITransport(speech_router), base_url="http://test/v1"
) # NOTE: "test" can be replaced with any other value
oai_client = AsyncOpenAI(http_client=http_client, api_key=config.speech_api_key)
else:
oai_client = AsyncOpenAI(base_url=config.speech_base_url, api_key=config.speech_api_key)
return oai_client.audio.speech


SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]


@lru_cache
def get_transcription_client() -> AsyncTranscriptions:
config = get_config()
if config.transcription_base_url is None:
# this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
from faster_whisper_server.routers.stt import (
router as stt_router,
)

http_client = AsyncClient(
transport=ASGITransport(stt_router), base_url="http://test/v1"
) # NOTE: "test" can be replaced with any other value

oai_client = AsyncOpenAI(http_client=http_client, api_key=config.transcription_api_key)
else:
oai_client = AsyncOpenAI(base_url=config.transcription_base_url, api_key=config.transcription_api_key)
return oai_client.audio.transcriptions


TranscriptionClientDependency = Annotated[AsyncTranscriptions, Depends(get_transcription_client)]
3 changes: 2 additions & 1 deletion src/faster_whisper_server/routers/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ class CreateSpeechRequestBody(BaseModel):
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
speed: float = Field(1.0, ge=0.25, le=4.0)
"""The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE) # TODO: document
sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE)
"""Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used.""" # noqa: E501

# TODO: move into `Voice`
@model_validator(mode="after")
Expand Down
Loading