Skip to content

Commit c995042

Browse files
committed
Download models on startup
1 parent 0469a9e commit c995042

File tree

5 files changed

+58
-5
lines changed

5 files changed

+58
-5
lines changed

src/speaches/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ class Config(BaseSettings):
118118

119119
unstable_ort_opts: OrtOptions = OrtOptions()
120120

121-
otel_exporter_otlp_endpoint: str | None = None
122121
"""
123122
OpenTelemetry OTLP exporter endpoint. If set, telemetry will be enabled.
124123
Example: 'http://localhost:4317'
@@ -130,3 +129,5 @@ class Config(BaseSettings):
130129
OpenTelemetry service name for identifying this application in traces.
131130
Shadows OTEL_SERVICE_NAME environment variable.
132131
"""
132+
133+
preload_models: list[str] = []

src/speaches/executors/shared/registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
pyannote_speaker_embedding_model_registry,
1414
)
1515
from speaches.executors.shared.executor import Executor
16-
from speaches.executors.silero_vad_v5 import SileroVADModelManager, silero_vad_model_registry
16+
from speaches.executors.silero_vad_v5 import (
17+
SileroVADModelManager,
18+
silero_vad_model_registry,
19+
)
1720
from speaches.executors.whisper import WhisperModelManager, whisper_model_registry
1821

1922

@@ -85,3 +88,9 @@ def all_executors(self): # noqa: ANN201
8588
self._pyannote_executor,
8689
self._vad_executor,
8790
)
91+
92+
def download_model_by_id(self, model_id: str) -> bool:
93+
for executor in self.all_executors():
94+
if model_id in [model.id for model in executor.model_registry.list_remote_models()]:
95+
return executor.model_registry.download_model_files_if_not_exist(model_id)
96+
return False

src/speaches/main.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncGenerator
4+
from contextlib import asynccontextmanager
35
import logging
46
import os
57
import uuid
@@ -19,7 +21,7 @@
1921
from starlette.exceptions import HTTPException as StarletteHTTPException
2022
from starlette.responses import RedirectResponse
2123

22-
from speaches.dependencies import ApiKeyDependency, get_config
24+
from speaches.dependencies import ApiKeyDependency, get_config, get_executor_registry
2325
from speaches.logger import setup_logger
2426
from speaches.routers.chat import (
2527
router as chat_router,
@@ -66,6 +68,26 @@
6668
]
6769

6870

71+
@asynccontextmanager
72+
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
73+
logger = logging.getLogger(__name__)
74+
config = get_config()
75+
76+
if config.preload_models:
77+
logger.info(f"Preloading {len(config.preload_models)} models on startup")
78+
executor_registry = get_executor_registry()
79+
80+
for model_id in config.preload_models:
81+
try:
82+
logger.info(f"Downloading model: {model_id}")
83+
executor_registry.download_model_by_id(model_id)
84+
logger.info(f"Successfully downloaded model: {model_id}")
85+
except Exception:
86+
logger.exception(f"Failed to download model {model_id}")
87+
88+
yield
89+
90+
6991
def create_app() -> FastAPI:
7092
config = get_config() # HACK
7193
setup_logger(config.log_level)
@@ -94,6 +116,7 @@ def create_app() -> FastAPI:
94116
version="0.8.3", # TODO: update this on release
95117
license_info={"name": "MIT License", "identifier": "MIT"},
96118
openapi_tags=TAGS_METADATA,
119+
lifespan=lifespan,
97120
)
98121

99122
# Instrument FastAPI app if telemetry is enabled

src/speaches/registry_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
def download_model_by_id(model_id: str) -> bool:
2+
from speaches.executors.kokoro.utils import model_registry as kokoro_model_registry
3+
from speaches.executors.parakeet.utils import model_registry as parakeet_model_registry
4+
from speaches.executors.piper.utils import model_registry as piper_model_registry
5+
from speaches.executors.whisper.utils import model_registry as whisper_model_registry
6+
7+
if model_id in [model.id for model in kokoro_model_registry.list_remote_models()]:
8+
return kokoro_model_registry.download_model_files_if_not_exist(model_id)
9+
elif model_id in [model.id for model in piper_model_registry.list_remote_models()]:
10+
return piper_model_registry.download_model_files_if_not_exist(model_id)
11+
elif model_id in [model.id for model in whisper_model_registry.list_remote_models()]:
12+
return whisper_model_registry.download_model_files_if_not_exist(model_id)
13+
elif model_id in [model.id for model in parakeet_model_registry.list_remote_models()]:
14+
return parakeet_model_registry.download_model_files_if_not_exist(model_id)
15+
else:
16+
raise ValueError(f"Model '{model_id}' not found in registry")

src/speaches/routers/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class ListAudioModelsResponse(BaseModel):
4040

4141
# HACK: returning ListModelsResponse directly causes extra `Model` fields to be omitted
4242
@router.get("/v1/audio/models", response_model=ListAudioModelsResponse)
43-
def list_local_audio_models(executor_registry: ExecutorRegistryDependency) -> JSONResponse:
43+
def list_local_audio_models(
44+
executor_registry: ExecutorRegistryDependency,
45+
) -> JSONResponse:
4446
models: list[Model] = []
4547
for executor in executor_registry.text_to_speech:
4648
models.extend(list(executor.model_registry.list_local_models()))
@@ -53,7 +55,9 @@ class ListVoicesResponse(BaseModel):
5355

5456
# HACK: returning ListModelsResponse directly causes extra `Model` fields to be omitted
5557
@router.get("/v1/audio/voices", response_model=ListModelsResponse)
56-
def list_local_audio_voices(executor_registry: ExecutorRegistryDependency) -> JSONResponse:
58+
def list_local_audio_voices(
59+
executor_registry: ExecutorRegistryDependency,
60+
) -> JSONResponse:
5761
models: list[KokoroModel | PiperModel] = []
5862
for executor in executor_registry.text_to_speech:
5963
models.extend(list(executor.model_registry.list_local_models()))

0 commit comments

Comments
 (0)