Skip to content

Commit a546957

Browse files
committed
Move rate string parsing into scenario
1 parent 0e226d1 commit a546957

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

src/guidellm/__main__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,6 @@ def parse_json(ctx, param, value): # noqa: ARG001
2727
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
2828

2929

30-
def parse_number_str(ctx, param, value): # noqa: ARG001
31-
if value is None:
32-
return None
33-
34-
values = value.split(",") if "," in value else [value]
35-
36-
try:
37-
return [float(val) for val in values]
38-
except ValueError as err:
39-
raise click.BadParameter(
40-
f"{param.name} must be a number or comma-separated list of numbers."
41-
) from err
42-
43-
4430
def set_if_not_default(ctx: click.Context, **kwargs):
4531
"""
4632
Set the value of a click option if it is not the default value.
@@ -157,7 +143,6 @@ def cli():
157143
@click.option(
158144
"--rate",
159145
default=GenerativeTextScenario.model_fields["rate"].default,
160-
callback=parse_number_str,
161146
help=(
162147
"The rates to run the benchmark at. "
163148
"Can be a single number or a comma-separated list of numbers. "

src/guidellm/benchmark/scenario.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
from collections.abc import Iterable
33
from pathlib import Path
4-
from typing import Any, Literal, Optional, TypeVar, Union
4+
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
55

66
import yaml
77
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
88
from loguru import logger
9+
from pydantic import BeforeValidator
910
from transformers.tokenization_utils_base import ( # type: ignore[import]
1011
PreTrainedTokenizerBase,
1112
)
@@ -17,6 +18,23 @@
1718

1819
__ALL__ = ["Scenario", "GenerativeTextScenario"]
1920

21+
22+
def parse_float_list(value: Union[str, float, list[float]]) -> list[float]:
23+
if isinstance(value, (int, float)):
24+
return [value]
25+
elif isinstance(value, list):
26+
return value
27+
28+
values = value.split(",") if "," in value else [value]
29+
30+
try:
31+
return [float(val) for val in values]
32+
except ValueError as err:
33+
raise ValueError(
34+
"must be a number or comma-separated list of numbers."
35+
) from err
36+
37+
2038
T = TypeVar("T", bound="Scenario")
2139

2240

@@ -63,7 +81,7 @@ class Config:
6381
data_args: Optional[dict[str, Any]] = None
6482
data_sampler: Optional[Literal["random"]] = None
6583
rate_type: Union[StrategyType, ProfileType]
66-
rate: Optional[Union[float, list[float]]] = None
84+
rate: Annotated[Optional[list[float]], BeforeValidator(parse_float_list)] = None
6785
max_seconds: Optional[float] = None
6886
max_requests: Optional[int] = None
6987
warmup_percent: Optional[float] = None

0 commit comments

Comments
 (0)