|
4 | 4 | from typing import get_args
|
5 | 5 |
|
6 | 6 | import click
|
| 7 | +from pydantic import ValidationError |
7 | 8 |
|
8 | 9 | from guidellm.backend import BackendType
|
9 | 10 | from guidellm.benchmark import ProfileType
|
@@ -65,13 +66,10 @@ def cli():
|
65 | 66 | "--scenario",
|
66 | 67 | type=str,
|
67 | 68 | default=None,
|
68 |
| - help=( |
69 |
| - "TODO: A scenario or path to config" |
70 |
| - ), |
| 69 | + help=("TODO: A scenario or path to config"), |
71 | 70 | )
|
72 | 71 | @click.option(
|
73 | 72 | "--target",
|
74 |
| - required=True, |
75 | 73 | type=str,
|
76 | 74 | help="The target path for the backend to run benchmarks against. For example, http://localhost:8000",
|
77 | 75 | )
|
@@ -123,7 +121,6 @@ def cli():
|
123 | 121 | )
|
124 | 122 | @click.option(
|
125 | 123 | "--data",
|
126 |
| - required=True, |
127 | 124 | type=str,
|
128 | 125 | help=(
|
129 | 126 | "The HuggingFace dataset ID, a path to a HuggingFace dataset, "
|
@@ -151,7 +148,6 @@ def cli():
|
151 | 148 | )
|
152 | 149 | @click.option(
|
153 | 150 | "--rate-type",
|
154 |
| - required=True, |
155 | 151 | type=click.Choice(STRATEGY_PROFILE_CHOICES),
|
156 | 152 | help=(
|
157 | 153 | "The type of benchmark to run. "
|
@@ -303,12 +299,19 @@ def benchmark(
|
303 | 299 | random_seed=random_seed,
|
304 | 300 | )
|
305 | 301 |
|
306 |
| - # If a scenario file was specified read from it |
307 |
| - if scenario is None: |
308 |
| - _scenario = GenerativeTextScenario.model_validate(overrides) |
309 |
| - else: |
310 |
| - # TODO: Support pre-defined scenarios |
311 |
| - _scenario = GenerativeTextScenario.from_file(scenario, overrides) |
| 302 | + try: |
| 303 | + # If a scenario file was specified read from it |
| 304 | + if scenario is None: |
| 305 | + _scenario = GenerativeTextScenario.model_validate(overrides) |
| 306 | + else: |
| 307 | + # TODO: Support pre-defined scenarios |
| 308 | + _scenario = GenerativeTextScenario.from_file(scenario, overrides) |
| 309 | + except ValidationError as e: |
| 310 | + errs = e.errors(include_url=False, include_context=True, include_input=True) |
| 311 | + param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-") |
| 312 | + raise click.BadParameter( |
| 313 | + errs[0]["msg"], ctx=click_ctx, param_hint=param_name |
| 314 | + ) from e |
312 | 315 |
|
313 | 316 | asyncio.run(
|
314 | 317 | benchmark_with_scenario(
|
|
0 commit comments