Skip to content

Commit 0e226d1

Browse files
committed
Handle required arg parsing with pydantic
1 parent 0baf0d6 commit 0e226d1

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/guidellm/__main__.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import get_args
55

66
import click
7+
from pydantic import ValidationError
78

89
from guidellm.backend import BackendType
910
from guidellm.benchmark import ProfileType
@@ -65,13 +66,10 @@ def cli():
6566
"--scenario",
6667
type=str,
6768
default=None,
68-
help=(
69-
"TODO: A scenario or path to config"
70-
),
69+
help=("TODO: A scenario or path to config"),
7170
)
7271
@click.option(
7372
"--target",
74-
required=True,
7573
type=str,
7674
help="The target path for the backend to run benchmarks against. For example, http://localhost:8000",
7775
)
@@ -123,7 +121,6 @@ def cli():
123121
)
124122
@click.option(
125123
"--data",
126-
required=True,
127124
type=str,
128125
help=(
129126
"The HuggingFace dataset ID, a path to a HuggingFace dataset, "
@@ -151,7 +148,6 @@ def cli():
151148
)
152149
@click.option(
153150
"--rate-type",
154-
required=True,
155151
type=click.Choice(STRATEGY_PROFILE_CHOICES),
156152
help=(
157153
"The type of benchmark to run. "
@@ -303,12 +299,19 @@ def benchmark(
303299
random_seed=random_seed,
304300
)
305301

306-
# If a scenario file was specified read from it
307-
if scenario is None:
308-
_scenario = GenerativeTextScenario.model_validate(overrides)
309-
else:
310-
# TODO: Support pre-defined scenarios
311-
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
302+
try:
303+
# If a scenario file was specified read from it
304+
if scenario is None:
305+
_scenario = GenerativeTextScenario.model_validate(overrides)
306+
else:
307+
# TODO: Support pre-defined scenarios
308+
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
309+
except ValidationError as e:
310+
errs = e.errors(include_url=False, include_context=True, include_input=True)
311+
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")
312+
raise click.BadParameter(
313+
errs[0]["msg"], ctx=click_ctx, param_hint=param_name
314+
) from e
312315

313316
asyncio.run(
314317
benchmark_with_scenario(

0 commit comments

Comments
 (0)