Skip to content

Commit 63f11db

Browse files
Add perf tools for huggingface and oga (#247)
* adds hf and oga perf tools * rev version
1 parent ee84439 commit 63f11db

File tree

7 files changed

+494
-2
lines changed

7 files changed

+494
-2
lines changed

.github/actions/server-testing/action.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,30 @@ description: Launch Lemonade Server and test the endpoints
33
inputs:
44
conda_env:
55
required: true
6+
description: "Location of the lemonade Conda environment on disk"
67
load_command:
78
required: true
9+
description: "The backend-specific portion of the lemonade command used to load the model, e.g., `-i CHECKPOINT load-tool --load-tool-args`"
10+
hf_home:
11+
required: false
12+
description: "Location of the Huggingface Cache on disk"
13+
default: "~/.cache/huggingface/hub"
814
amd_oga:
915
required: false
1016
default: ""
1117
description: "Location of the OGA for RyzenAI NPU install directory on disk"
18+
hf_token:
19+
required: false
20+
default: ""
1221
runs:
1322
using: "composite"
1423
steps:
1524
- name: Ensure the Lemonade serer works properly
1625
shell: PowerShell
1726
run: |
1827
$Env:AMD_OGA = "${{ inputs.amd_oga }}"
28+
$Env:HF_HOME = "${{ inputs.hf_home }}"
29+
$Env:HF_TOKEN = "${{ inputs.hf_token }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
1930
2031
$outputFile = "output.log"
2132
$errorFile = "error.log"

src/turnkeyml/llm/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ That command will run just the management test from MMLU on your LLM and save th
5454

5555
You can run the full suite of MMLU subjects by omitting the `--test` argument. You can learn more about this with `lemonade accuracy-mmlu -h.
5656

57+
## Benchmarking
58+
59+
To measure the time-to-first-token and tokens/second of an LLM, try this:
60+
61+
`lemonade -i facebook/opt-125m huggingface-load huggingface-bench`
62+
63+
That command will run a few warmup iterations, then a few generation iterations where performance data is collected.
64+
65+
The prompt size, number of output tokens, and number iterations are all parameters. Learn more by running `lemonade huggingface-bench -h`.
66+
5767
## Serving
5868

5969
You can launch a WebSocket server for your LLM with:
@@ -95,6 +105,10 @@ You can then load supported OGA models on to CPU or iGPU with the `oga-load` too
95105

96106
You can also launch a server process with:
97107

108+
The `oga-bench` tool is available to capture tokens/second and time-to-first-token metrics: `lemonade -i microsoft/Phi-3-mini-4k-instruct oga-load --device igpu --dtype int4 oga-bench`. Learn more with `lemonade oga-bench -h`.
109+
110+
You can also try Phi-3-Mini-128k-Instruct with the following commands:
111+
98112
`lemonade -i microsoft/Phi-3-mini-4k-instruct oga-load --device igpu --dtype int4 serve`
99113

100114
You can learn more about the CPU and iGPU support in our [OGA documentation](https://github.com/onnx/turnkeyml/blob/main/docs/ort_genai_igpu.md).

src/turnkeyml/llm/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
AdaptHuggingface,
1313
)
1414

15+
from turnkeyml.llm.tools.huggingface_bench import HuggingfaceBench
16+
from turnkeyml.llm.tools.ort_genai.oga_bench import OgaBench
17+
1518
from turnkeyml.llm.tools.llamacpp import LoadLlamaCpp
1619

1720
import turnkeyml.llm.cache as cache
@@ -31,6 +34,8 @@ def main():
3134
LLMPrompt,
3235
AdaptHuggingface,
3336
Serve,
37+
HuggingfaceBench,
38+
OgaBench,
3439
# Inherited from TurnkeyML
3540
Report,
3641
Cache,
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
import argparse
2+
import os
3+
from typing import List, Tuple
4+
import time
5+
import statistics
6+
from contextlib import nullcontext
7+
import torch
8+
import tqdm
9+
from turnkeyml.state import State
10+
from turnkeyml.tools import Tool
11+
from turnkeyml.llm.cache import Keys
12+
import turnkeyml.llm.tools.ort_genai.oga_bench as general
13+
14+
15+
def benchmark_huggingface_llm(
16+
model: torch.nn.Module,
17+
tokenizer,
18+
input_ids,
19+
dtype,
20+
num_beams: int,
21+
target_output_tokens: int,
22+
iterations: int,
23+
warmup_iterations: int,
24+
) -> List[Tuple[float, int]]:
25+
26+
# Inform the user whether the current execution is to measure
27+
# prefill or generation performance, since we need to run this
28+
# method once for each of those modes
29+
mode = "prefill" if target_output_tokens == 1 else "generation"
30+
31+
amp_enabled = True if (dtype == torch.float16 or dtype == torch.bfloat16) else False
32+
# The "if amp_enabled else nullcontext()" is to get around a bug in PyTorch 2.1
33+
# where torch.cpu.amp.autocast(enabled=False) does nothing
34+
with (
35+
torch.cpu.amp.autocast(enabled=amp_enabled, dtype=dtype)
36+
if amp_enabled
37+
else nullcontext()
38+
):
39+
40+
per_iteration_result = []
41+
42+
# Early stopping is only a valid parameter with multiple beams
43+
early_stopping = num_beams > 1
44+
45+
with torch.no_grad(), torch.inference_mode():
46+
# Don't capture time for warmup
47+
for _ in tqdm.tqdm(range(warmup_iterations), desc=f"{mode} warmup"):
48+
model.generate(
49+
input_ids,
50+
num_beams=num_beams,
51+
max_new_tokens=target_output_tokens,
52+
min_new_tokens=target_output_tokens,
53+
early_stopping=early_stopping,
54+
pad_token_id=tokenizer.eos_token_id,
55+
)
56+
57+
for _ in tqdm.tqdm(range(iterations), desc=f"{mode} iterations"):
58+
# CUDA synchronization is required prior to GPU benchmarking
59+
# This has no negative effect on CPU-only benchmarks, and is more robust than
60+
# checking `model.device == "cuda"` since it applies to multi-GPU environments
61+
# Synchronization is done before collecting the start time because this will
62+
# ensure that the GPU has finished initialization tasks such as loading weights
63+
if torch.cuda.is_available():
64+
torch.cuda.synchronize()
65+
start_time = time.perf_counter()
66+
67+
outputs = model.generate(
68+
input_ids,
69+
num_beams=num_beams,
70+
max_new_tokens=target_output_tokens,
71+
min_new_tokens=target_output_tokens,
72+
early_stopping=early_stopping,
73+
pad_token_id=tokenizer.eos_token_id,
74+
)
75+
76+
if torch.cuda.is_available():
77+
torch.cuda.synchronize()
78+
end_time = time.perf_counter()
79+
80+
latency = end_time - start_time
81+
82+
token_len = outputs.shape[1] - input_ids.shape[1]
83+
84+
# Only count an iteration it produced enough tokens
85+
if token_len >= target_output_tokens:
86+
per_iteration_result.append((latency, token_len))
87+
88+
return per_iteration_result
89+
90+
91+
class HuggingfaceBench(Tool):
92+
"""
93+
Benchmarks the performance of the generate() method of an LLM loaded from
94+
Huggingface Transformers (or any object that supports a
95+
huggingface-like generate() method).
96+
97+
Required input state:
98+
- DTYPE: data type of the model; used to determine if AMP should be
99+
enabled to convert the input data type to match the model data
100+
type.
101+
- MODEL: huggingface-like instance to benchmark.
102+
- INPUTS: model inputs to pass to generate() during benchmarking.
103+
104+
Output state produced: None
105+
106+
"""
107+
108+
unique_name = "huggingface-bench"
109+
110+
def __init__(self):
111+
super().__init__(monitor_message="Benchmarking Huggingface LLM")
112+
113+
self.status_stats = [Keys.SECONDS_TO_FIRST_TOKEN, Keys.MEAN_TOKENS_PER_SECOND]
114+
115+
@staticmethod
116+
def parser(parser: argparse.ArgumentParser = None, add_help: bool = True):
117+
# allow inherited classes to initialize and pass in a parser, add parameters to it if so
118+
if parser is None:
119+
parser = __class__.helpful_parser(
120+
short_description="Benchmark a Huggingface-like LLM", add_help=add_help
121+
)
122+
123+
parser.add_argument(
124+
"--iterations",
125+
"-i",
126+
required=False,
127+
type=int,
128+
default=general.default_iterations,
129+
help="Number of benchmarking iterations to run (default: "
130+
f"{general.default_iterations})",
131+
)
132+
133+
parser.add_argument(
134+
"--warmup-iterations",
135+
"-w",
136+
required=False,
137+
type=int,
138+
default=general.default_warmup_runs,
139+
help="Number of benchmarking iterations to use for cache warmup "
140+
"(the results of these iterations "
141+
f"are not included in the results; default: {general.default_warmup_runs})",
142+
)
143+
144+
parser.add_argument(
145+
"--prompt",
146+
"-p",
147+
required=False,
148+
default=general.default_prompt,
149+
help="Input prompt to the LLM. Three formats are supported. "
150+
f"1) integer (default: {general.default_prompt}): "
151+
"use a synthetic prompt with the specified length. "
152+
"2) str: use a user-provided prompt string "
153+
"3) path/to/prompt.txt: load the prompt from a text file.",
154+
)
155+
156+
parser.add_argument(
157+
"--num-beams",
158+
required=False,
159+
type=int,
160+
default=general.default_beams,
161+
help=f"Number of beams for the LLM to use (default: {general.default_beams})",
162+
)
163+
164+
parser.add_argument(
165+
"--output-tokens",
166+
required=False,
167+
type=int,
168+
default=general.default_output_tokens,
169+
help="Number of new tokens the LLM should make (default: "
170+
f"{general.default_output_tokens})",
171+
)
172+
173+
return parser
174+
175+
def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
176+
"""
177+
Helper function to parse CLI arguments into the args expected
178+
by run()
179+
"""
180+
181+
parsed_args = super().parse(state, args, known_only)
182+
183+
# Decode prompt arg into a string prompt
184+
if parsed_args.prompt.isdigit():
185+
# Generate a prompt with the requested length
186+
length = int(parsed_args.prompt)
187+
parsed_args.prompt = "word " * (length - 2)
188+
189+
elif os.path.exists(parsed_args.prompt):
190+
with open(parsed_args.prompt, "r", encoding="utf-8") as f:
191+
parsed_args.prompt = f.read()
192+
193+
else:
194+
# No change to the prompt
195+
pass
196+
197+
return parsed_args
198+
199+
def run(
200+
self,
201+
state: State,
202+
prompt: str = general.default_prompt,
203+
iterations: int = general.default_iterations,
204+
warmup_iterations: int = general.default_warmup_runs,
205+
num_beams: int = general.default_beams,
206+
output_tokens: int = general.default_output_tokens,
207+
) -> State:
208+
"""
209+
Args:
210+
- prompt: input prompt used as a starting point for LLM text generation
211+
- iterations: number of benchmarking samples to take; results are
212+
reported as the median and mean of the samples.
213+
- warmup_iterations: subset of the iterations to treat as warmup,
214+
and not included in the results.
215+
- num_beams: number of beams to use in the LLM beam search. If the LLM
216+
instance has hardcoded its number of beams already, this value
217+
must match the hardcoded value.
218+
- output_tokens: Number of new tokens LLM to create.
219+
220+
We don't have access to the internal timings of generate(), so time to first
221+
token (TTFT, aka prefill latency) and token/s are calculated using the following formulae:
222+
prefill_latency = latency of generate(output_tokens=1)
223+
execution_latency = latency of generate(output_tokens=output_tokens)
224+
tokens_per_second = (new_tokens - 1) / (execution_latency - prefill_latency)
225+
"""
226+
227+
if vars(state).get(Keys.MODEL) is None:
228+
raise ValueError(
229+
f"{self.__class__.__name__} requires that a model be passed from another tool"
230+
)
231+
232+
if vars(state).get("num_beams") and vars(state).get("num_beams") != num_beams:
233+
raise ValueError(
234+
f"Number of beams was set to {vars(state).get('num_beams')} "
235+
f"in a previous tool, but it is set to {num_beams} in "
236+
"this tool. The values must be the same."
237+
)
238+
239+
model = state.model
240+
tokenizer = state.tokenizer
241+
dtype = state.dtype
242+
243+
# Generate the input_ids outside of the benchmarking function to make sure
244+
# the same input_ids are used everywhere
245+
input_ids = (
246+
tokenizer(prompt, return_tensors="pt").to(device=model.device).input_ids
247+
)
248+
249+
# Benchmark prefill time (time to first token)
250+
prefill_per_iteration_result = benchmark_huggingface_llm(
251+
model=model,
252+
tokenizer=tokenizer,
253+
input_ids=input_ids,
254+
dtype=dtype,
255+
num_beams=num_beams,
256+
target_output_tokens=1,
257+
iterations=iterations,
258+
warmup_iterations=warmup_iterations,
259+
)
260+
261+
time_to_first_token_per_iteration = [
262+
latency for latency, _ in prefill_per_iteration_result
263+
]
264+
mean_time_to_first_token = statistics.mean(time_to_first_token_per_iteration)
265+
266+
# Benchmark generation of all tokens
267+
decode_per_iteration_result = benchmark_huggingface_llm(
268+
model=model,
269+
tokenizer=tokenizer,
270+
input_ids=input_ids,
271+
dtype=dtype,
272+
num_beams=num_beams,
273+
target_output_tokens=output_tokens,
274+
iterations=iterations,
275+
warmup_iterations=warmup_iterations,
276+
)
277+
278+
mean_execution_latency = statistics.mean(
279+
[latency for latency, _ in decode_per_iteration_result]
280+
)
281+
mean_decode_latency = mean_execution_latency - mean_time_to_first_token
282+
mean_token_len = statistics.mean(
283+
[token_len for _, token_len in decode_per_iteration_result]
284+
)
285+
# Subtract 1 so that we don't count the prefill token
286+
mean_tokens_per_second = (mean_token_len - 1) / mean_decode_latency
287+
288+
# Save performance data to stats
289+
state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token)
290+
state.save_stat(Keys.MEAN_TOKENS_PER_SECOND, mean_tokens_per_second)
291+
state.save_stat(Keys.PROMPT_TOKENS, input_ids.shape[1])
292+
293+
return state

0 commit comments

Comments
 (0)