Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

just testing #485

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import aiohttp
import asyncio
import time
import json
import os

async def measure_performance(api_endpoint: str, prompt: str = "Who are you?"):
async with aiohttp.ClientSession() as session:
request = {
"model": "llama-3.2-3b",
"messages": [{"role": "user", "content": prompt}],
"stream": True
}

start_time = time.time()
first_token_time = None
total_tokens = 0

print(f"Sending request to {api_endpoint}...")

async with session.post(api_endpoint, json=request) as response:
async for line in response.content:
if not line.strip():
continue

line = line.decode('utf-8')
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if line == '[DONE]':
break

try:
chunk = json.loads(line)
if chunk.get('choices') and chunk['choices'][0].get('delta', {}).get('content'):
if first_token_time is None:
first_token_time = time.time()
ttft = first_token_time - start_time
print(f"Time to first token: {ttft:.3f}s")

total_tokens += 1

except json.JSONDecodeError:
continue

end_time = time.time()
total_time = end_time - start_time

if total_tokens > 0:
tps = total_tokens / total_time
print(f"Tokens per second: {tps:.1f}")
print(f"Total tokens generated: {total_tokens}")
print(f"Total time: {total_time:.3f}s")
else:
print("No tokens were generated")

if __name__ == "__main__":
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:52415/v1/chat/completions")
asyncio.run(measure_performance(API_ENDPOINT, prompt="Write an essay about life, the universe, and everything."))
9 changes: 0 additions & 9 deletions exo/inference/mlx/stateful_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, Tuple
from collections import OrderedDict

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import make_prompt_cache

Expand All @@ -16,12 +14,6 @@ def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
self.caches = OrderedDict()

def init_cache(self, request_id: str):
kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
# if self.max_kv_size is not None:
# cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
# cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
# else:
# cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
cache = make_prompt_cache(self.model)

if len(self.caches) >= self.max_caches:
Expand All @@ -39,4 +31,3 @@ def __call__(self, x, request_id: str):

y = self.model(x, cache=cache)
return y

11 changes: 8 additions & 3 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
parser.add_argument("--stream", action=argparse.BooleanOptionalAction, help="Stream the output of running a model")
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
Expand Down Expand Up @@ -165,7 +166,7 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):

shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)

async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str, stream_output: bool):
inference_class = inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class)
if not shard:
Expand All @@ -183,7 +184,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, request_id=request_id)

_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
def _on_token(_request_id, tokens, is_finished):
if stream_output:
print(tokenizer.decode(tokens), end='\r', flush=True)
return _request_id == request_id and is_finished
_, tokens, _ = await callback.wait(_on_token, timeout=300)

print("\nGenerated response:")
print(tokenizer.decode(tokens))
Expand Down Expand Up @@ -230,7 +235,7 @@ def handle_exit():
if not model_name:
print("Error: Model name is required when using 'run' command or --run-model")
return
await run_model_cli(node, inference_engine, model_name, args.prompt)
await run_model_cli(node, inference_engine, model_name, args.prompt, args.stream)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
Expand Down