diff --git a/src/faster_whisper_server/config.py b/src/faster_whisper_server/config.py index 14b4230c..1feefdaa 100644 --- a/src/faster_whisper_server/config.py +++ b/src/faster_whisper_server/config.py @@ -180,6 +180,7 @@ class Config(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") + api_key: str | None = None log_level: str = "debug" host: str = Field(alias="UVICORN_HOST", default="0.0.0.0") port: int = Field(alias="UVICORN_PORT", default=8000) diff --git a/src/faster_whisper_server/dependencies.py b/src/faster_whisper_server/dependencies.py index 985585c0..6b273ddd 100644 --- a/src/faster_whisper_server/dependencies.py +++ b/src/faster_whisper_server/dependencies.py @@ -1,7 +1,8 @@ from functools import lru_cache from typing import Annotated -from fastapi import Depends +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from faster_whisper_server.config import Config from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager @@ -31,3 +32,16 @@ def get_piper_model_manager() -> PiperModelManager: PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)] + + +security = HTTPBearer() + + +async def verify_api_key( + config: ConfigDependency, credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)] +) -> None: + if credentials.credentials != config.api_key: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +ApiKeyDependency = Depends(verify_api_key) diff --git a/src/faster_whisper_server/main.py b/src/faster_whisper_server/main.py index 6936460e..f59036bd 100644 --- a/src/faster_whisper_server/main.py +++ b/src/faster_whisper_server/main.py @@ -10,7 +10,7 @@ ) from fastapi.middleware.cors import CORSMiddleware -from faster_whisper_server.dependencies import get_config, get_model_manager +from faster_whisper_server.dependencies import get_config, get_model_manager, verify_api_key from faster_whisper_server.logger import setup_logger from faster_whisper_server.routers.list_models import ( router as list_models_router, @@ -50,7 +50,11 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: model_manager.load_model(model_name) yield - app = FastAPI(lifespan=lifespan) + dependencies = [] + if config.api_key is not None: + dependencies.append(verify_api_key) + + app = FastAPI(lifespan=lifespan, dependencies=dependencies) app.include_router(stt_router) app.include_router(list_models_router)