Skip to content

Commit c79b60e

Browse files
committed
Move scenario helper methods to base pydantic class
1 parent f0fe82c commit c79b60e

File tree

3 files changed

+29
-31
lines changed

3 files changed

+29
-31
lines changed

src/guidellm/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def benchmark(
290290
_scenario = GenerativeTextScenario.model_validate(overrides)
291291
else:
292292
# TODO: Support pre-defined scenarios
293-
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
293+
_scenario = GenerativeTextScenario.from_file(Path(scenario), overrides)
294294
except ValidationError as e:
295295
errs = e.errors(include_url=False, include_context=True, include_input=True)
296296
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")

src/guidellm/benchmark/scenario.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
import json
21
from collections.abc import Iterable
32
from pathlib import Path
4-
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
3+
from typing import Annotated, Any, Literal, Optional, Union
54

6-
import yaml
75
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
8-
from loguru import logger
96
from pydantic import BeforeValidator, Field, NonNegativeInt, PositiveFloat, PositiveInt
107
from transformers.tokenization_utils_base import ( # type: ignore[import]
118
PreTrainedTokenizerBase,
@@ -35,34 +32,9 @@ def parse_float_list(value: Union[str, float, list[float]]) -> list[float]:
3532
) from err
3633

3734

38-
T = TypeVar("T", bound="Scenario")
39-
40-
4135
class Scenario(StandardBaseModel):
4236
target: str
4337

44-
@classmethod
45-
def get_default(cls: type[T], field: str) -> Any:
46-
"""Get default values for model fields"""
47-
return cls.model_fields[field].default
48-
49-
@classmethod
50-
def from_file(
51-
cls: type[T], filename: Union[str, Path], overrides: Optional[dict] = None
52-
) -> T:
53-
try:
54-
with open(filename) as f:
55-
if str(filename).endswith(".yaml") or str(filename).endswith(".yml"):
56-
data = yaml.safe_load(f)
57-
else: # Assume everything else is json
58-
data = json.load(f)
59-
except (json.JSONDecodeError, yaml.YAMLError) as e:
60-
logger.error("Failed to parse scenario")
61-
raise e
62-
63-
data.update(overrides)
64-
return cls.model_validate(data)
65-
6638

6739
class GenerativeTextScenario(Scenario):
6840
# FIXME: This solves an issue with Pydantic and class types

src/guidellm/objects/pydantic.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Any, Generic, TypeVar
1+
import json
2+
from pathlib import Path
3+
from typing import Any, Generic, Optional, TypeVar
24

5+
import yaml
36
from loguru import logger
47
from pydantic import BaseModel, ConfigDict, Field
58

69
__all__ = ["StandardBaseModel", "StatusBreakdown"]
710

11+
T = TypeVar("T", bound="StandardBaseModel")
812

913
class StandardBaseModel(BaseModel):
1014
"""
@@ -27,6 +31,28 @@ def __init__(self, /, **data: Any) -> None:
2731
data,
2832
)
2933

34+
@classmethod
35+
def get_default(cls: type[T], field: str) -> Any:
36+
"""Get default values for model fields"""
37+
return cls.model_fields[field].default
38+
39+
@classmethod
40+
def from_file(
41+
cls: type[T], filename: Path, overrides: Optional[dict] = None
42+
) -> T:
43+
try:
44+
with filename.open() as f:
45+
if str(filename).endswith((".yaml", ".yml")):
46+
data = yaml.safe_load(f)
47+
else: # Assume everything else is json
48+
data = json.load(f)
49+
except (json.JSONDecodeError, yaml.YAMLError) as e:
50+
logger.error(f"Failed to parse {filename} as type {cls.__name__}")
51+
raise e
52+
53+
data.update(overrides)
54+
return cls.model_validate(data)
55+
3056

3157
SuccessfulT = TypeVar("SuccessfulT")
3258
ErroredT = TypeVar("ErroredT")

0 commit comments

Comments
 (0)