Skip to content

Commit 0baf0d6

Browse files
committed
Handle reading scenario from file in factory method
1 parent 0811fc5 commit 0baf0d6

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

src/guidellm/__main__.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,7 @@ def benchmark(
282282
):
283283
click_ctx = click.get_current_context()
284284

285-
# If a scenario file was specified read from it
286-
# TODO: This should probably be a factory method
287-
if scenario is None:
288-
_scenario = {}
289-
else:
290-
# TODO: Support pre-defined scenarios
291-
# TODO: Support other formats
292-
with Path(scenario).open() as f:
293-
_scenario = json.load(f)
294-
295-
# If any command line arguments are specified, override the scenario
296-
_scenario.update(set_if_not_default(
285+
overrides = set_if_not_default(
297286
click_ctx,
298287
target=target,
299288
backend_type=backend_type,
@@ -312,11 +301,18 @@ def benchmark(
312301
cooldown_percent=cooldown_percent,
313302
output_sampling=output_sampling,
314303
random_seed=random_seed,
315-
))
304+
)
305+
306+
# If a scenario file was specified read from it
307+
if scenario is None:
308+
_scenario = GenerativeTextScenario.model_validate(overrides)
309+
else:
310+
# TODO: Support pre-defined scenarios
311+
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
316312

317313
asyncio.run(
318314
benchmark_with_scenario(
319-
scenario=GenerativeTextScenario(**_scenario),
315+
scenario=_scenario,
320316
show_progress=not disable_progress,
321317
show_progress_scheduler_stats=display_scheduler_stats,
322318
output_console=not disable_console_outputs,

src/guidellm/benchmark/scenario.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
12
from collections.abc import Iterable
23
from pathlib import Path
3-
from typing import Any, Literal, Optional, Self, Union
4+
from typing import Any, Literal, Optional, TypeVar, Union
45

6+
import yaml
57
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
8+
from loguru import logger
69
from transformers.tokenization_utils_base import ( # type: ignore[import]
710
PreTrainedTokenizerBase,
811
)
@@ -14,20 +17,28 @@
1417

1518
__ALL__ = ["Scenario", "GenerativeTextScenario"]
1619

20+
T = TypeVar("T", bound="Scenario")
21+
1722

1823
class Scenario(StandardBaseModel):
1924
target: str
2025

21-
def _update(self, **fields: Any) -> Self:
22-
for k, v in fields.items():
23-
if not hasattr(self, k):
24-
raise ValueError(f"Invalid field {k}")
25-
setattr(self, k, v)
26-
27-
return self
26+
@classmethod
27+
def from_file(
28+
cls: type[T], filename: Union[str, Path], overrides: Optional[dict] = None
29+
) -> T:
30+
try:
31+
with open(filename) as f:
32+
if str(filename).endswith(".yaml") or str(filename).endswith(".yml"):
33+
data = yaml.safe_load(f)
34+
else: # Assume everything else is json
35+
data = json.load(f)
36+
except (json.JSONDecodeError, yaml.YAMLError) as e:
37+
logger.error("Failed to parse scenario")
38+
raise e
2839

29-
def update(self, **fields: Any) -> Self:
30-
return self._update(**{k: v for k, v in fields.items() if v is not None})
40+
data.update(overrides)
41+
return cls.model_validate(data)
3142

3243

3344
class GenerativeTextScenario(Scenario):

0 commit comments

Comments
 (0)