Skip to content

Commit 40cc370

Browse files
authored
Merge pull request #138 from alimaredia/mtbench-branch-judgement-return-overall-score
return overall_score from MTBenchBranch.generate_judgement()
2 parents fa22ef5 + b22b40b commit 40cc370

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

.github/workflows/e2e-nvidia-t4-x1.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ jobs:
142142
working-directory: ./instructlab
143143
run: |
144144
. venv/bin/activate
145-
./scripts/basic-workflow-tests.sh -m
145+
./scripts/basic-workflow-tests.sh -msq
146146
147147
stop-runner:
148148
name: Stop external EC2 runner

src/instructlab/eval/mt_bench.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,12 @@ def judge_answers(
246246
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
247247
248248
Returns:
249+
overall_score overall score from the evaluation
249250
qa_pairs Question and answer pairs (with scores) from the evaluation
251+
error_rate percentage of questions dropped due to errors during evaluation
250252
"""
251253
logger.debug(locals())
252-
_, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment(
254+
overall_score, qa_pairs, _, error_rate = mt_bench_judgment.generate_judgment(
253255
self.model_name,
254256
self.judge_model_name,
255257
server_url,
@@ -261,4 +263,4 @@ def judge_answers(
261263
bench_name="mt_bench_branch",
262264
merge_system_user_message=self.merge_system_user_message,
263265
)
264-
return qa_pairs, error_rate
266+
return overall_score, qa_pairs, error_rate

tests/test_branch_judge_answers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
"../taxonomy",
1111
"main",
1212
)
13-
qa_pairs, error_rate = mt_bench_branch.judge_answers("http://localhost:8000/v1")
13+
overall_score, qa_pairs, error_rate = mt_bench_branch.judge_answers(
14+
"http://localhost:8000/v1"
15+
)
16+
17+
print(f"Overall Score: {overall_score}")
1418
print(f"Error Rate: {error_rate}")
1519
print(f"QA Pair 0:")
1620
pprint.pprint(qa_pairs[0])

0 commit comments

Comments
 (0)