Skip to content

Commit e799682

Browse files
committed
Second Scenario implmentation attempt
1 parent b85c6b9 commit e799682

File tree

3 files changed

+160
-34
lines changed

3 files changed

+160
-34
lines changed

src/guidellm/__main__.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import click
77

88
from guidellm.backend import BackendType
9-
from guidellm.benchmark import ProfileType, benchmark_generative_text
9+
from guidellm.benchmark import ProfileType
10+
from guidellm.benchmark.entrypoints import benchmark_with_scenario
11+
from guidellm.benchmark.scenario import GenerativeTextScenario
1012
from guidellm.config import print_config
1113
from guidellm.scheduler import StrategyType
1214

@@ -38,6 +40,19 @@ def parse_number_str(ctx, param, value): # noqa: ARG001
3840
) from err
3941

4042

43+
def set_if_not_default(ctx: click.Context, **kwargs):
44+
"""
45+
Set the value of a click option if it is not the default value.
46+
This is useful for setting options that are not None by default.
47+
"""
48+
values = {}
49+
for k, v in kwargs.items():
50+
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
51+
values[k] = v
52+
53+
return values
54+
55+
4156
@click.group()
4257
def cli():
4358
pass
@@ -46,6 +61,14 @@ def cli():
4661
@cli.command(
4762
help="Run a benchmark against a generative model using the specified arguments."
4863
)
64+
@click.option(
65+
"--scenario",
66+
type=str,
67+
default=None,
68+
help=(
69+
"TODO: A scenario or path to config"
70+
),
71+
)
4972
@click.option(
5073
"--target",
5174
required=True,
@@ -59,20 +82,20 @@ def cli():
5982
"The type of backend to use to run requests against. Defaults to 'openai_http'."
6083
f" Supported types: {', '.join(get_args(BackendType))}"
6184
),
62-
default="openai_http",
85+
default=GenerativeTextScenario.backend_type,
6386
)
6487
@click.option(
6588
"--backend-args",
6689
callback=parse_json,
67-
default=None,
90+
default=GenerativeTextScenario.backend_args,
6891
help=(
6992
"A JSON string containing any arguments to pass to the backend as a "
7093
"dict with **kwargs."
7194
),
7295
)
7396
@click.option(
7497
"--model",
75-
default=None,
98+
default=GenerativeTextScenario.model,
7699
type=str,
77100
help=(
78101
"The ID of the model to benchmark within the backend. "
@@ -81,7 +104,7 @@ def cli():
81104
)
82105
@click.option(
83106
"--processor",
84-
default=None,
107+
default=GenerativeTextScenario.processor,
85108
type=str,
86109
help=(
87110
"The processor or tokenizer to use to calculate token counts for statistics "
@@ -91,7 +114,7 @@ def cli():
91114
)
92115
@click.option(
93116
"--processor-args",
94-
default=None,
117+
default=GenerativeTextScenario.processor_args,
95118
callback=parse_json,
96119
help=(
97120
"A JSON string containing any arguments to pass to the processor constructor "
@@ -110,6 +133,7 @@ def cli():
110133
)
111134
@click.option(
112135
"--data-args",
136+
default=GenerativeTextScenario.data_args,
113137
callback=parse_json,
114138
help=(
115139
"A JSON string containing any arguments to pass to the dataset creation "
@@ -118,7 +142,7 @@ def cli():
118142
)
119143
@click.option(
120144
"--data-sampler",
121-
default=None,
145+
default=GenerativeTextScenario.data_sampler,
122146
type=click.Choice(["random"]),
123147
help=(
124148
"The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -136,7 +160,7 @@ def cli():
136160
)
137161
@click.option(
138162
"--rate",
139-
default=None,
163+
default=GenerativeTextScenario.rate,
140164
callback=parse_number_str,
141165
help=(
142166
"The rates to run the benchmark at. "
@@ -150,6 +174,7 @@ def cli():
150174
@click.option(
151175
"--max-seconds",
152176
type=float,
177+
default=GenerativeTextScenario.max_seconds,
153178
help=(
154179
"The maximum number of seconds each benchmark can run for. "
155180
"If None, will run until max_requests or the data is exhausted."
@@ -158,6 +183,7 @@ def cli():
158183
@click.option(
159184
"--max-requests",
160185
type=int,
186+
default=GenerativeTextScenario.max_requests,
161187
help=(
162188
"The maximum number of requests each benchmark can run for. "
163189
"If None, will run until max_seconds or the data is exhausted."
@@ -166,7 +192,7 @@ def cli():
166192
@click.option(
167193
"--warmup-percent",
168194
type=float,
169-
default=None,
195+
default=GenerativeTextScenario.warmup_percent,
170196
help=(
171197
"The percent of the benchmark (based on max-seconds, max-requets, "
172198
"or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -176,6 +202,7 @@ def cli():
176202
@click.option(
177203
"--cooldown-percent",
178204
type=float,
205+
default=GenerativeTextScenario.cooldown_percent,
179206
help=(
180207
"The percent of the benchmark (based on max-seconds, max-requets, or lenth "
181208
"of dataset) to run as a cooldown and not include in the final results. "
@@ -185,16 +212,19 @@ def cli():
185212
@click.option(
186213
"--disable-progress",
187214
is_flag=True,
215+
default=not GenerativeTextScenario.show_progress,
188216
help="Set this flag to disable progress updates to the console",
189217
)
190218
@click.option(
191219
"--display-scheduler-stats",
192220
is_flag=True,
221+
default=GenerativeTextScenario.show_progress_scheduler_stats,
193222
help="Set this flag to display stats for the processes running the benchmarks",
194223
)
195224
@click.option(
196225
"--disable-console-outputs",
197226
is_flag=True,
227+
default=not GenerativeTextScenario.output_console,
198228
help="Set this flag to disable console output",
199229
)
200230
@click.option(
@@ -211,6 +241,7 @@ def cli():
211241
@click.option(
212242
"--output-extras",
213243
callback=parse_json,
244+
default=GenerativeTextScenario.output_extras,
214245
help="A JSON string of extra data to save with the output benchmarks",
215246
)
216247
@click.option(
@@ -220,15 +251,16 @@ def cli():
220251
"The number of samples to save in the output file. "
221252
"If None (default), will save all samples."
222253
),
223-
default=None,
254+
default=GenerativeTextScenario.output_sampling,
224255
)
225256
@click.option(
226257
"--random-seed",
227-
default=42,
258+
default=GenerativeTextScenario.random_seed,
228259
type=int,
229260
help="The random seed to use for benchmarking to ensure reproducibility.",
230261
)
231262
def benchmark(
263+
scenario,
232264
target,
233265
backend_type,
234266
backend_args,
@@ -252,30 +284,48 @@ def benchmark(
252284
output_sampling,
253285
random_seed,
254286
):
287+
click_ctx = click.get_current_context()
288+
289+
# If a scenario file was specified read from it
290+
# TODO: This should probably be a factory method
291+
if scenario is None:
292+
_scenario = {}
293+
else:
294+
# TODO: Support pre-defined scenarios
295+
# TODO: Support other formats
296+
with Path(scenario).open() as f:
297+
_scenario = json.load(f)
298+
299+
# If any command line arguments are specified, override the scenario
300+
_scenario.update(set_if_not_default(
301+
click_ctx,
302+
target=target,
303+
backend_type=backend_type,
304+
backend_args=backend_args,
305+
model=model,
306+
processor=processor,
307+
processor_args=processor_args,
308+
data=data,
309+
data_args=data_args,
310+
data_sampler=data_sampler,
311+
rate_type=rate_type,
312+
rate=rate,
313+
max_seconds=max_seconds,
314+
max_requests=max_requests,
315+
warmup_percent=warmup_percent,
316+
cooldown_percent=cooldown_percent,
317+
show_progress=not disable_progress,
318+
show_progress_scheduler_stats=display_scheduler_stats,
319+
output_console=not disable_console_outputs,
320+
output_path=output_path,
321+
output_extras=output_extras,
322+
output_sampling=output_sampling,
323+
random_seed=random_seed,
324+
))
325+
255326
asyncio.run(
256-
benchmark_generative_text(
257-
target=target,
258-
backend_type=backend_type,
259-
backend_args=backend_args,
260-
model=model,
261-
processor=processor,
262-
processor_args=processor_args,
263-
data=data,
264-
data_args=data_args,
265-
data_sampler=data_sampler,
266-
rate_type=rate_type,
267-
rate=rate,
268-
max_seconds=max_seconds,
269-
max_requests=max_requests,
270-
warmup_percent=warmup_percent,
271-
cooldown_percent=cooldown_percent,
272-
show_progress=not disable_progress,
273-
show_progress_scheduler_stats=display_scheduler_stats,
274-
output_console=not disable_console_outputs,
275-
output_path=output_path,
276-
output_extras=output_extras,
277-
output_sampling=output_sampling,
278-
random_seed=random_seed,
327+
benchmark_with_scenario(
328+
scenario=GenerativeTextScenario(**_scenario)
279329
)
280330
)
281331

src/guidellm/benchmark/entrypoints.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,23 @@
1515
)
1616
from guidellm.benchmark.profile import ProfileType, create_profile
1717
from guidellm.benchmark.progress import GenerativeTextBenchmarkerProgressDisplay
18+
from guidellm.benchmark.scenario import GenerativeTextScenario, Scenario
1819
from guidellm.request import GenerativeRequestLoader
1920
from guidellm.scheduler import StrategyType
2021

22+
type benchmark_type = Literal["generative_text"]
23+
24+
25+
async def benchmark_with_scenario(scenario: Scenario, **kwargs):
26+
"""
27+
Run a benchmark using a scenario and specify any extra arguments
28+
"""
29+
30+
if isinstance(scenario, GenerativeTextScenario):
31+
return await benchmark_generative_text(**vars(scenario), **kwargs)
32+
else:
33+
raise ValueError(f"Unsupported Scenario type {type(scenario)}")
34+
2135

2236
async def benchmark_generative_text(
2337
target: str,

src/guidellm/benchmark/scenario.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from collections.abc import Iterable
2+
from pathlib import Path
3+
from typing import Any, Literal, Optional, Self, Union
4+
5+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6+
from transformers.tokenization_utils_base import ( # type: ignore[import]
7+
PreTrainedTokenizerBase,
8+
)
9+
10+
from guidellm.backend.backend import BackendType
11+
from guidellm.benchmark.profile import ProfileType
12+
from guidellm.objects.pydantic import StandardBaseModel
13+
from guidellm.scheduler.strategy import StrategyType
14+
15+
__ALL__ = ["Scenario", "GenerativeTextScenario"]
16+
17+
18+
class Scenario(StandardBaseModel):
19+
target: str
20+
21+
def _update(self, **fields: Any) -> Self:
22+
for k, v in fields.items():
23+
if not hasattr(self, k):
24+
raise ValueError(f"Invalid field {k}")
25+
setattr(self, k, v)
26+
27+
return self
28+
29+
def update(self, **fields: Any) -> Self:
30+
return self._update(**{k: v for k, v in fields.items() if v is not None})
31+
32+
33+
class GenerativeTextScenario(Scenario):
34+
backend_type: BackendType = "openai_http"
35+
backend_args: Optional[dict[str, Any]] = None
36+
model: Optional[str] = None
37+
processor: Optional[Union[str, Path, PreTrainedTokenizerBase]] = None
38+
processor_args: Optional[dict[str, Any]] = None
39+
data: Union[
40+
str,
41+
Path,
42+
Iterable[Union[str, dict[str, Any]]],
43+
Dataset,
44+
DatasetDict,
45+
IterableDataset,
46+
IterableDatasetDict,
47+
]
48+
data_args: Optional[dict[str, Any]] = None
49+
data_sampler: Optional[Literal["random"]] = None
50+
rate_type: Union[StrategyType, ProfileType]
51+
rate: Optional[Union[int, float, list[Union[int, float]]]] = None
52+
max_seconds: Optional[float] = None
53+
max_requests: Optional[int] = None
54+
warmup_percent: Optional[float] = None
55+
cooldown_percent: Optional[float] = None
56+
show_progress: bool = True
57+
show_progress_scheduler_stats: bool = True
58+
output_console: bool = True
59+
output_path: Optional[Union[str, Path]] = None
60+
output_extras: Optional[dict[str, Any]] = None
61+
output_sampling: Optional[int] = None
62+
random_seed: int = 42

0 commit comments

Comments
 (0)