|
| 1 | +import asyncio |
| 2 | +from collections.abc import Callable, Coroutine |
| 3 | +import logging |
| 4 | +import logging.config |
| 5 | +from pathlib import Path |
| 6 | +import time |
| 7 | + |
| 8 | +from httpx import AsyncClient |
| 9 | +from openai import AsyncOpenAI |
| 10 | +from pydantic import SecretStr |
| 11 | +from pydantic_settings import BaseSettings |
| 12 | + |
| 13 | +INPUT_TEXT_1 = """You can now select additional permissions when creating an API key to use in any third-party libraries or software that integrate with Immich. This mechanism will give you better control over what the other applications or libraries can do with your Immich’s instance.""" # noqa: RUF001 |
| 14 | +INPUT_TEXT_2 = """I figured that surely, someone has had this idea and built it before. On eBay you'll find cubes of resin embedding various random components from mechanical watches, but they are typically sold as "steampunk art" and bear little resemblance to the proper assembly of a mechanical watch movement. Sometimes, you'll find resin castings showing every component of a movement spread out in a plane like a buffet---very cool, but not what I'm looking for. Despite my best efforts, I haven't found anyone who makes what I'm after, and I have a sneaking suspicion as to why that is. Building an exploded view of a mechanical watch movement is undoubtedly very fiddly work and requires working knowledge about how a mechanical watch is assembled. People with that skillset are called watchmakers. Maker, not "destroyer for the sake of art". I guess it falls to me, then, to give this project an honest shot. """ |
| 15 | + |
| 16 | + |
| 17 | +class Config(BaseSettings): |
| 18 | + api_key: SecretStr = SecretStr("does-not-matter") |
| 19 | + log_level: str = "debug" |
| 20 | + """ |
| 21 | + Logging level. One of: 'debug', 'info', 'warning', 'error', 'critical'. |
| 22 | + """ |
| 23 | + logs_directory: str = "logs" |
| 24 | + |
| 25 | + speaches_base_url: SecretStr = SecretStr("http://localhost:8000") |
| 26 | + speech_model_id: str = "speaches-ai/Kokoro-82M-v1.0-ONNX" |
| 27 | + voice_id: str = "af_heart" |
| 28 | + input_text: str = INPUT_TEXT_2 |
| 29 | + |
| 30 | + iterations: int = 5 |
| 31 | + """ |
| 32 | + The number of iterations to run the performance test. |
| 33 | + """ |
| 34 | + concurrency: int = 1 |
| 35 | + """ |
| 36 | + Maximum number of concurrent requests made to the API. |
| 37 | + """ |
| 38 | + |
| 39 | + |
| 40 | +def limit_concurrency[**P, R]( |
| 41 | + coro: Callable[P, Coroutine[None, None, R]], limit: int |
| 42 | +) -> Callable[P, Coroutine[None, None, R]]: |
| 43 | + semaphore = asyncio.Semaphore(limit) |
| 44 | + |
| 45 | + async def wrapped_coro(*args: P.args, **kwargs: P.kwargs) -> R: |
| 46 | + async with semaphore: |
| 47 | + return await coro(*args, **kwargs) |
| 48 | + |
| 49 | + return wrapped_coro |
| 50 | + |
| 51 | + |
| 52 | +async def main(config: Config) -> None: |
| 53 | + logger = logging.getLogger(__name__) |
| 54 | + logger.debug("Config: %s", config.model_dump_json()) |
| 55 | + client = AsyncClient( |
| 56 | + base_url=config.speaches_base_url.get_secret_value(), |
| 57 | + headers={"Authorization": f"Bearer {config.api_key.get_secret_value()}"}, |
| 58 | + ) |
| 59 | + oai_client = AsyncOpenAI( |
| 60 | + api_key=config.api_key.get_secret_value(), |
| 61 | + base_url=f"{config.speaches_base_url.get_secret_value()}/v1", |
| 62 | + http_client=client, |
| 63 | + ) |
| 64 | + |
| 65 | + logger.debug(f"Attempting to pull model {config.speech_model_id}") |
| 66 | + res = await client.post(f"{config.speaches_base_url}/v1/models/{config.speech_model_id}") |
| 67 | + logger.info(f"Finished attempting to pull model {config.speech_model_id}. Response: {res.text}") |
| 68 | + |
| 69 | + # INFO: Make initial request so that the model is loaded into memory. |
| 70 | + await oai_client.audio.speech.create( |
| 71 | + input="Hello", |
| 72 | + model=config.speech_model_id, |
| 73 | + voice=config.voice_id, # type: ignore # noqa: PGH003 |
| 74 | + ) |
| 75 | + |
| 76 | + async def create_speech() -> None: |
| 77 | + async with oai_client.audio.speech.with_streaming_response.create( |
| 78 | + input=config.input_text, |
| 79 | + model=config.speech_model_id, |
| 80 | + voice=config.voice_id, # type: ignore # noqa: PGH003 |
| 81 | + ) as res: |
| 82 | + chunk_times: list[float] = [] |
| 83 | + start = time.perf_counter() |
| 84 | + prev_chunk_time = time.perf_counter() |
| 85 | + async for _ in res.iter_bytes(): |
| 86 | + chunk_times.append(time.perf_counter() - prev_chunk_time) |
| 87 | + prev_chunk_time = time.perf_counter() |
| 88 | + stats = { |
| 89 | + "time_to_first_token": chunk_times[0], |
| 90 | + "average_chunk_time": sum(chunk_times) / len(chunk_times), |
| 91 | + "total_chunks": len(chunk_times), |
| 92 | + "total_time": time.perf_counter() - start, |
| 93 | + } |
| 94 | + logger.debug(stats) |
| 95 | + |
| 96 | + create_speech_with_limited_concurrency = limit_concurrency(create_speech, config.concurrency) |
| 97 | + |
| 98 | + start = time.perf_counter() |
| 99 | + |
| 100 | + async with asyncio.TaskGroup() as tg: |
| 101 | + tasks = [tg.create_task(create_speech_with_limited_concurrency()) for _ in range(config.iterations)] |
| 102 | + start = time.perf_counter() |
| 103 | + await asyncio.gather(*tasks) |
| 104 | + logger.info(f"All tasks completed in {time.perf_counter() - start:.2f} seconds") |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + config = Config() |
| 109 | + |
| 110 | + Path.mkdir(Path(config.logs_directory), exist_ok=True) |
| 111 | + logging_config = { |
| 112 | + "version": 1, # required |
| 113 | + "disable_existing_loggers": False, |
| 114 | + "formatters": { |
| 115 | + "simple": {"format": "%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s"}, |
| 116 | + }, |
| 117 | + "handlers": { |
| 118 | + "stdout": { |
| 119 | + "class": "logging.StreamHandler", |
| 120 | + "formatter": "simple", |
| 121 | + "stream": "ext://sys.stdout", |
| 122 | + }, |
| 123 | + "file": { |
| 124 | + "class": "logging.FileHandler", |
| 125 | + "filename": f"{config.logs_directory}/{time.strftime('%Y-%m-%d_%H-%M-%S')}_tts_performance_test.log", # TODO: there's a better way to do this, but this is good enough for now |
| 126 | + "formatter": "simple", |
| 127 | + }, |
| 128 | + }, |
| 129 | + "loggers": { |
| 130 | + "root": { |
| 131 | + "level": config.log_level.upper(), |
| 132 | + "handlers": ["stdout", "file"], |
| 133 | + }, |
| 134 | + "asyncio": { |
| 135 | + "level": "INFO", |
| 136 | + "handlers": ["stdout"], |
| 137 | + }, |
| 138 | + "httpx": { |
| 139 | + "level": "WARNING", |
| 140 | + "handlers": ["stdout"], |
| 141 | + }, |
| 142 | + "python_multipart": { |
| 143 | + "level": "INFO", |
| 144 | + "handlers": ["stdout"], |
| 145 | + }, |
| 146 | + "httpcore": { |
| 147 | + "level": "INFO", |
| 148 | + "handlers": ["stdout"], |
| 149 | + }, |
| 150 | + "openai": { |
| 151 | + "level": "INFO", |
| 152 | + "handlers": ["stdout"], |
| 153 | + }, |
| 154 | + }, |
| 155 | + } |
| 156 | + |
| 157 | + logging.config.dictConfig(logging_config) |
| 158 | + logging.basicConfig(level=config.log_level.upper()) |
| 159 | + asyncio.run(main(config)) |
0 commit comments