Skip to content

Commit f0fe82c

Browse files
committed
Add a helper method to get scenario defaults
1 parent 9f08500 commit f0fe82c

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

src/guidellm/__main__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import json
33
from pathlib import Path
4-
from typing import get_args
4+
from typing import Any, get_args
55

66
import click
77
from pydantic import ValidationError
@@ -27,7 +27,7 @@ def parse_json(ctx, param, value): # noqa: ARG001
2727
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
2828

2929

30-
def set_if_not_default(ctx: click.Context, **kwargs):
30+
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
3131
"""
3232
Set the value of a click option if it is not the default value.
3333
This is useful for setting options that are not None by default.
@@ -66,20 +66,20 @@ def cli():
6666
"The type of backend to use to run requests against. Defaults to 'openai_http'."
6767
f" Supported types: {', '.join(get_args(BackendType))}"
6868
),
69-
default=GenerativeTextScenario.model_fields["backend_type"].default,
69+
default=GenerativeTextScenario.get_default("backend_type"),
7070
)
7171
@click.option(
7272
"--backend-args",
7373
callback=parse_json,
74-
default=GenerativeTextScenario.model_fields["backend_args"].default,
74+
default=GenerativeTextScenario.get_default("backend_args"),
7575
help=(
7676
"A JSON string containing any arguments to pass to the backend as a "
7777
"dict with **kwargs."
7878
),
7979
)
8080
@click.option(
8181
"--model",
82-
default=GenerativeTextScenario.model_fields["model"].default,
82+
default=GenerativeTextScenario.get_default("model"),
8383
type=str,
8484
help=(
8585
"The ID of the model to benchmark within the backend. "
@@ -88,7 +88,7 @@ def cli():
8888
)
8989
@click.option(
9090
"--processor",
91-
default=GenerativeTextScenario.model_fields["processor"].default,
91+
default=GenerativeTextScenario.get_default("processor"),
9292
type=str,
9393
help=(
9494
"The processor or tokenizer to use to calculate token counts for statistics "
@@ -98,7 +98,7 @@ def cli():
9898
)
9999
@click.option(
100100
"--processor-args",
101-
default=GenerativeTextScenario.model_fields["processor_args"].default,
101+
default=GenerativeTextScenario.get_default("processor_args"),
102102
callback=parse_json,
103103
help=(
104104
"A JSON string containing any arguments to pass to the processor constructor "
@@ -116,7 +116,7 @@ def cli():
116116
)
117117
@click.option(
118118
"--data-args",
119-
default=GenerativeTextScenario.model_fields["data_args"].default,
119+
default=GenerativeTextScenario.get_default("data_args"),
120120
callback=parse_json,
121121
help=(
122122
"A JSON string containing any arguments to pass to the dataset creation "
@@ -125,7 +125,7 @@ def cli():
125125
)
126126
@click.option(
127127
"--data-sampler",
128-
default=GenerativeTextScenario.model_fields["data_sampler"].default,
128+
default=GenerativeTextScenario.get_default("data_sampler"),
129129
type=click.Choice(["random"]),
130130
help=(
131131
"The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -142,7 +142,7 @@ def cli():
142142
)
143143
@click.option(
144144
"--rate",
145-
default=GenerativeTextScenario.model_fields["rate"].default,
145+
default=GenerativeTextScenario.get_default("rate"),
146146
help=(
147147
"The rates to run the benchmark at. "
148148
"Can be a single number or a comma-separated list of numbers. "
@@ -155,7 +155,7 @@ def cli():
155155
@click.option(
156156
"--max-seconds",
157157
type=float,
158-
default=GenerativeTextScenario.model_fields["max_seconds"].default,
158+
default=GenerativeTextScenario.get_default("max_seconds"),
159159
help=(
160160
"The maximum number of seconds each benchmark can run for. "
161161
"If None, will run until max_requests or the data is exhausted."
@@ -164,7 +164,7 @@ def cli():
164164
@click.option(
165165
"--max-requests",
166166
type=int,
167-
default=GenerativeTextScenario.model_fields["max_requests"].default,
167+
default=GenerativeTextScenario.get_default("max_requests"),
168168
help=(
169169
"The maximum number of requests each benchmark can run for. "
170170
"If None, will run until max_seconds or the data is exhausted."
@@ -173,7 +173,7 @@ def cli():
173173
@click.option(
174174
"--warmup-percent",
175175
type=float,
176-
default=GenerativeTextScenario.model_fields["warmup_percent"].default,
176+
default=GenerativeTextScenario.get_default("warmup_percent"),
177177
help=(
178178
"The percent of the benchmark (based on max-seconds, max-requets, "
179179
"or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -183,7 +183,7 @@ def cli():
183183
@click.option(
184184
"--cooldown-percent",
185185
type=float,
186-
default=GenerativeTextScenario.model_fields["cooldown_percent"].default,
186+
default=GenerativeTextScenario.get_default("cooldown_percent"),
187187
help=(
188188
"The percent of the benchmark (based on max-seconds, max-requets, or lenth "
189189
"of dataset) to run as a cooldown and not include in the final results. "
@@ -228,11 +228,11 @@ def cli():
228228
"The number of samples to save in the output file. "
229229
"If None (default), will save all samples."
230230
),
231-
default=GenerativeTextScenario.model_fields["output_sampling"].default,
231+
default=GenerativeTextScenario.get_default("output_sampling"),
232232
)
233233
@click.option(
234234
"--random-seed",
235-
default=GenerativeTextScenario.model_fields["random_seed"].default,
235+
default=GenerativeTextScenario.get_default("random_seed"),
236236
type=int,
237237
help="The random seed to use for benchmarking to ensure reproducibility.",
238238
)

src/guidellm/benchmark/entrypoints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from guidellm.request import GenerativeRequestLoader
2020
from guidellm.scheduler import StrategyType
2121

22-
type benchmark_type = Literal["generative_text"]
23-
2422

2523
async def benchmark_with_scenario(scenario: Scenario, **kwargs):
2624
"""

src/guidellm/benchmark/scenario.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def parse_float_list(value: Union[str, float, list[float]]) -> list[float]:
4141
class Scenario(StandardBaseModel):
4242
target: str
4343

44+
@classmethod
45+
def get_default(cls: type[T], field: str) -> Any:
46+
"""Get default values for model fields"""
47+
return cls.model_fields[field].default
48+
4449
@classmethod
4550
def from_file(
4651
cls: type[T], filename: Union[str, Path], overrides: Optional[dict] = None

0 commit comments

Comments
 (0)