diff --git a/sarathi/benchmark/benchmark_runner.py b/sarathi/benchmark/benchmark_runner.py index 083dcb3..1671547 100644 --- a/sarathi/benchmark/benchmark_runner.py +++ b/sarathi/benchmark/benchmark_runner.py @@ -11,6 +11,7 @@ from sarathi.benchmark.entities import Request from sarathi.benchmark.request_generator import RequestGeneratorRegistry from sarathi.benchmark.utils.random import set_seeds +from sarathi.benchmark.utils import dataset_loader from sarathi.config import ReplicaConfig from sarathi.metrics.metrics_store import MetricsStore from sarathi.types import ReplicaResourceMapping, ResourceMapping @@ -38,17 +39,22 @@ def __init__( os.makedirs(replica_config.output_dir, exist_ok=True) set_seeds(self.config.seed) + request_generator = RequestGeneratorRegistry.get( self.config.request_generator_config.get_type(), self.config.request_generator_config, ) + self.requests = request_generator.generate() + self.run_correctness_tests = self.config.correctness_test_config is not None \ + and self.config.correctness_test_config.run_correctness_tests + # select every nth request for this replica # e.g. if there are 4 replicas, and this is the 2nd replica, then # we will select the 2nd, 6th, 10th, ... requests # round robin scheduling - self.requests = self.requests[self.replica_id :: self.config.num_replicas] + # self.requests = self.requests[self.replica_id :: self.config.num_replicas] if self.config.num_replicas > 1: # disable per-replica wandb logging for multi-replica runs @@ -70,14 +76,24 @@ def _get_input_params( temperature=0, top_p=1.0, ) - prompt_token_ids = [1] * request.num_prefill_tokens - return { - "prompt": None, - "prompt_token_ids": prompt_token_ids, - "sampling_params": sampling_params, - "arrival_time": first_request_time + request.arrived_at, - } + if self.run_correctness_tests: + return { + "prompt": request.prompt, + "prompt_token_ids": None, + "sampling_params": sampling_params, + "arrival_time": first_request_time + request.arrived_at, + } + else: + + prompt_token_ids = [1] * request.num_prefill_tokens + + return { + "prompt": None, + "prompt_token_ids": prompt_token_ids, + "sampling_params": sampling_params, + "arrival_time": first_request_time + request.arrived_at, + } def warmup(self) -> None: self.llm_engine.add_request(**self._get_input_params(self.requests[0], 0)) @@ -111,6 +127,11 @@ def _run(self) -> None: num_steps += 1 for output in step_outputs: + if self.config.run_correctness_tests: + print("CORRECTNESS OUTPUT") + print(output.text) + self.correctness_output[output.seq_id] = self.correctness_output.get(output.seq_id, "") + output.text + if output.finished: num_processed_requests += 1 pbar.update(1) @@ -143,6 +164,17 @@ def run(self) -> None: metric_store = self.llm_engine.get_metric_store() return metric_store + def check_correctness(self) -> bool: + assert self.config.run_correctness_checks + + with open(file_path, 'r') as file: + data = yaml.safe_load(file) + + for k, v in data.items(): + if self.correctness_output[k] != v: return False + + return True + class BenchmarkRunnerLauncher: diff --git a/sarathi/benchmark/config.py b/sarathi/benchmark/config.py index bbc5d10..7d4c159 100644 --- a/sarathi/benchmark/config.py +++ b/sarathi/benchmark/config.py @@ -161,6 +161,35 @@ def get_type(): return RequestLengthGeneratorType.FIXED + +@dataclass +class DatasetRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + dataset: str = field( + default="ccdv/arxiv-summarization", + metadata={"help": "Path to the trace file for request lengths."}, + ) + meta_prompt: str = field( + default=None, + metadata={"help": "Meta prompt for the dataset."}, + ) + max_prompt_len: int = field( + default=4096, metadata={"help": "Maximum prompt length allowed."} + ) + max_num_prompts: int = field( + default=300, metadata={"help": "Maximum number of prompts to use."} + ) + max_decode_tokens: int = field( + default=512, metadata={"help": "Maximum number of decode tokens."} + ) + tokenizer_model: str = field( + default="meta-llama/Meta-Llama-3-8B-Instruct", metadata={"help": "Name or path of the huggingface model to use for the tokenizer."} + ) + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.DATASET + + @dataclass class BaseRequestGeneratorConfig(BasePolyConfig): seed: int = field( @@ -214,6 +243,19 @@ class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig): def get_type(): return RequestGeneratorType.TRACE +@dataclass +class CorrectnessTestConfig(BaseTestConfig): + run_correctness_tests: bool = field( + default=False, metadata={"help": "Collect correctness data in this run"} + ) + run_correctness_baseline: bool = field( + default=False, metadata={"help": "Make this correctness ground truth for correctness tests"} + ) + correctness_test_file: bool = field( + default=None, metadata={"help": "Ground truth file for model output. If run_correctness_baseline is True, then the model output will be saved to \ + this file to be used as a ground truth file. Otherwise, the test will read from this file to be used as ground truth for \ + the correctness test."} + ) @dataclass class BenchmarkConfig(BaseEndpointConfig): @@ -240,6 +282,9 @@ class BenchmarkConfig(BaseEndpointConfig): request_generator_config: BaseRequestGeneratorConfig = field( default_factory=SyntheticRequestGeneratorConfig ) + correctness_test_config: Optional[CorrectnessTestConfig] = field( + default_factory=CorrectnessTestConfig + ) def __post_init__(self): super().__post_init__() diff --git a/sarathi/benchmark/correctness_test_runner.py b/sarathi/benchmark/correctness_test_runner.py new file mode 100644 index 0000000..c3490f3 --- /dev/null +++ b/sarathi/benchmark/correctness_test_runner.py @@ -0,0 +1,277 @@ +import logging +import os +import time + +import ray +import wandb +from tqdm import tqdm + +from sarathi import LLMEngine, SamplingParams +from sarathi.benchmark.config import BenchmarkConfig +from sarathi.benchmark.entities import Request +from sarathi.benchmark.request_generator import RequestGeneratorRegistry +from sarathi.benchmark.utils.random import set_seeds +from sarathi.benchmark.utils import dataset_loader +from sarathi.config import ReplicaConfig +from sarathi.metrics.metrics_store import MetricsStore +from sarathi.types import ReplicaResourceMapping, ResourceMapping +from sarathi.utils import get_ip + +logger = logging.getLogger(__name__) + + +class CorrectnessRunner: + + def __init__( + self, + replica_id: int, + config: BenchmarkConfig, + resource_mapping: ResourceMapping, + dataset: str + ) -> None: + self.replica_id = replica_id + self.config = config + + replica_config = ReplicaConfig( + replica_id, + self.config.output_dir, + resource_mapping, + ) + os.makedirs(replica_config.output_dir, exist_ok=True) + + set_seeds(self.config.seed) + + self.requests = dataset_loader.get_data_loader() + + # select every nth request for this replica + # e.g. if there are 4 replicas, and this is the 2nd replica, then + # we will select the 2nd, 6th, 10th, ... requests + # round robin scheduling + self.requests = self.requests[self.replica_id :: self.config.num_replicas] + + if self.config.num_replicas > 1: + # disable per-replica wandb logging for multi-replica runs + # so that we can aggregate metrics across all replicas + self.config.metrics_config.wandb_project = None + + system_config = self.config.create_system_config(replica_config) + self.llm_engine = LLMEngine.from_system_config(system_config) + + if wandb.run is not None: + wandb.config.update(self.config.to_dict()) + + def _get_input_params( + self, request: Request, first_request_time: float + ) -> SamplingParams: + sampling_params = SamplingParams( + ignore_eos=True, + max_tokens=request.num_decode_tokens, + temperature=0, + top_p=1.0, + ) + + return { + "prompt": request.prompt, + "prompt_token_ids": None, + "sampling_params": sampling_params, + "arrival_time": first_request_time + request.arrived_at, + } + + def warmup(self) -> None: + self.llm_engine.add_request(**self._get_input_params(self.requests[0], 0)) + + is_completed = False + while not is_completed: + step_outputs = self.llm_engine.step() + is_completed = step_outputs[0].finished + + self.llm_engine.reset_metrics() + + def _run(self) -> None: + if self.config.enable_profiling: + self.llm_engine.start_profiling() + + num_processed_requests = 0 + num_steps = 0 + pbar = tqdm( + total=len(self.requests), + desc=f"Replica {self.replica_id} processed requests", + ) + start_time = time.monotonic() + + # Run the engine. + while num_processed_requests < len(self.requests): + elapsed_time = time.monotonic() - start_time + if elapsed_time > self.config.time_limit: + break + + step_outputs = self.llm_engine.step() + num_steps += 1 + + for output in step_outputs: + if output.finished: + num_processed_requests += 1 + pbar.update(1) + + end_time = time.monotonic() + pbar.close() + + logger.info( + f"Replica {self.replica_id} exiting after processing {len(self.requests)} ({num_steps} iterations), Total time taken: {end_time - start_time:.2f} seconds" + ) + + if self.config.enable_profiling: + self.llm_engine.stop_profiling() + + def _add_requests(self) -> None: + index = 0 + first_request_time = time.monotonic() + while index < len(self.requests): + request = self.requests[index] + self.llm_engine.add_request( + **self._get_input_params(request, first_request_time) + ) + index += 1 + + def run(self) -> None: + self.llm_engine.reset_metrics() + self._add_requests() + self._run() + self.llm_engine.pull_worker_metrics() + metric_store = self.llm_engine.get_metric_store() + return metric_store + + +class CorrectnessRunnerLauncher: + + def __init__(self, config: BenchmarkConfig) -> None: + self.config = config + self.is_multi_replica = self.config.num_replicas > 1 + + ray.init(ignore_reinit_error=True) + + self._validate_cluster_resources() + self.runners = self._create_runners() + + if self.is_multi_replica: + self.aggregate_metric_store = self._create_aggregate_metric_store() + + def _validate_cluster_resources(self): + num_replicas = self.config.num_replicas + num_gpus_required = num_replicas * self.config.parallel_config.world_size + + available_resources = ray.available_resources() + + assert ( + available_resources["GPU"] >= num_gpus_required + ), f"Insufficient GPUs. Required: {num_gpus_required}, Available: {available_resources['GPU']}" + + def _get_replica_resource_mapping(self) -> ReplicaResourceMapping: + if self.config.replica_resource_mapping: + assert len(self.config.replica_resource_mapping) == self.config.num_replicas + logger.info( + f"Replica resource mapping: {self.config.replica_resource_mapping}" + ) + return self.config.replica_resource_mapping + + cluster_resources_keys = list(ray.available_resources().keys()) + num_gpus = ray.available_resources()["GPU"] + ip_addresses = [ + x + for x in cluster_resources_keys + if x.startswith("node:") and x != "node:__internal_head__" + ] + + runner_ip = f"node:{get_ip()}" + + ip_addresses.remove(runner_ip) + ip_addresses.insert(0, runner_ip) + + num_nodes = len(ip_addresses) + assert num_nodes > 0, "No nodes found in the cluster" + assert num_gpus > 0, "No GPUs found in the cluster" + assert ( + num_gpus % num_nodes == 0 + ), f"Number of GPUs ({num_gpus}) is not a multiple of number of nodes ({num_nodes})" + num_gpus_per_node = int(num_gpus // num_nodes) + num_replicas = self.config.num_replicas + num_gpus_per_replica = self.config.parallel_config.world_size + + assert ( + num_gpus >= num_replicas * num_gpus_per_replica + ), f"Insufficient GPUs. Required: {num_replicas * num_gpus_per_replica}, Available: {num_gpus}" + + replica_resource_mapping = [] + + available_gpus = [] + for ip_address in ip_addresses: + for gpu_id in reversed(range(num_gpus_per_node)): + available_gpus.append((ip_address, gpu_id)) + + for _ in range(num_replicas): + resource_mapping = [] + for _ in range(num_gpus_per_replica): + resource_mapping.append(available_gpus.pop(0)) + replica_resource_mapping.append(resource_mapping) + + logger.info(f"Replica resource mapping: {replica_resource_mapping}") + + return replica_resource_mapping + + def _create_runners(self): + replica_resource_mapping = self._get_replica_resource_mapping() + + if not self.is_multi_replica: + return [BenchmarkRunner(0, self.config, replica_resource_mapping[0])] + + runner_class = ray.remote(num_cpus=1)(BenchmarkRunner) + + runners = [] + + for replica_id in range(self.config.num_replicas): + runners.append( + runner_class.options( + resources={ + replica_resource_mapping[replica_id][0][0]: 0.01, + }, + ).remote(replica_id, self.config, replica_resource_mapping[replica_id]) + ) + + return runners + + def _create_aggregate_metric_store(self): + replica_config = ReplicaConfig( + replica_id=0, # dummy replica id + output_dir=self.config.output_dir, + ) + metrics_store = MetricsStore.get_instance( + replica_config, + self.config.model_config, + self.config.metrics_config, + ) + + if wandb.run is not None: + wandb.config.update(self.config.to_dict()) + + metrics_store.mark_initial_memory_profiling_done() + + return metrics_store + + def run(self): + if self.is_multi_replica: + ray.get([runner.warmup.remote() for runner in self.runners]) + + runner_metrics = ray.get([runner.run.remote() for runner in self.runners]) + + for runner_metric in runner_metrics: + self.aggregate_metric_store.merge(runner_metric) + + if wandb.run is not None: + wandb.config.update(self.config.__dict__) + + self.aggregate_metric_store.plot() + else: + metric_store = self.runners[0].run() + metric_store.plot() + + wandb.finish() diff --git a/sarathi/benchmark/entities/request.py b/sarathi/benchmark/entities/request.py index 2d74498..b274138 100644 --- a/sarathi/benchmark/entities/request.py +++ b/sarathi/benchmark/entities/request.py @@ -1,5 +1,5 @@ import logging -from typing import Tuple +from typing import Tuple, Optional from sarathi.benchmark.entities.base_entity import BaseEntity @@ -13,11 +13,13 @@ def __init__( arrived_at: float, num_prefill_tokens: int, num_decode_tokens: int, + prompt: Optional[str] ): self._id = Request.generate_id() self._arrived_at = arrived_at self._num_prefill_tokens = num_prefill_tokens self._num_decode_tokens = num_decode_tokens + self._prompt = prompt assert num_prefill_tokens > 0 assert num_decode_tokens > 0 @@ -28,6 +30,10 @@ def size(self) -> Tuple[int, int]: @property def arrived_at(self) -> float: return self._arrived_at + + @property + def prompt(self) -> float: + return self._prompt @property def num_prefill_tokens(self) -> int: @@ -44,6 +50,12 @@ def pd_ratio(self) -> float: @property def total_tokens(self) -> int: return self._num_prefill_tokens + self._num_decode_tokens + + @property + def prompt(self) -> int: + return self._prompt + + def to_dict(self) -> dict: return { diff --git a/sarathi/benchmark/main.py b/sarathi/benchmark/main.py index 333033d..71b2cb1 100644 --- a/sarathi/benchmark/main.py +++ b/sarathi/benchmark/main.py @@ -13,11 +13,14 @@ def main() -> None: - config = BenchmarkConfig.create_from_cli_args() + config = BenchmarkConfig( + run_correctness_tests=True, + correctness_test_dataset="openai_humaneval" + ) - os.makedirs(config.output_dir, exist_ok=True) - with open(os.path.join(config.output_dir, "config.yaml"), "w") as f: - yaml.dump(config.to_dict(), f) + # os.makedirs(config.output_dir, exist_ok=True) + # with open(os.path.join(config.output_dir, "config.yaml"), "w") as f: + # yaml.dump(config.to_dict(), f) logger.info(f"Starting benchmark with config: {config}") diff --git a/sarathi/benchmark/request_generator/base_request_length_generator.py b/sarathi/benchmark/request_generator/base_request_length_generator.py index 1609de3..dfe6b7d 100644 --- a/sarathi/benchmark/request_generator/base_request_length_generator.py +++ b/sarathi/benchmark/request_generator/base_request_length_generator.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple +from typing import Tuple, Union from sarathi.benchmark.config import BaseRequestLengthGeneratorConfig @@ -10,5 +10,5 @@ def __init__(self, config: BaseRequestLengthGeneratorConfig): self.config = config @abstractmethod - def get_next_num_tokens(self) -> Tuple[float, float]: + def get_next_num_tokens(self) -> Tuple[Union[str|float], float]: pass diff --git a/sarathi/benchmark/request_generator/dataset_request_length_generator.py b/sarathi/benchmark/request_generator/dataset_request_length_generator.py new file mode 100644 index 0000000..5e398f2 --- /dev/null +++ b/sarathi/benchmark/request_generator/dataset_request_length_generator.py @@ -0,0 +1,35 @@ +import logging +from typing import Tuple, Union + +import numpy as np +import pandas as pd + +from sarathi.benchmark.utils import data_loader +from sarathi.benchmark.config import DatasetRequestLengthGeneratorConfig +from sarathi.benchmark.request_generator.base_request_length_generator import ( + BaseRequestLengthGenerator, +) + +logger = logging.getLogger(__name__) + + +class DatasetRequestLengthGenerator(BaseRequestLengthGenerator): + + def __init__(self, config: DatasetRequestLengthGeneratorConfig): + super().__init__(config) + self.next_request_idx = 0 + prompts = data_loader.get_data_loader(None, config.dataset, config.meta_prompt, None, config.tokenizer_model) + self.requests = [prompt for prompt in prompts if len(prompt) <= config.max_prompt_len][:config.max_num_prompts] + self.decode_tokens = config.max_decode_tokens + + def get_next_num_tokens(self) -> Tuple[Union[str|float], float]: + if self.next_request_idx >= len(self.requests): + return None, None + + row = self.requests[self.next_request_idx] + self.next_request_idx += 1 + + return ( + row, + self.decode_tokens, + ) \ No newline at end of file diff --git a/sarathi/benchmark/request_generator/request_length_generator_registry.py b/sarathi/benchmark/request_generator/request_length_generator_registry.py index da3f12e..ef76896 100644 --- a/sarathi/benchmark/request_generator/request_length_generator_registry.py +++ b/sarathi/benchmark/request_generator/request_length_generator_registry.py @@ -7,6 +7,9 @@ from sarathi.benchmark.request_generator.uniform_request_length_generator import ( UniformRequestLengthGenerator, ) +from sarathi.benchmark.request_generator.dataset_request_length_generator import ( + DatasetRequestLengthGenerator +) from sarathi.benchmark.request_generator.zipf_request_length_generator import ( ZipfRequestLengthGenerator, ) @@ -30,3 +33,6 @@ class RequestLengthGeneratorRegistry(BaseRegistry): RequestLengthGeneratorRegistry.register( RequestLengthGeneratorType.FIXED, FixedRequestLengthGenerator ) +RequestLengthGeneratorRegistry.register( + RequestLengthGeneratorType.DATASET, DatasetRequestLengthGenerator +) diff --git a/sarathi/benchmark/request_generator/synthetic_request_generator.py b/sarathi/benchmark/request_generator/synthetic_request_generator.py index a835791..1b69b1f 100644 --- a/sarathi/benchmark/request_generator/synthetic_request_generator.py +++ b/sarathi/benchmark/request_generator/synthetic_request_generator.py @@ -43,11 +43,18 @@ def _generate_next_request(self, last_arrived_at: float) -> Request: if prefill_tokens is None or decode_tokens is None: return None + + prompt = None + # custom request prompt + if type(prefill_tokens) is str: + prompt = prefill_tokens + prefill_tokens = len(prefill_tokens) return Request( arrived_at=arrived_at, num_prefill_tokens=int(prefill_tokens), num_decode_tokens=int(decode_tokens), + prompt=prompt ) def _generate_requests(self) -> List[Request]: diff --git a/sarathi/benchmark/utils/dataset_loader.py b/sarathi/benchmark/utils/dataset_loader.py new file mode 100644 index 0000000..e273fa9 --- /dev/null +++ b/sarathi/benchmark/utils/dataset_loader.py @@ -0,0 +1,83 @@ +from collections.abc import Iterable +from itertools import islice +from typing import Optional + +from transformers import AutoTokenizer + +from datasets import load_dataset + +DATASET_FIELDS = { + "xsum": ("document", "train"), + "openai_humaneval": ("prompt", "test"), + "ccdv/arxiv-summarization": ("article", "train"), + "lmsys/lmsys-chat-1m": ("conversation", "train"), + "OpenGVLab/ShareGPT-4o": ("conversations", "image_caption"), + "Fredithefish/ShareGPT-unfiltered-alpaca-lora-format": ("instruction", "train"), + "openai/gsm8k": ("question", "main"), + # 'another_dataset': 'text_field', + # ... add other datasets and their respective fields +} + + +def get_data_loader( + input_string: Optional[str], dataset_str: Optional[str], meta_prompt: Optional[str], max_samples: Optional[int], + model_for_tokenizer_chat_template: Optional[str] +) -> Iterable[str]: + if input_string: + return [input_string] + if dataset_str not in DATASET_FIELDS: + # assume bwb path + return islice(load_bwb(dataset_str, meta_prompt), max_samples) + + field_name = DATASET_FIELDS[dataset_str][0] + split_name = DATASET_FIELDS[dataset_str][1] + if dataset_str != "OpenGVLab/ShareGPT-4o": + dataset = load_dataset(dataset_str, split=split_name) + else: + dataset = load_dataset(dataset_str, 'image_caption') + + if "conversation" not in field_name: + sample_iterator = map( + lambda x: f"{meta_prompt} {x}", map(lambda x: x[field_name], dataset) + ) + else: + # reformat sharegpt-4 + if "lmsys" not in dataset_str: + dataset['images'] = dataset['images'].map(format_conversations) + dataset = dataset['images'] + # don't use meta prompt, this is chat case + tokenizer = AutoTokenizer.from_pretrained(model_for_tokenizer_chat_template) + conversations = dataset[field_name] + sample_iterator = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=True,) + + if not max_samples: + return sample_iterator + + return islice(sample_iterator, max_samples) + + +def format_conversations(sample): + sample['conversations'] = [ + { + "content": turn["value"], + "role": "user" if turn["from"] == "human" else "assistant" + } + for turn in sample['conversations'] + ] + return sample + + +def load_bwb(file_path, meta_prompt, chunk_size=1000): + translation_chunks = [] + try: + with open(file_path, 'r', encoding='utf-8') as file: + while True: + chunk = file.read(chunk_size) + if not chunk: + break + chunk = meta_prompt + chunk + translation_chunks.append(chunk) + except FileNotFoundError: + print(f"Error: File '{file_path}' not found.") + + return translation_chunks \ No newline at end of file diff --git a/sarathi/types.py b/sarathi/types.py index fbf2cbb..3fd45c3 100644 --- a/sarathi/types.py +++ b/sarathi/types.py @@ -32,7 +32,7 @@ class RequestLengthGeneratorType(Enum): ZIPF = "ZIPF" TRACE = "TRACE" FIXED = "FIXED" - + DATASET = "DATASET" class AttentionBackend(Enum): FLASHINFER = "FLASHINFER" diff --git a/unit_test/correctness_test.py b/unit_test/correctness_test.py new file mode 100644 index 0000000..8007359 --- /dev/null +++ b/unit_test/correctness_test.py @@ -0,0 +1,52 @@ +import glob +import shutil +import json +import os + +import pandas as pd +import pytest + +from sarathi.benchmark.benchmark_runner import BenchmarkRunnerLauncher +from sarathi.benchmark.config import BenchmarkConfig, SyntheticRequestGeneratorConfig, \ + DatasetRequestLengthGeneratorConfig, CorrectnessTestConfig +from sarathi.config.config import ModelConfig, ParallelConfig + + +# pytest -k "perf_test" +@pytest.mark.parametrize( + "model, max_model_len, pp_size, tp_size, dataset, request_pattern, baseline_run, test_file", + [ + ("meta-llama/Meta-Llama-3-8B", 4096, 1, 1, "OpenGVLab/ShareGPT-4o", "uniform", False, None), + ] +) +def test_correctness(model: str, max_model_len: int, pp_size: int, tp_size: int, dataset: str, request_pattern: str, baseline_run: bool, baseline_file: str): + # TODO: Test over 3d space + model_config = ModelConfig( + model=model, + max_model_len=max_model_len + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=pp_size, + tensor_parallel_size=tp_size + ) + request_generator_config = SyntheticRequestGeneratorConfig( + length_generator_config=DatasetRequestLengthGeneratorConfig(dataset=dataset) + ) + correctness_test_config = CorrectnessTestConfig( + run_correctness_tests=True, + run_correctness_baseline=baseline_run, + correctness_test_file=baseline_file + ) + cwd = os.getcwd() + output_dir = os.path.join(cwd, "benchmark_output") + benchmark_config = BenchmarkConfig( + log_level="error", + output_dir=output_dir, + model_config=model_config, + parallel_config=parallel_config, + request_generator_config=request_generator_config, + test_config=correctness_test_config, + ) + BenchmarkRunnerLauncher(benchmark_config).run() + + print("correctness_test finished.") \ No newline at end of file