From d0b78e6861bbc422b22cba7849af833035e70674 Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Wed, 19 Feb 2025 21:02:01 +0000 Subject: [PATCH 1/6] adding best_of and logprobs patch to benchmarking script for compatability --- benchmarks/backend_request_func.py | 2 -- benchmarks/benchmark_serving.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 4813fde27f0bc..0cb3e72e4fb64 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -235,9 +235,7 @@ async def async_request_openai_completions( "model": request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, - "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, - "logprobs": request_func_input.logprobs, "stream": True, "ignore_eos": request_func_input.ignore_eos, } diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index c1a396c81f666..964ae9b798a54 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -417,7 +417,7 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, + best_of=None, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) From fa7ab76b3d405e0bbba9434f3ba7d77896c5ce99 Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Wed, 19 Feb 2025 21:10:06 +0000 Subject: [PATCH 2/6] uplift benchmarks/benchmark_serving.py from 45186834a0d9f101dd29fac0e7ccbdb245f27645 --- benchmarks/benchmark_serving.py | 421 +++++++++++++++++++++++++++++--- 1 file changed, 393 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 964ae9b798a54..3b188803b0dcc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -25,6 +25,7 @@ import argparse import asyncio import base64 +import gc import io import json import os @@ -36,6 +37,7 @@ from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np +import pandas as pd from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) from datasets import load_dataset @@ -53,6 +55,10 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +from benchmark_utils import convert_to_pytorch_benchmark_format + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + @dataclass class BenchmarkMetrics: @@ -60,6 +66,7 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float + request_goodput: float output_throughput: float total_token_throughput: float mean_ttft_ms: float @@ -126,6 +133,35 @@ def sample_sharegpt_requests( return filtered_dataset +def sample_burstgpt_requests( + dataset_path: str, + num_requests: int, + random_seed: int, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int, None]]: + df = pd.read_csv(dataset_path) + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove the failed requests (i.e., response length is 0) + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Randomly sample num_requests from the dataset + if num_requests <= len(gpt4_df): + gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed) + else: + gpt4_df = gpt4_df.sample(n=num_requests, + random_state=random_seed, + replace=True) + # Convert the dataframe to a list of tuples + dataset = gpt4_df.values.tolist() + input_requests = [] + for i in range(num_requests): + input_len = int(dataset[i][2]) + output_len = int(dataset[i][3]) + prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size + for j in range(input_len)]) + input_requests.append((prompt, input_len, output_len, None)) + return input_requests + + def sample_sonnet_requests( dataset_path: str, num_requests: int, @@ -196,22 +232,80 @@ def sample_sonnet_requests( return sampled_requests +def sample_vision_arena_requests( + dataset, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in dataset: + if len(sampled_requests) == num_requests: + break + + prompt = data["turns"][0][0]['content'] + + prompt_token_ids = tokenizer(prompt).input_ids + if fixed_output_len is None: + # Default max output len is set to 128 + print("--hf-output-len is not provided. Using default value 128.") + fixed_output_len = 128 + + prompt_len = len(prompt_token_ids) + output_len = fixed_output_len + + assert isinstance( + data["images"][0], + Image), ("Input image format must be `PIL.Image.Image`, " + f"given {type(data['image'])}.") + image: Image = data["images"][0] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) + + return sampled_requests + + def sample_hf_requests( dataset_path: str, - dataset_subset: str, + dataset_subset: Optional[str], dataset_split: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + random_seed: int, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + + # Special case for vision_arena dataset + if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ + and dataset_subset is None: + assert dataset_split == "train" + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + dataset = dataset.shuffle(seed=random_seed) + return sample_vision_arena_requests(dataset, num_requests, tokenizer, + fixed_output_len) + dataset = load_dataset(dataset_path, name=dataset_subset, split=dataset_split, streaming=True) assert "conversations" in dataset.features, ( "HF Dataset must have 'conversations' column.") - filtered_dataset = dataset.shuffle().filter( - lambda x: len(x["conversations"]) >= 2) + filter_func = lambda x: len(x["conversations"]) >= 2 + filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) sampled_requests: List[Tuple[str, int, int, Dict[str, Collection[str]]]] = [] for data in filtered_dataset: @@ -247,6 +341,19 @@ def sample_hf_requests( "url": f"data:image/jpeg;base64,{image_base64}" }, } + elif "image" in data and isinstance(data["image"], str): + if (data["image"].startswith("http://") or \ + data["image"].startswith("file://")): + image_url = data["image"] + else: + image_url = f"file://{data['image']}" + + mm_content = { + "type": "image_url", + "image_url": { + "url": image_url + }, + } else: mm_content = None @@ -293,8 +400,33 @@ def sample_random_requests( async def get_request( input_requests: List[Tuple[str, int, int]], request_rate: float, + burstiness: float = 1.0, ) -> AsyncGenerator[Tuple[str, int, int], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + for request in input_requests: yield request @@ -302,8 +434,9 @@ async def get_request( # If the request rate is infinity, then we don't need to wait. continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) # The next request will be sent after the interval. await asyncio.sleep(interval) @@ -315,28 +448,39 @@ def calculate_metrics( tokenizer: PreTrainedTokenizerBase, selected_percentile_metrics: List[str], selected_percentiles: List[float], + goodput_config_dict: Dict[str, float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 completed = 0 + good_completed = 0 itls: List[float] = [] tpots: List[float] = [] + all_tpots: List[float] = [] ttfts: List[float] = [] e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: - # We use the tokenizer to count the number of output tokens for all - # serving backends instead of looking at len(outputs[i].itl) since - # multiple output tokens may be bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + output_len = outputs[i].output_tokens + + if output_len is None: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) actual_output_lens.append(output_len) total_input += input_requests[i][1] + tpot = 0 if output_len > 1: - tpots.append( - (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) itls += outputs[i].itl ttfts.append(outputs[i].ttft) e2els.append(outputs[i].latency) @@ -344,6 +488,28 @@ def calculate_metrics( else: actual_output_lens.append(0) + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + if completed == 0: warnings.warn( "All requests failed. This is likely due to a misconfiguration " @@ -354,6 +520,7 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * @@ -372,9 +539,9 @@ def calculate_metrics( median_itl_ms=np.median(itls or 0) * 1000, percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles], - mean_e2el_ms=np.median(e2els or 0) * 1000, + mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.mean(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], ) @@ -387,16 +554,21 @@ async def benchmark( api_url: str, base_url: str, model_id: str, + model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], logprobs: Optional[int], best_of: int, request_rate: float, + burstiness: float, disable_tqdm: bool, profile: bool, selected_percentile_metrics: List[str], selected_percentiles: List[str], ignore_eos: bool, + goodput_config_dict: Dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[List[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -412,15 +584,17 @@ async def benchmark( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, + model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=None, + best_of=best_of, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) + test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError( @@ -429,9 +603,15 @@ async def benchmark( else: print("Initial test run completed. Starting main benchmark run...") + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, @@ -444,15 +624,43 @@ async def benchmark( if profile_output.success: print("Profiler started") + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): + async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request - request_func_input = RequestFuncInput(model=model_id, + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -463,8 +671,8 @@ async def benchmark( ignore_eos=ignore_eos) tasks.append( asyncio.create_task( - request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -494,6 +702,7 @@ async def benchmark( tokenizer=tokenizer, selected_percentile_metrics=selected_percentile_metrics, selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -505,6 +714,9 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", @@ -516,6 +728,8 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, + "request_goodput:": + metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -569,6 +783,67 @@ def process_one_metric( return result +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: Dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + with open(pt_file, "w") as f: + json.dump(pt_records, f) + + def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -576,7 +851,9 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model + model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" @@ -586,6 +863,7 @@ def main(args: argparse.Namespace): base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) if args.dataset is not None: @@ -609,6 +887,14 @@ def main(args: argparse.Namespace): fixed_output_len=args.sharegpt_output_len, ) + elif args.dataset_name == "burstgpt": + input_requests = sample_burstgpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + random_seed=args.seed, + tokenizer=tokenizer, + ) + elif args.dataset_name == "sonnet": # Do not format the prompt, pass to message directly if args.backend == "openai-chat": @@ -646,6 +932,7 @@ def main(args: argparse.Namespace): dataset_split=args.hf_split, num_requests=args.num_prompts, tokenizer=tokenizer, + random_seed=args.seed, fixed_output_len=args.hf_output_len, ) @@ -662,17 +949,25 @@ def main(args: argparse.Namespace): else: raise ValueError(f"Unknown dataset: {args.dataset_name}") + goodput_config_dict = check_goodput_args(args) + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + benchmark_result = asyncio.run( benchmark( backend=backend, api_url=api_url, base_url=base_url, model_id=model_id, + model_name=model_name, tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, best_of=args.best_of, request_rate=args.request_rate, + burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), @@ -680,6 +975,9 @@ def main(args: argparse.Namespace): float(p) for p in args.metric_percentiles.split(",") ], ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, )) # Save config and results to json @@ -707,21 +1005,26 @@ def main(args: argparse.Namespace): ) # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf") + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency # Merge with benchmark result result_json = {**result_json, **benchmark_result} # Save to file base_model_id = model_id.split("/")[-1] - file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa if args.result_filename: file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) with open(file_name, "w", encoding='utf-8') as outfile: json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) if __name__ == "__main__": @@ -739,7 +1042,8 @@ def main(args: argparse.Namespace): default=None, help="Server or API base url if not using http host and port.", ) - parser.add_argument("--host", type=str, default="localhost") + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument( "--endpoint", @@ -758,7 +1062,7 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", @@ -766,6 +1070,19 @@ def main(args: argparse.Namespace): default=None, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + parser.add_argument( "--model", type=str, @@ -808,8 +1125,20 @@ def main(args: argparse.Namespace): default=float("inf"), help="Number of requests per second. If this is inf, " "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.", + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument( @@ -879,6 +1208,17 @@ def main(args: argparse.Namespace): "Default value is \"99\". " "Use \"--percentile-metrics\" to select metrics.", ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") # group for dataset specific arguments sonnet_group = parser.add_argument_group("sonnet dataset options") @@ -960,5 +1300,30 @@ def main(args: argparse.Namespace): "from the sampled HF dataset.", ) + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + args = parser.parse_args() main(args) From 870cfe2d5a16f7dc2f8cd43eca4ea3297feefedc Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Wed, 19 Feb 2025 21:24:18 +0000 Subject: [PATCH 3/6] add benchmark_utils.py --- benchmarks/benchmark_serving.py | 1 + benchmarks/benchmark_utils.py | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 benchmarks/benchmark_utils.py diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3b188803b0dcc..9760737ccec3e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 r"""Benchmark online serving throughput. On the server side, run one of the following commands: diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py new file mode 100644 index 0000000000000..dc6d31f6fdb99 --- /dev/null +++ b/benchmarks/benchmark_utils.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +from typing import Any, Dict, List + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: Dict[str, List], + extra_info: Dict[str, Any]) -> List: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + records.append(record) + + return records \ No newline at end of file From c22daa0f3ee6f5aa90c9a0933f11f45c7e366044 Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Wed, 19 Feb 2025 21:31:00 +0000 Subject: [PATCH 4/6] uplift --- benchmarks/backend_request_func.py | 148 +++++++++++++++++++---------- 1 file changed, 100 insertions(+), 48 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 0cb3e72e4fb64..19e12f2e9af29 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import json import os import sys @@ -22,8 +24,10 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + model_name: Optional[str] = None best_of: int = 1 logprobs: Optional[int] = None + extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False @@ -33,9 +37,11 @@ class RequestFuncOutput: generated_text: str = "" success: bool = False latency: float = 0.0 + output_tokens: int = 0 ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies + tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -47,13 +53,15 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, # TGI does not accept ignore_eos flag. } payload = { @@ -75,11 +83,11 @@ async def async_request_tgi( continue chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without + # NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue - chunk = remove_prefix(chunk_bytes, "data:") + chunk = chunk_bytes.removeprefix("data:") data = json.loads(chunk) timestamp = time.perf_counter() @@ -118,7 +126,8 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -144,15 +153,15 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data:") + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data:") data = json.loads(chunk) output.generated_text += data["text_output"] timestamp = time.perf_counter() # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -182,7 +191,8 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { @@ -230,15 +240,25 @@ async def async_request_openai_completions( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, + # "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + # "logprobs": request_func_input.logprobs, "stream": True, - "ignore_eos": request_func_input.ignore_eos, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } @@ -247,32 +267,34 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: + first_chunk_received = False async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": data = json.loads(chunk) # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated - if data["choices"][0]["text"]: + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") timestamp = time.perf_counter() # First token - if ttft == 0.0: + if not first_chunk_received: + first_chunk_received = True ttft = time.perf_counter() - st output.ttft = ttft @@ -282,11 +304,19 @@ async def async_request_openai_completions( most_recent_timestamp) most_recent_timestamp = timestamp - generated_text += data["choices"][0]["text"] - + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") output.generated_text = generated_text - output.success = True - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -309,12 +339,14 @@ async def async_request_openai_chat_completions( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "messages": [ { "role": "user", @@ -322,10 +354,16 @@ async def async_request_openai_chat_completions( }, ], "temperature": 0.0, - "max_tokens": request_func_input.output_len, + "max_completion_tokens": request_func_input.output_len, "stream": True, - "ignore_eos": request_func_input.ignore_eos, + "stream_options": { + "include_usage": True, + }, } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", @@ -347,19 +385,17 @@ async def async_request_openai_chat_completions( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), - "data: ") - if chunk == "[DONE]": - latency = time.perf_counter() - st - else: + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) - delta = data["choices"][0]["delta"] - if delta.get("content", None): + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: - ttft = time.perf_counter() - st + ttft = timestamp - st output.ttft = ttft # Decoding phase @@ -367,13 +403,16 @@ async def async_request_openai_chat_completions( output.itl.append(timestamp - most_recent_timestamp) - generated_text += delta["content"] + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True - output.latency = latency + output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False @@ -387,14 +426,6 @@ async def async_request_openai_chat_completions( return output -# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) -# introduced in Python 3.9 -def remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix):] - return text - - def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download @@ -409,14 +440,35 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( - pretrained_model_name_or_path: str, trust_remote_code: bool + pretrained_model_name_or_path: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path): pretrained_model_name_or_path = get_model( pretrained_model_name_or_path) - return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, - trust_remote_code=trust_remote_code) + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + if tokenizer_mode == "mistral": + try: + from vllm.transformers_utils.tokenizer import MistralTokenizer + except ImportError as e: + raise ImportError("MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode.") from e + return MistralTokenizer.from_pretrained( + str(pretrained_model_name_or_path)) + else: + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) ASYNC_REQUEST_FUNCS = { From 3429acf14e46436948db6865b90178c6375d0217 Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Wed, 19 Feb 2025 21:39:51 +0000 Subject: [PATCH 5/6] backport removeprefix --- benchmarks/backend_request_func.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 19e12f2e9af29..29125012b17f9 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -46,6 +46,10 @@ class RequestFuncOutput: error: str = "" +def backport_removeprefix(string: str, prefix: str) -> str: + return string[len(prefix):] if string.startswith(prefix) else string + + async def async_request_tgi( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -87,7 +91,7 @@ async def async_request_tgi( # any data, we should skip it. if chunk_bytes.startswith(":"): continue - chunk = chunk_bytes.removeprefix("data:") + chunk = backport_removeprefix(chunk_bytes, "data:") data = json.loads(chunk) timestamp = time.perf_counter() @@ -153,8 +157,7 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data:") + chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data:") data = json.loads(chunk) output.generated_text += data["text_output"] @@ -279,8 +282,7 @@ async def async_request_openai_completions( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data: ") if chunk != "[DONE]": data = json.loads(chunk) @@ -385,8 +387,7 @@ async def async_request_openai_chat_completions( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) From 82515b4f897edbffea98c3c1587c25ce1a1753dd Mon Sep 17 00:00:00 2001 From: Tom Stesco Date: Sat, 22 Feb 2025 01:59:10 +0000 Subject: [PATCH 6/6] add TODO: remove before upstream: comments --- benchmarks/backend_request_func.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 29125012b17f9..2c1355219ca9c 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -45,7 +45,9 @@ class RequestFuncOutput: prompt_len: int = 0 error: str = "" - +# TODO: remove before upstream: remove once we can drop support of Python 3.8 +# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) +# introduced in Python 3.9 def backport_removeprefix(string: str, prefix: str) -> str: return string[len(prefix):] if string.startswith(prefix) else string @@ -250,6 +252,7 @@ async def async_request_openai_completions( if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, + # TODO: remove before upstream: best_of and logprobs not currently supported: https://github.com/tenstorrent/vllm/issues/44 # "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, # "logprobs": request_func_input.logprobs,