Skip to content

Commit fa1e0d5

Browse files
yotamNfedirz
authored andcommitted
Download models on startup
1 parent 0469a9e commit fa1e0d5

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

src/speaches/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,11 @@ class Config(BaseSettings):
130130
OpenTelemetry service name for identifying this application in traces.
131131
Shadows OTEL_SERVICE_NAME environment variable.
132132
"""
133+
134+
preload_models: list[str] = []
135+
"""
136+
List of model IDs to download during application startup.
137+
Models will be downloaded sequentially if they do not already exist locally.
138+
Application will exit if any model fails to download or is not found in the registry.
139+
Example: ["Systran/faster-whisper-tiny", "rhasspy/piper-voices"]
140+
"""

src/speaches/executors/shared/registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
pyannote_speaker_embedding_model_registry,
1414
)
1515
from speaches.executors.shared.executor import Executor
16-
from speaches.executors.silero_vad_v5 import SileroVADModelManager, silero_vad_model_registry
16+
from speaches.executors.silero_vad_v5 import (
17+
SileroVADModelManager,
18+
silero_vad_model_registry,
19+
)
1720
from speaches.executors.whisper import WhisperModelManager, whisper_model_registry
1821

1922

@@ -85,3 +88,9 @@ def all_executors(self): # noqa: ANN201
8588
self._pyannote_executor,
8689
self._vad_executor,
8790
)
91+
92+
def download_model_by_id(self, model_id: str) -> bool:
93+
for executor in self.all_executors():
94+
if model_id in [model.id for model in executor.model_registry.list_remote_models()]:
95+
return executor.model_registry.download_model_files_if_not_exist(model_id)
96+
raise ValueError(f"Model '{model_id}' not found")

src/speaches/main.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncGenerator
4+
from contextlib import asynccontextmanager
35
import logging
46
import os
57
import uuid
@@ -19,7 +21,7 @@
1921
from starlette.exceptions import HTTPException as StarletteHTTPException
2022
from starlette.responses import RedirectResponse
2123

22-
from speaches.dependencies import ApiKeyDependency, get_config
24+
from speaches.dependencies import ApiKeyDependency, get_config, get_executor_registry
2325
from speaches.logger import setup_logger
2426
from speaches.routers.chat import (
2527
router as chat_router,
@@ -66,6 +68,23 @@
6668
]
6769

6870

71+
@asynccontextmanager
72+
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
73+
logger = logging.getLogger(__name__)
74+
config = get_config()
75+
76+
if config.preload_models:
77+
logger.info(f"Preloading {len(config.preload_models)} models on startup")
78+
executor_registry = get_executor_registry()
79+
80+
for model_id in config.preload_models:
81+
logger.info(f"Downloading model: {model_id}")
82+
executor_registry.download_model_by_id(model_id)
83+
logger.info(f"Successfully downloaded model: {model_id}")
84+
85+
yield
86+
87+
6988
def create_app() -> FastAPI:
7089
config = get_config() # HACK
7190
setup_logger(config.log_level)
@@ -94,6 +113,7 @@ def create_app() -> FastAPI:
94113
version="0.8.3", # TODO: update this on release
95114
license_info={"name": "MIT License", "identifier": "MIT"},
96115
openapi_tags=TAGS_METADATA,
116+
lifespan=lifespan,
97117
)
98118

99119
# Instrument FastAPI app if telemetry is enabled

src/speaches/routers/models.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class ListAudioModelsResponse(BaseModel):
4040

4141
# HACK: returning ListModelsResponse directly causes extra `Model` fields to be omitted
4242
@router.get("/v1/audio/models", response_model=ListAudioModelsResponse)
43-
def list_local_audio_models(executor_registry: ExecutorRegistryDependency) -> JSONResponse:
43+
def list_local_audio_models(
44+
executor_registry: ExecutorRegistryDependency,
45+
) -> JSONResponse:
4446
models: list[Model] = []
4547
for executor in executor_registry.text_to_speech:
4648
models.extend(list(executor.model_registry.list_local_models()))
@@ -53,7 +55,9 @@ class ListVoicesResponse(BaseModel):
5355

5456
# HACK: returning ListModelsResponse directly causes extra `Model` fields to be omitted
5557
@router.get("/v1/audio/voices", response_model=ListModelsResponse)
56-
def list_local_audio_voices(executor_registry: ExecutorRegistryDependency) -> JSONResponse:
58+
def list_local_audio_voices(
59+
executor_registry: ExecutorRegistryDependency,
60+
) -> JSONResponse:
5761
models: list[KokoroModel | PiperModel] = []
5862
for executor in executor_registry.text_to_speech:
5963
models.extend(list(executor.model_registry.list_local_models()))
@@ -77,14 +81,14 @@ def get_local_model(executor_registry: ExecutorRegistryDependency, model_id: Mod
7781
# NOTE: without `response_model` and `JSONResponse` extra fields aren't included in the response
7882
@router.post("/v1/models/{model_id:path}")
7983
def download_remote_model(executor_registry: ExecutorRegistryDependency, model_id: ModelId) -> Response:
80-
for executor in executor_registry.all_executors():
81-
if model_id in [model.id for model in executor.model_registry.list_remote_models()]:
82-
was_downloaded = executor.model_registry.download_model_files_if_not_exist(model_id)
83-
if was_downloaded:
84-
return Response(status_code=200, content=f"Model '{model_id}' downloaded")
85-
else:
86-
return Response(status_code=201, content=f"Model '{model_id}' already exists")
87-
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
84+
try:
85+
was_downloaded = executor_registry.download_model_by_id(model_id)
86+
if was_downloaded:
87+
return Response(status_code=200, content=f"Model '{model_id}' downloaded")
88+
else:
89+
return Response(status_code=201, content=f"Model '{model_id}' already exists")
90+
except ValueError as error:
91+
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") from error
8892

8993

9094
# TODO: document that any model will be deleted regardless if it's supported speaches or not

0 commit comments

Comments
 (0)