Skip to content

Commit 9dea250

Browse files
Add prefill tps in oga-bench (#250)
* Add prefill tps in oga-bench Signed-off-by: David Fan <[email protected]> * link fix --------- Signed-off-by: David Fan <[email protected]> Co-authored-by: ramkrishna2910 <[email protected]>
1 parent 8c46f6b commit 9dea250

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

src/turnkeyml/llm/cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class Keys:
2121
PER_ITERATION_LATENCY = "per_iteration_latency"
2222
MEAN_LATENCY = "mean_latency"
2323
STD_DEV_LATENCY = "std_dev_latency"
24-
MEAN_TOKENS_PER_SECOND = "mean_tokens_per_second"
24+
TOKEN_GENERATION_TOKENS_PER_SECOND = "token_generation_tokens_per_second"
2525
STD_DEV_TOKENS_PER_SECOND = "std_dev_tokens_per_second"
2626
SECONDS_TO_FIRST_TOKEN = "seconds_to_first_token"
27+
PREFILL_TOKENS_PER_SECOND = "prefill_tokens_per_second"
2728
STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
2829
CHECKPOINT = "checkpoint"
2930
DTYPE = "dtype"

src/turnkeyml/llm/tools/huggingface_bench.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ class HuggingfaceBench(Tool):
110110
def __init__(self):
111111
super().__init__(monitor_message="Benchmarking Huggingface LLM")
112112

113-
self.status_stats = [Keys.SECONDS_TO_FIRST_TOKEN, Keys.MEAN_TOKENS_PER_SECOND]
113+
self.status_stats = [
114+
Keys.SECONDS_TO_FIRST_TOKEN,
115+
Keys.TOKEN_GENERATION_TOKENS_PER_SECOND,
116+
]
114117

115118
@staticmethod
116119
def parser(parser: argparse.ArgumentParser = None, add_help: bool = True):
@@ -283,11 +286,13 @@ def run(
283286
[token_len for _, token_len in decode_per_iteration_result]
284287
)
285288
# Subtract 1 so that we don't count the prefill token
286-
mean_tokens_per_second = (mean_token_len - 1) / mean_decode_latency
289+
token_generation_tokens_per_second = (mean_token_len - 1) / mean_decode_latency
287290

288291
# Save performance data to stats
289292
state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token)
290-
state.save_stat(Keys.MEAN_TOKENS_PER_SECOND, mean_tokens_per_second)
293+
state.save_stat(
294+
Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second
295+
)
291296
state.save_stat(Keys.PROMPT_TOKENS, input_ids.shape[1])
292297

293298
return state

src/turnkeyml/llm/tools/ort_genai/oga_bench.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __init__(self):
3232

3333
self.status_stats = [
3434
Keys.SECONDS_TO_FIRST_TOKEN,
35-
Keys.MEAN_TOKENS_PER_SECOND,
35+
Keys.PREFILL_TOKENS_PER_SECOND,
36+
Keys.TOKEN_GENERATION_TOKENS_PER_SECOND,
3637
Keys.PROMPT_TOKENS,
3738
]
3839

@@ -144,10 +145,16 @@ def run(
144145
per_iteration_tokens_per_second.append(model.tokens_per_second)
145146

146147
mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token)
147-
mean_tokens_per_second = statistics.mean(per_iteration_tokens_per_second)
148+
prefill_tokens_per_second = input_ids_len / mean_time_to_first_token
149+
token_generation_tokens_per_second = statistics.mean(
150+
per_iteration_tokens_per_second
151+
)
148152

149153
state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token)
150-
state.save_stat(Keys.MEAN_TOKENS_PER_SECOND, mean_tokens_per_second)
154+
state.save_stat(Keys.PREFILL_TOKENS_PER_SECOND, prefill_tokens_per_second)
155+
state.save_stat(
156+
Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second
157+
)
151158
state.save_stat(Keys.PROMPT_TOKENS, input_ids_len)
152159

153160
return state

test/llm_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_001_huggingface_bench(self):
7878

7979
stats = fs.Stats(state.cache_dir, state.build_name).stats
8080

81-
assert stats[Keys.MEAN_TOKENS_PER_SECOND] > 0
81+
assert stats[Keys.TOKEN_GENERATION_TOKENS_PER_SECOND] > 0
8282

8383

8484
if __name__ == "__main__":

0 commit comments

Comments
 (0)