Skip to content

Commit 099f4e5

Browse files
committed
Move cli helpers to separate file and add click Union type back
1 parent c79b60e commit 099f4e5

File tree

2 files changed

+74
-30
lines changed

2 files changed

+74
-30
lines changed

src/guidellm/__main__.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
2-
import json
32
from pathlib import Path
4-
from typing import Any, get_args
3+
from typing import get_args
54

65
import click
76
from pydantic import ValidationError
@@ -12,34 +11,13 @@
1211
from guidellm.benchmark.scenario import GenerativeTextScenario
1312
from guidellm.config import print_config
1413
from guidellm.scheduler import StrategyType
14+
from guidellm.utils import cli as cli_tools
1515

1616
STRATEGY_PROFILE_CHOICES = set(
1717
list(get_args(ProfileType)) + list(get_args(StrategyType))
1818
)
1919

2020

21-
def parse_json(ctx, param, value): # noqa: ARG001
22-
if value is None:
23-
return None
24-
try:
25-
return json.loads(value)
26-
except json.JSONDecodeError as err:
27-
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
28-
29-
30-
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
31-
"""
32-
Set the value of a click option if it is not the default value.
33-
This is useful for setting options that are not None by default.
34-
"""
35-
values = {}
36-
for k, v in kwargs.items():
37-
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
38-
values[k] = v
39-
40-
return values
41-
42-
4321
@click.group()
4422
def cli():
4523
pass
@@ -50,7 +28,10 @@ def cli():
5028
)
5129
@click.option(
5230
"--scenario",
53-
type=str,
31+
type=cli_tools.Union(
32+
click.Path(exists=True, readable=True, file_okay=True, dir_okay=False),
33+
click.STRING
34+
),
5435
default=None,
5536
help=("TODO: A scenario or path to config"),
5637
)
@@ -70,7 +51,7 @@ def cli():
7051
)
7152
@click.option(
7253
"--backend-args",
73-
callback=parse_json,
54+
callback=cli_tools.parse_json,
7455
default=GenerativeTextScenario.get_default("backend_args"),
7556
help=(
7657
"A JSON string containing any arguments to pass to the backend as a "
@@ -99,7 +80,7 @@ def cli():
9980
@click.option(
10081
"--processor-args",
10182
default=GenerativeTextScenario.get_default("processor_args"),
102-
callback=parse_json,
83+
callback=cli_tools.parse_json,
10384
help=(
10485
"A JSON string containing any arguments to pass to the processor constructor "
10586
"as a dict with **kwargs."
@@ -117,7 +98,7 @@ def cli():
11798
@click.option(
11899
"--data-args",
119100
default=GenerativeTextScenario.get_default("data_args"),
120-
callback=parse_json,
101+
callback=cli_tools.parse_json,
121102
help=(
122103
"A JSON string containing any arguments to pass to the dataset creation "
123104
"as a dict with **kwargs."
@@ -218,7 +199,7 @@ def cli():
218199
)
219200
@click.option(
220201
"--output-extras",
221-
callback=parse_json,
202+
callback=cli_tools.parse_json,
222203
help="A JSON string of extra data to save with the output benchmarks",
223204
)
224205
@click.option(
@@ -263,7 +244,7 @@ def benchmark(
263244
):
264245
click_ctx = click.get_current_context()
265246

266-
overrides = set_if_not_default(
247+
overrides = cli_tools.set_if_not_default(
267248
click_ctx,
268249
target=target,
269250
backend_type=backend_type,

src/guidellm/utils/cli.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
from typing import Any
3+
4+
import click
5+
6+
7+
def parse_json(ctx, param, value): # noqa: ARG001
8+
if value is None:
9+
return None
10+
try:
11+
return json.loads(value)
12+
except json.JSONDecodeError as err:
13+
raise click.BadParameter(f"{param.name} must be a valid JSON string.") from err
14+
15+
16+
def set_if_not_default(ctx: click.Context, **kwargs) -> dict[str, Any]:
17+
"""
18+
Set the value of a click option if it is not the default value.
19+
This is useful for setting options that are not None by default.
20+
"""
21+
values = {}
22+
for k, v in kwargs.items():
23+
if ctx.get_parameter_source(k) != click.core.ParameterSource.DEFAULT:
24+
values[k] = v
25+
26+
return values
27+
28+
29+
class Union(click.ParamType):
30+
"""
31+
A custom click parameter type that allows for multiple types to be accepted.
32+
"""
33+
34+
def __init__(self, *types: click.ParamType):
35+
self.types = types
36+
self.name = "".join(t.name for t in types)
37+
38+
def convert(self, value, param, ctx):
39+
fails = []
40+
for t in self.types:
41+
try:
42+
return t.convert(value, param, ctx)
43+
except click.BadParameter as e:
44+
fails.append(str(e))
45+
continue
46+
47+
self.fail("; ".join(fails) or f"Invalid value: {value}") # noqa: RET503
48+
49+
50+
def get_metavar(self, param: click.Parameter) -> str:
51+
def get_choices(t: click.ParamType) -> str:
52+
meta = t.get_metavar(param)
53+
return meta if meta is not None else t.name
54+
55+
# Get the choices for each type in the union.
56+
choices_str = "|".join(map(get_choices, self.types))
57+
58+
# Use curly braces to indicate a required argument.
59+
if param.required and param.param_type_name == "argument":
60+
return f"{{{choices_str}}}"
61+
62+
# Use square braces to indicate an option or optional argument.
63+
return f"[{choices_str}]"

0 commit comments

Comments
 (0)