-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathbootstrap.py
More file actions
164 lines (140 loc) · 5.32 KB
/
bootstrap.py
File metadata and controls
164 lines (140 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
"""
ETL Container Launcher
This script bootstraps an ETL container by:
1. Deserializing a base64-encoded ETLServer subclass passed via the ETL_CLASS_PAYLOAD env var.
2. Determining the server type (FastAPI, Flask, or HTTPMultiThreaded).
3. Installing any required Python packages (via the PACKAGES env var).
4. Starting the ETL server either in-process (for HTTPMultiThreaded) or by spawning an external process.
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""
import os
import sys
import logging
import subprocess
import base64
from typing import Type
import cloudpickle
from aistore.sdk.etl.webserver.base_etl_server import ETLServer
from aistore.sdk.etl.webserver.fastapi_server import FastAPIServer
from aistore.sdk.etl.webserver.flask_server import FlaskServer
from aistore.sdk.etl.webserver.http_multi_threaded_server import HTTPMultiThreadedServer
# ------------------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------------------
NUM_WORKERS: int = int(os.getenv("NUM_WORKERS", "6"))
ETL_CLASS_PAYLOAD: str = os.getenv("ETL_CLASS_PAYLOAD", "")
PACKAGES: str = os.getenv("PACKAGES", "")
OS_PACKAGES: str = os.getenv("OS_PACKAGES", "")
if not ETL_CLASS_PAYLOAD:
print("ERROR: ETL_CLASS_PAYLOAD is not set", file=sys.stderr)
sys.exit(1)
# ------------------------------------------------------------------------------
# Logging
# ------------------------------------------------------------------------------
log = logging.getLogger("bootstrap")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
def install(package: str) -> None:
"""Install a pip package. Exit if installation fails."""
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
except subprocess.CalledProcessError as e:
log.error("Failed to install package '%s': %s", package, e)
sys.exit(1)
def install_system(pkgs: str) -> None:
"""
Install system packages via apk (Alpine).
Some python packages require system dependencies that must be installed
via the system package manager (apk for Alpine Linux).
This function installs the specified packages using `apk add --no-cache`.
"""
pkg_list = [p.strip() for p in pkgs.split(",") if p.strip()]
if not pkg_list:
return
cmd = ["apk", "add", "--no-cache"] + pkg_list
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
log.error("Failed to install system packages '%s': %s", pkg_list, e)
sys.exit(1)
def deserialize_class(payload: str) -> Type[ETLServer]:
"""Deserialize the ETL class from the payload."""
try:
raw = base64.b64decode(payload.encode())
etl_class = cloudpickle.loads(raw)
except Exception as e: # pylint: disable=broad-exception-caught
log.error("Failed to deserialize ETL class: %s", e)
sys.exit(1)
if not isinstance(etl_class, type) or not issubclass(etl_class, ETLServer):
raise TypeError(f"{etl_class!r} is not a subclass of ETLServer")
return etl_class
# ------------------------------------------------------------------------------
# Main
# ------------------------------------------------------------------------------
def main():
"""Entry point to set up and run the ETL server."""
# 1) Install dependencies if specified
if PACKAGES:
log.info("Installing required packages: %s", PACKAGES)
for package in PACKAGES.split(","):
install(package.strip())
if OS_PACKAGES:
log.info("Installing system packages: %s", OS_PACKAGES)
install_system(OS_PACKAGES)
# 2) Deserialize ETL class
etl_class = deserialize_class(ETL_CLASS_PAYLOAD)
# 3) Instantiate ETL server
try:
server = etl_class()
# pylint: disable=broad-exception-caught
except Exception as e:
log.error("Failed to instantiate ETLServer: %s", e)
sys.exit(1)
# 5) Start server
if isinstance(server, HTTPMultiThreadedServer):
log.info("Starting HTTP server in-process")
server.start()
return
if isinstance(server, FastAPIServer):
cmd = [
"uvicorn",
"server:server.app",
"--host",
"0.0.0.0",
"--port",
"8000",
"--workers",
str(NUM_WORKERS),
"--log-level",
"info",
"--ws-max-size",
"17179869184",
"--ws-ping-interval",
"0",
"--ws-ping-timeout",
"86400",
"--no-access-log",
]
elif isinstance(server, FlaskServer):
cmd = [
"gunicorn",
"server:server.app",
"--bind",
"0.0.0.0:8000",
"--workers",
str(NUM_WORKERS),
"--log-level",
"debug",
]
else:
log.error("Unsupported server type: %s", server.__class__.__name__)
sys.exit(1)
log.info("Launching server: %s", " ".join(cmd))
os.execvp(cmd[0], cmd)
if __name__ == "__main__":
main()