diff --git a/src/faster_whisper_server/config.py b/src/faster_whisper_server/config.py index 4f37de47..9e462d54 100644 --- a/src/faster_whisper_server/config.py +++ b/src/faster_whisper_server/config.py @@ -236,3 +236,14 @@ class Config(BaseSettings): Controls how many latest seconds of audio are being passed through VAD. Should be greater than `max_inactivity_seconds` """ + + chat_completion_base_url: str = "https://api.openai.com/v1" + chat_completion_api_key: str | None = None + + speech_base_url: str | None = None + speech_api_key: str | None = None + speech_model: str = "piper" + speech_extra_body: dict = {"sample_rate": 24000} + + transcription_base_url: str | None = None + transcription_api_key: str | None = None diff --git a/src/faster_whisper_server/dependencies.py b/src/faster_whisper_server/dependencies.py index 6b273ddd..3266b145 100644 --- a/src/faster_whisper_server/dependencies.py +++ b/src/faster_whisper_server/dependencies.py @@ -3,6 +3,10 @@ from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from httpx import ASGITransport, AsyncClient +from openai import AsyncOpenAI +from openai.resources.audio import AsyncSpeech, AsyncTranscriptions +from openai.resources.chat.completions import AsyncCompletions from faster_whisper_server.config import Config from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager @@ -45,3 +49,56 @@ async def verify_api_key( ApiKeyDependency = Depends(verify_api_key) + + +@lru_cache +def get_completion_client() -> AsyncCompletions: + config = get_config() # HACK + oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key) + return oai_client.chat.completions + + +CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_client)] + + +@lru_cache +def get_speech_client() -> AsyncSpeech: + config = get_config() # HACK + if config.speech_base_url is None: + # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 + from faster_whisper_server.routers.speech import ( + router as speech_router, + ) + + http_client = AsyncClient( + transport=ASGITransport(speech_router), base_url="http://test/v1" + ) # NOTE: "test" can be replaced with any other value + oai_client = AsyncOpenAI(http_client=http_client, api_key=config.speech_api_key) + else: + oai_client = AsyncOpenAI(base_url=config.speech_base_url, api_key=config.speech_api_key) + return oai_client.audio.speech + + +SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)] + + +@lru_cache +def get_transcription_client() -> AsyncTranscriptions: + config = get_config() + if config.transcription_base_url is None: + # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 + from faster_whisper_server.routers.stt import ( + router as stt_router, + ) + + http_client = AsyncClient( + transport=ASGITransport(stt_router), base_url="http://test/v1" + ) # NOTE: "test" can be replaced with any other value + + oai_client = AsyncOpenAI(http_client=http_client, api_key=config.transcription_api_key) + else: + oai_client = AsyncOpenAI(base_url=config.transcription_base_url, api_key=config.transcription_api_key) + return oai_client.audio.transcriptions + + +TranscriptionClientDependency = Annotated[AsyncTranscriptions, Depends(get_transcription_client)]