|
1 | 1 | import json
|
2 | 2 | from collections.abc import Iterable
|
3 | 3 | from pathlib import Path
|
4 |
| -from typing import Any, Literal, Optional, TypeVar, Union |
| 4 | +from typing import Annotated, Any, Literal, Optional, TypeVar, Union |
5 | 5 |
|
6 | 6 | import yaml
|
7 | 7 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
8 | 8 | from loguru import logger
|
| 9 | +from pydantic import BeforeValidator |
9 | 10 | from transformers.tokenization_utils_base import ( # type: ignore[import]
|
10 | 11 | PreTrainedTokenizerBase,
|
11 | 12 | )
|
|
17 | 18 |
|
18 | 19 | __ALL__ = ["Scenario", "GenerativeTextScenario"]
|
19 | 20 |
|
| 21 | + |
| 22 | +def parse_float_list(value: Union[str, float, list[float]]) -> list[float]: |
| 23 | + if isinstance(value, (int, float)): |
| 24 | + return [value] |
| 25 | + elif isinstance(value, list): |
| 26 | + return value |
| 27 | + |
| 28 | + values = value.split(",") if "," in value else [value] |
| 29 | + |
| 30 | + try: |
| 31 | + return [float(val) for val in values] |
| 32 | + except ValueError as err: |
| 33 | + raise ValueError( |
| 34 | + "must be a number or comma-separated list of numbers." |
| 35 | + ) from err |
| 36 | + |
| 37 | + |
20 | 38 | T = TypeVar("T", bound="Scenario")
|
21 | 39 |
|
22 | 40 |
|
@@ -63,7 +81,7 @@ class Config:
|
63 | 81 | data_args: Optional[dict[str, Any]] = None
|
64 | 82 | data_sampler: Optional[Literal["random"]] = None
|
65 | 83 | rate_type: Union[StrategyType, ProfileType]
|
66 |
| - rate: Optional[Union[float, list[float]]] = None |
| 84 | + rate: Annotated[Optional[list[float]], BeforeValidator(parse_float_list)] = None |
67 | 85 | max_seconds: Optional[float] = None
|
68 | 86 | max_requests: Optional[int] = None
|
69 | 87 | warmup_percent: Optional[float] = None
|
|
0 commit comments