66import click
77
88from guidellm .backend import BackendType
9- from guidellm .benchmark import ProfileType , benchmark_generative_text
9+ from guidellm .benchmark import ProfileType
10+ from guidellm .benchmark .entrypoints import benchmark_with_scenario
11+ from guidellm .benchmark .scenario import GenerativeTextScenario
1012from guidellm .config import print_config
1113from guidellm .scheduler import StrategyType
1214
@@ -38,6 +40,19 @@ def parse_number_str(ctx, param, value): # noqa: ARG001
3840 ) from err
3941
4042
43+ def set_if_not_default (ctx : click .Context , ** kwargs ):
44+ """
45+ Set the value of a click option if it is not the default value.
46+ This is useful for setting options that are not None by default.
47+ """
48+ values = {}
49+ for k , v in kwargs .items ():
50+ if ctx .get_parameter_source (k ) != click .core .ParameterSource .DEFAULT :
51+ values [k ] = v
52+
53+ return values
54+
55+
4156@click .group ()
4257def cli ():
4358 pass
@@ -46,6 +61,14 @@ def cli():
4661@cli .command (
4762 help = "Run a benchmark against a generative model using the specified arguments."
4863)
64+ @click .option (
65+ "--scenario" ,
66+ type = str ,
67+ default = None ,
68+ help = (
69+ "TODO: A scenario or path to config"
70+ ),
71+ )
4972@click .option (
5073 "--target" ,
5174 required = True ,
@@ -59,20 +82,20 @@ def cli():
5982 "The type of backend to use to run requests against. Defaults to 'openai_http'."
6083 f" Supported types: { ', ' .join (get_args (BackendType ))} "
6184 ),
62- default = "openai_http" ,
85+ default = GenerativeTextScenario . backend_type ,
6386)
6487@click .option (
6588 "--backend-args" ,
6689 callback = parse_json ,
67- default = None ,
90+ default = GenerativeTextScenario . backend_args ,
6891 help = (
6992 "A JSON string containing any arguments to pass to the backend as a "
7093 "dict with **kwargs."
7194 ),
7295)
7396@click .option (
7497 "--model" ,
75- default = None ,
98+ default = GenerativeTextScenario . model ,
7699 type = str ,
77100 help = (
78101 "The ID of the model to benchmark within the backend. "
@@ -81,7 +104,7 @@ def cli():
81104)
82105@click .option (
83106 "--processor" ,
84- default = None ,
107+ default = GenerativeTextScenario . processor ,
85108 type = str ,
86109 help = (
87110 "The processor or tokenizer to use to calculate token counts for statistics "
@@ -91,7 +114,7 @@ def cli():
91114)
92115@click .option (
93116 "--processor-args" ,
94- default = None ,
117+ default = GenerativeTextScenario . processor_args ,
95118 callback = parse_json ,
96119 help = (
97120 "A JSON string containing any arguments to pass to the processor constructor "
@@ -110,6 +133,7 @@ def cli():
110133)
111134@click .option (
112135 "--data-args" ,
136+ default = GenerativeTextScenario .data_args ,
113137 callback = parse_json ,
114138 help = (
115139 "A JSON string containing any arguments to pass to the dataset creation "
@@ -118,7 +142,7 @@ def cli():
118142)
119143@click .option (
120144 "--data-sampler" ,
121- default = None ,
145+ default = GenerativeTextScenario . data_sampler ,
122146 type = click .Choice (["random" ]),
123147 help = (
124148 "The data sampler type to use. 'random' will add a random shuffle on the data. "
@@ -136,7 +160,7 @@ def cli():
136160)
137161@click .option (
138162 "--rate" ,
139- default = None ,
163+ default = GenerativeTextScenario . rate ,
140164 callback = parse_number_str ,
141165 help = (
142166 "The rates to run the benchmark at. "
@@ -150,6 +174,7 @@ def cli():
150174@click .option (
151175 "--max-seconds" ,
152176 type = float ,
177+ default = GenerativeTextScenario .max_seconds ,
153178 help = (
154179 "The maximum number of seconds each benchmark can run for. "
155180 "If None, will run until max_requests or the data is exhausted."
@@ -158,6 +183,7 @@ def cli():
158183@click .option (
159184 "--max-requests" ,
160185 type = int ,
186+ default = GenerativeTextScenario .max_requests ,
161187 help = (
162188 "The maximum number of requests each benchmark can run for. "
163189 "If None, will run until max_seconds or the data is exhausted."
@@ -166,7 +192,7 @@ def cli():
166192@click .option (
167193 "--warmup-percent" ,
168194 type = float ,
169- default = None ,
195+ default = GenerativeTextScenario . warmup_percent ,
170196 help = (
171197 "The percent of the benchmark (based on max-seconds, max-requets, "
172198 "or lenth of dataset) to run as a warmup and not include in the final results. "
@@ -176,6 +202,7 @@ def cli():
176202@click .option (
177203 "--cooldown-percent" ,
178204 type = float ,
205+ default = GenerativeTextScenario .cooldown_percent ,
179206 help = (
180207 "The percent of the benchmark (based on max-seconds, max-requets, or lenth "
181208 "of dataset) to run as a cooldown and not include in the final results. "
@@ -185,16 +212,19 @@ def cli():
185212@click .option (
186213 "--disable-progress" ,
187214 is_flag = True ,
215+ default = not GenerativeTextScenario .show_progress ,
188216 help = "Set this flag to disable progress updates to the console" ,
189217)
190218@click .option (
191219 "--display-scheduler-stats" ,
192220 is_flag = True ,
221+ default = GenerativeTextScenario .show_progress_scheduler_stats ,
193222 help = "Set this flag to display stats for the processes running the benchmarks" ,
194223)
195224@click .option (
196225 "--disable-console-outputs" ,
197226 is_flag = True ,
227+ default = not GenerativeTextScenario .output_console ,
198228 help = "Set this flag to disable console output" ,
199229)
200230@click .option (
@@ -211,6 +241,7 @@ def cli():
211241@click .option (
212242 "--output-extras" ,
213243 callback = parse_json ,
244+ default = GenerativeTextScenario .output_extras ,
214245 help = "A JSON string of extra data to save with the output benchmarks" ,
215246)
216247@click .option (
@@ -220,15 +251,16 @@ def cli():
220251 "The number of samples to save in the output file. "
221252 "If None (default), will save all samples."
222253 ),
223- default = None ,
254+ default = GenerativeTextScenario . output_sampling ,
224255)
225256@click .option (
226257 "--random-seed" ,
227- default = 42 ,
258+ default = GenerativeTextScenario . random_seed ,
228259 type = int ,
229260 help = "The random seed to use for benchmarking to ensure reproducibility." ,
230261)
231262def benchmark (
263+ scenario ,
232264 target ,
233265 backend_type ,
234266 backend_args ,
@@ -252,30 +284,48 @@ def benchmark(
252284 output_sampling ,
253285 random_seed ,
254286):
287+ click_ctx = click .get_current_context ()
288+
289+ # If a scenario file was specified read from it
290+ # TODO: This should probably be a factory method
291+ if scenario is None :
292+ _scenario = {}
293+ else :
294+ # TODO: Support pre-defined scenarios
295+ # TODO: Support other formats
296+ with Path (scenario ).open () as f :
297+ _scenario = json .load (f )
298+
299+ # If any command line arguments are specified, override the scenario
300+ _scenario .update (set_if_not_default (
301+ click_ctx ,
302+ target = target ,
303+ backend_type = backend_type ,
304+ backend_args = backend_args ,
305+ model = model ,
306+ processor = processor ,
307+ processor_args = processor_args ,
308+ data = data ,
309+ data_args = data_args ,
310+ data_sampler = data_sampler ,
311+ rate_type = rate_type ,
312+ rate = rate ,
313+ max_seconds = max_seconds ,
314+ max_requests = max_requests ,
315+ warmup_percent = warmup_percent ,
316+ cooldown_percent = cooldown_percent ,
317+ show_progress = not disable_progress ,
318+ show_progress_scheduler_stats = display_scheduler_stats ,
319+ output_console = not disable_console_outputs ,
320+ output_path = output_path ,
321+ output_extras = output_extras ,
322+ output_sampling = output_sampling ,
323+ random_seed = random_seed ,
324+ ))
325+
255326 asyncio .run (
256- benchmark_generative_text (
257- target = target ,
258- backend_type = backend_type ,
259- backend_args = backend_args ,
260- model = model ,
261- processor = processor ,
262- processor_args = processor_args ,
263- data = data ,
264- data_args = data_args ,
265- data_sampler = data_sampler ,
266- rate_type = rate_type ,
267- rate = rate ,
268- max_seconds = max_seconds ,
269- max_requests = max_requests ,
270- warmup_percent = warmup_percent ,
271- cooldown_percent = cooldown_percent ,
272- show_progress = not disable_progress ,
273- show_progress_scheduler_stats = display_scheduler_stats ,
274- output_console = not disable_console_outputs ,
275- output_path = output_path ,
276- output_extras = output_extras ,
277- output_sampling = output_sampling ,
278- random_seed = random_seed ,
327+ benchmark_with_scenario (
328+ scenario = GenerativeTextScenario (** _scenario )
279329 )
280330 )
281331
0 commit comments