|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from collections.abc import AsyncGenerator |
| 4 | +from contextlib import asynccontextmanager |
3 | 5 | import logging |
4 | 6 | import os |
5 | 7 | import uuid |
|
19 | 21 | from starlette.exceptions import HTTPException as StarletteHTTPException |
20 | 22 | from starlette.responses import RedirectResponse |
21 | 23 |
|
22 | | -from speaches.dependencies import ApiKeyDependency, get_config |
| 24 | +from speaches.dependencies import ApiKeyDependency, get_config, get_executor_registry |
23 | 25 | from speaches.logger import setup_logger |
24 | 26 | from speaches.routers.chat import ( |
25 | 27 | router as chat_router, |
|
66 | 68 | ] |
67 | 69 |
|
68 | 70 |
|
| 71 | +@asynccontextmanager |
| 72 | +async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: |
| 73 | + logger = logging.getLogger(__name__) |
| 74 | + config = get_config() |
| 75 | + |
| 76 | + if config.preload_models: |
| 77 | + logger.info(f"Preloading {len(config.preload_models)} models on startup") |
| 78 | + executor_registry = get_executor_registry() |
| 79 | + |
| 80 | + for model_id in config.preload_models: |
| 81 | + try: |
| 82 | + logger.info(f"Downloading model: {model_id}") |
| 83 | + executor_registry.download_model_by_id(model_id) |
| 84 | + logger.info(f"Successfully downloaded model: {model_id}") |
| 85 | + except Exception: |
| 86 | + logger.exception(f"Failed to download model {model_id}") |
| 87 | + |
| 88 | + yield |
| 89 | + |
| 90 | + |
69 | 91 | def create_app() -> FastAPI: |
70 | 92 | config = get_config() # HACK |
71 | 93 | setup_logger(config.log_level) |
@@ -94,6 +116,7 @@ def create_app() -> FastAPI: |
94 | 116 | version="0.8.3", # TODO: update this on release |
95 | 117 | license_info={"name": "MIT License", "identifier": "MIT"}, |
96 | 118 | openapi_tags=TAGS_METADATA, |
| 119 | + lifespan=lifespan, |
97 | 120 | ) |
98 | 121 |
|
99 | 122 | # Instrument FastAPI app if telemetry is enabled |
|
0 commit comments