Skip to content

Commit 829d87b

Browse files
Fedir Zadniprovskyifedirz
authored andcommitted
feat: add TTS performance benchmarking script
1 parent 31be1c1 commit 829d87b

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

scripts/tts_performance_test.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import asyncio
2+
from collections.abc import Callable, Coroutine
3+
import logging
4+
import logging.config
5+
from pathlib import Path
6+
import time
7+
8+
from httpx import AsyncClient
9+
from openai import AsyncOpenAI
10+
from pydantic import SecretStr
11+
from pydantic_settings import BaseSettings
12+
13+
INPUT_TEXT_1 = """You can now select additional permissions when creating an API key to use in any third-party libraries or software that integrate with Immich. This mechanism will give you better control over what the other applications or libraries can do with your Immich’s instance.""" # noqa: RUF001
14+
INPUT_TEXT_2 = """I figured that surely, someone has had this idea and built it before. On eBay you'll find cubes of resin embedding various random components from mechanical watches, but they are typically sold as "steampunk art" and bear little resemblance to the proper assembly of a mechanical watch movement. Sometimes, you'll find resin castings showing every component of a movement spread out in a plane like a buffet---very cool, but not what I'm looking for. Despite my best efforts, I haven't found anyone who makes what I'm after, and I have a sneaking suspicion as to why that is. Building an exploded view of a mechanical watch movement is undoubtedly very fiddly work and requires working knowledge about how a mechanical watch is assembled. People with that skillset are called watchmakers. Maker, not "destroyer for the sake of art". I guess it falls to me, then, to give this project an honest shot. """
15+
16+
17+
class Config(BaseSettings):
18+
api_key: SecretStr = SecretStr("does-not-matter")
19+
log_level: str = "debug"
20+
"""
21+
Logging level. One of: 'debug', 'info', 'warning', 'error', 'critical'.
22+
"""
23+
logs_directory: str = "logs"
24+
25+
speaches_base_url: SecretStr = SecretStr("http://localhost:8000")
26+
speech_model_id: str = "speaches-ai/Kokoro-82M-v1.0-ONNX"
27+
voice_id: str = "af_heart"
28+
input_text: str = INPUT_TEXT_2
29+
30+
iterations: int = 5
31+
"""
32+
The number of iterations to run the performance test.
33+
"""
34+
concurrency: int = 1
35+
"""
36+
Maximum number of concurrent requests made to the API.
37+
"""
38+
39+
40+
def limit_concurrency[**P, R](
41+
coro: Callable[P, Coroutine[None, None, R]], limit: int
42+
) -> Callable[P, Coroutine[None, None, R]]:
43+
semaphore = asyncio.Semaphore(limit)
44+
45+
async def wrapped_coro(*args: P.args, **kwargs: P.kwargs) -> R:
46+
async with semaphore:
47+
return await coro(*args, **kwargs)
48+
49+
return wrapped_coro
50+
51+
52+
async def main(config: Config) -> None:
53+
logger = logging.getLogger(__name__)
54+
logger.debug("Config: %s", config.model_dump_json())
55+
client = AsyncClient(
56+
base_url=config.speaches_base_url.get_secret_value(),
57+
headers={"Authorization": f"Bearer {config.api_key.get_secret_value()}"},
58+
)
59+
oai_client = AsyncOpenAI(
60+
api_key=config.api_key.get_secret_value(),
61+
base_url=f"{config.speaches_base_url.get_secret_value()}/v1",
62+
http_client=client,
63+
)
64+
65+
logger.debug(f"Attempting to pull model {config.speech_model_id}")
66+
res = await client.post(f"{config.speaches_base_url}/v1/models/{config.speech_model_id}")
67+
logger.info(f"Finished attempting to pull model {config.speech_model_id}. Response: {res.text}")
68+
69+
# INFO: Make initial request so that the model is loaded into memory.
70+
await oai_client.audio.speech.create(
71+
input="Hello",
72+
model=config.speech_model_id,
73+
voice=config.voice_id, # type: ignore # noqa: PGH003
74+
)
75+
76+
async def create_speech() -> None:
77+
async with oai_client.audio.speech.with_streaming_response.create(
78+
input=config.input_text,
79+
model=config.speech_model_id,
80+
voice=config.voice_id, # type: ignore # noqa: PGH003
81+
) as res:
82+
chunk_times: list[float] = []
83+
start = time.perf_counter()
84+
prev_chunk_time = time.perf_counter()
85+
async for _ in res.iter_bytes():
86+
chunk_times.append(time.perf_counter() - prev_chunk_time)
87+
prev_chunk_time = time.perf_counter()
88+
stats = {
89+
"time_to_first_token": chunk_times[0],
90+
"average_chunk_time": sum(chunk_times) / len(chunk_times),
91+
"total_chunks": len(chunk_times),
92+
"total_time": time.perf_counter() - start,
93+
}
94+
logger.debug(stats)
95+
96+
create_speech_with_limited_concurrency = limit_concurrency(create_speech, config.concurrency)
97+
98+
start = time.perf_counter()
99+
100+
async with asyncio.TaskGroup() as tg:
101+
tasks = [tg.create_task(create_speech_with_limited_concurrency()) for _ in range(config.iterations)]
102+
start = time.perf_counter()
103+
await asyncio.gather(*tasks)
104+
logger.info(f"All tasks completed in {time.perf_counter() - start:.2f} seconds")
105+
106+
107+
if __name__ == "__main__":
108+
config = Config()
109+
110+
Path.mkdir(Path(config.logs_directory), exist_ok=True)
111+
logging_config = {
112+
"version": 1, # required
113+
"disable_existing_loggers": False,
114+
"formatters": {
115+
"simple": {"format": "%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s"},
116+
},
117+
"handlers": {
118+
"stdout": {
119+
"class": "logging.StreamHandler",
120+
"formatter": "simple",
121+
"stream": "ext://sys.stdout",
122+
},
123+
"file": {
124+
"class": "logging.FileHandler",
125+
"filename": f"{config.logs_directory}/{time.strftime('%Y-%m-%d_%H-%M-%S')}_tts_performance_test.log", # TODO: there's a better way to do this, but this is good enough for now
126+
"formatter": "simple",
127+
},
128+
},
129+
"loggers": {
130+
"root": {
131+
"level": config.log_level.upper(),
132+
"handlers": ["stdout", "file"],
133+
},
134+
"asyncio": {
135+
"level": "INFO",
136+
"handlers": ["stdout"],
137+
},
138+
"httpx": {
139+
"level": "WARNING",
140+
"handlers": ["stdout"],
141+
},
142+
"python_multipart": {
143+
"level": "INFO",
144+
"handlers": ["stdout"],
145+
},
146+
"httpcore": {
147+
"level": "INFO",
148+
"handlers": ["stdout"],
149+
},
150+
"openai": {
151+
"level": "INFO",
152+
"handlers": ["stdout"],
153+
},
154+
},
155+
}
156+
157+
logging.config.dictConfig(logging_config)
158+
logging.basicConfig(level=config.log_level.upper())
159+
asyncio.run(main(config))

0 commit comments

Comments
 (0)