Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add core correctness testing setup #50

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
48 changes: 40 additions & 8 deletions sarathi/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sarathi.benchmark.entities import Request
from sarathi.benchmark.request_generator import RequestGeneratorRegistry
from sarathi.benchmark.utils.random import set_seeds
from sarathi.benchmark.utils import dataset_loader
from sarathi.config import ReplicaConfig
from sarathi.metrics.metrics_store import MetricsStore
from sarathi.types import ReplicaResourceMapping, ResourceMapping
Expand Down Expand Up @@ -38,17 +39,22 @@ def __init__(
os.makedirs(replica_config.output_dir, exist_ok=True)

set_seeds(self.config.seed)

request_generator = RequestGeneratorRegistry.get(
self.config.request_generator_config.get_type(),
self.config.request_generator_config,
)

self.requests = request_generator.generate()

self.run_correctness_tests = self.config.correctness_test_config is not None \
and self.config.correctness_test_config.run_correctness_tests

# select every nth request for this replica
# e.g. if there are 4 replicas, and this is the 2nd replica, then
# we will select the 2nd, 6th, 10th, ... requests
# round robin scheduling
self.requests = self.requests[self.replica_id :: self.config.num_replicas]
# self.requests = self.requests[self.replica_id :: self.config.num_replicas]

if self.config.num_replicas > 1:
# disable per-replica wandb logging for multi-replica runs
Expand All @@ -70,14 +76,24 @@ def _get_input_params(
temperature=0,
top_p=1.0,
)
prompt_token_ids = [1] * request.num_prefill_tokens

return {
"prompt": None,
"prompt_token_ids": prompt_token_ids,
"sampling_params": sampling_params,
"arrival_time": first_request_time + request.arrived_at,
}
if self.run_correctness_tests:
return {
"prompt": request.prompt,
"prompt_token_ids": None,
"sampling_params": sampling_params,
"arrival_time": first_request_time + request.arrived_at,
}
else:

prompt_token_ids = [1] * request.num_prefill_tokens

return {
"prompt": None,
"prompt_token_ids": prompt_token_ids,
"sampling_params": sampling_params,
"arrival_time": first_request_time + request.arrived_at,
}

def warmup(self) -> None:
self.llm_engine.add_request(**self._get_input_params(self.requests[0], 0))
Expand Down Expand Up @@ -111,6 +127,11 @@ def _run(self) -> None:
num_steps += 1

for output in step_outputs:
if self.config.run_correctness_tests:
print("CORRECTNESS OUTPUT")
print(output.text)
self.correctness_output[output.seq_id] = self.correctness_output.get(output.seq_id, "") + output.text

if output.finished:
num_processed_requests += 1
pbar.update(1)
Expand Down Expand Up @@ -143,6 +164,17 @@ def run(self) -> None:
metric_store = self.llm_engine.get_metric_store()
return metric_store

def check_correctness(self) -> bool:
assert self.config.run_correctness_checks

with open(file_path, 'r') as file:
data = yaml.safe_load(file)

for k, v in data.items():
if self.correctness_output[k] != v: return False

return True


class BenchmarkRunnerLauncher:

Expand Down
45 changes: 45 additions & 0 deletions sarathi/benchmark/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,35 @@ def get_type():
return RequestLengthGeneratorType.FIXED



@dataclass
class DatasetRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig):
dataset: str = field(
default="ccdv/arxiv-summarization",
metadata={"help": "Path to the trace file for request lengths."},
)
meta_prompt: str = field(
default=None,
metadata={"help": "Meta prompt for the dataset."},
)
max_prompt_len: int = field(
default=4096, metadata={"help": "Maximum prompt length allowed."}
)
max_num_prompts: int = field(
default=300, metadata={"help": "Maximum number of prompts to use."}
)
max_decode_tokens: int = field(
default=512, metadata={"help": "Maximum number of decode tokens."}
)
tokenizer_model: str = field(
default="meta-llama/Meta-Llama-3-8B-Instruct", metadata={"help": "Name or path of the huggingface model to use for the tokenizer."}
)

@staticmethod
def get_type():
return RequestLengthGeneratorType.DATASET


@dataclass
class BaseRequestGeneratorConfig(BasePolyConfig):
seed: int = field(
Expand Down Expand Up @@ -214,6 +243,19 @@ class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig):
def get_type():
return RequestGeneratorType.TRACE

@dataclass
class CorrectnessTestConfig(BaseTestConfig):
run_correctness_tests: bool = field(
default=False, metadata={"help": "Collect correctness data in this run"}
)
run_correctness_baseline: bool = field(
default=False, metadata={"help": "Make this correctness ground truth for correctness tests"}
)
correctness_test_file: bool = field(
default=None, metadata={"help": "Ground truth file for model output. If run_correctness_baseline is True, then the model output will be saved to \
this file to be used as a ground truth file. Otherwise, the test will read from this file to be used as ground truth for \
the correctness test."}
)

@dataclass
class BenchmarkConfig(BaseEndpointConfig):
Expand All @@ -240,6 +282,9 @@ class BenchmarkConfig(BaseEndpointConfig):
request_generator_config: BaseRequestGeneratorConfig = field(
default_factory=SyntheticRequestGeneratorConfig
)
correctness_test_config: Optional[CorrectnessTestConfig] = field(
default_factory=CorrectnessTestConfig
)

def __post_init__(self):
super().__post_init__()
Expand Down
Loading