Skip to content

Commit c086116

Browse files
authored
Merge pull request #197 from RobotSail/fix-mmlu
Allows MMLU to have the system_prompt provided to it
2 parents 4cf3e14 + ad12276 commit c086116

File tree

6 files changed

+140
-28
lines changed

6 files changed

+140
-28
lines changed

.pylintrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ disable=raw-checker-failed,
448448
pointless-statement,
449449
wrong-import-order,
450450
line-too-long,
451-
dangerous-default-value
451+
dangerous-default-value,
452+
too-many-instance-attributes
452453

453454
# Enable the message, report, category or checker with the given id(s). You can
454455
# either give multiple identifier separated by comma (,) or put this option

.spellcheck-en-custom.txt

+2
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ TODO
2626
tox
2727
venv
2828
vllm
29+
barebones
30+
LM

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.4.2
2+
3+
* Adds the ability to provide a custom system prompt to the MMLU-based evaluators. When a system prompt is provided, LM-eval applies the chat template under the hood, else it will pass the model a barebones prompt.
4+
* Adds an `extra_args` parameter to the `.run` method of all MMLU-based evaluators. This way, consumers are able to directly pass any additional arguments they want through to the `lm_eval.evaluators.simple_evaluate` function.
5+
16
## 0.4
27

38
* Added ability to specify a custom http client to MT-Bench

scripts/test_mmlu.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,73 @@
1+
# Standard
2+
from typing import Dict, List, Tuple, TypedDict
3+
14
# First Party
25
from instructlab.eval.mmlu import MMLUEvaluator
36

7+
SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."""
8+
9+
10+
class MMLUSample(TypedDict):
11+
"""
12+
Example of a single sample returned from lm_eval when running MMLU.
13+
This is not a comprehensive type, just the subset of fields we care about for this test.
14+
"""
15+
16+
# Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples.
17+
# They will not be present with few_shot=0
18+
arguments: List[Tuple[str, str]]
19+
20+
21+
def all_samples_contain_system_prompt(
22+
samples: Dict[str, List[MMLUSample]], prompt: str
23+
) -> bool:
24+
"""
25+
Given a mapping of evaluation --> list of results, validates that all few-shot examples
26+
included the system prompt
27+
"""
28+
for topic, samples_set in samples.items():
29+
for sample in samples_set:
30+
for mmlu_prompt, _ in sample["arguments"]:
31+
if prompt not in mmlu_prompt:
32+
# we are looking for the exact system prompt, so no need to convert to normalize to lowercase
33+
print(f"found a sample in the '{topic}' MMLU topic set")
34+
return False
35+
36+
return True
37+
438

539
def test_minimal_mmlu():
640
print("===> Executing 'test_minimal_mmlu'...")
741
try:
842
model_path = "instructlab/granite-7b-lab"
943
tasks = ["mmlu_anatomy", "mmlu_astronomy"]
10-
mmlu = MMLUEvaluator(model_path=model_path, tasks=tasks)
11-
overall_score, individual_scores = mmlu.run()
44+
mmlu = MMLUEvaluator(
45+
model_path=model_path,
46+
tasks=tasks,
47+
system_prompt=SYSTEM_PROMPT,
48+
)
49+
overall_score, individual_scores = mmlu.run(
50+
extra_args={"log_samples": True, "write_out": True}
51+
)
52+
samples = mmlu.results["samples"]
53+
1254
print(overall_score)
1355
print(individual_scores)
56+
57+
# we need n-shots > 1 to be able to validate the inclusion of the system prompt
58+
eligible_samples = {
59+
topic: samples[topic]
60+
for topic, shot in mmlu.results["n-shot"].items()
61+
if shot > 1
62+
}
63+
if eligible_samples:
64+
if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT):
65+
return False
66+
else:
67+
print(
68+
"MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..."
69+
)
70+
1471
except Exception as exc:
1572
print(f"'test_minimal_mmlu' failed: {exc}")
1673
return False

src/instructlab/eval/mmlu.py

+63-23
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
"""
88

99
# Standard
10-
from typing import Optional, Union
10+
from typing import Any, Dict, Optional, Union
1111
import os
1212

1313
# Third Party
14-
from lm_eval.evaluator import simple_evaluate # type: ignore
15-
from lm_eval.tasks import TaskManager # type: ignore
14+
from lm_eval.evaluator import simple_evaluate
15+
from lm_eval.tasks import TaskManager
1616
import torch
1717

1818
# First Party
@@ -102,6 +102,8 @@ class AbstractMMLUEvaluator(Evaluator):
102102
few_shots number of examples
103103
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
104104
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
105+
system_prompt system prompt to be used when applying the chat template
106+
results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run.
105107
"""
106108

107109
def __init__(
@@ -113,26 +115,43 @@ def __init__(
113115
few_shots: int = 5,
114116
batch_size: Optional[Union[int, str]] = "auto",
115117
device: str = ("cuda" if torch.cuda.is_available() else "cpu"),
118+
system_prompt: Optional[str] = None,
116119
) -> None:
117120
self.model_path = model_path
121+
self.system_prompt = system_prompt
118122
self.tasks_dir = tasks_dir
119123
self.tasks = tasks
120124
self.model_dtype = model_dtype
121125
self.few_shots = few_shots
122126
self.batch_size = batch_size
123127
self.device = device
128+
self._results = None
124129

125-
def run(self, server_url: str | None = None) -> tuple:
130+
@property
131+
def results(self) -> Dict[str, Any] | None:
132+
"""
133+
Returns the results of the last MMLU evaluation, if one has taken place.
134+
135+
Returns:
136+
Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate`
137+
"""
138+
return self._results
139+
140+
def run(
141+
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
142+
) -> tuple:
126143
"""
127144
Runs evaluation
128145
129146
Attributes
130147
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
148+
extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function.
131149
132150
Returns:
133151
overall_score Average score for the task group
134152
individual_scores Individual scores for each task in the task group
135153
"""
154+
extra_args = {} if not extra_args else extra_args
136155
logger.debug(locals())
137156

138157
# TODO: make this a parameter for class?
@@ -153,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple:
153172

154173
return overall_score, individual_scores
155174

156-
def _run_mmlu(self, server_url: str | None = None) -> dict:
175+
def _run_mmlu(
176+
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
177+
) -> dict:
178+
extra_args = {} if not extra_args else extra_args
157179
if server_url is not None:
158180
# Requires lm_eval >= 0.4.4
159181
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
@@ -168,17 +190,25 @@ def _run_mmlu(self, server_url: str | None = None) -> dict:
168190
if not os.access(self.tasks_dir, os.R_OK):
169191
raise InvalidTasksDirError(self.tasks_dir)
170192
tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir)
171-
mmlu_output = self._simple_evaluate_with_error_handling(
172-
model=model,
173-
model_args=model_args,
174-
tasks=self.tasks,
175-
num_fewshot=self.few_shots,
176-
batch_size=self.batch_size,
177-
device=self.device,
178-
task_manager=tm,
179-
)
180-
results = mmlu_output["results"]
181-
return results
193+
should_apply_chat_template = self.system_prompt is not None
194+
195+
# configure the args here so users can override them as necessary
196+
simple_evaluate_kwargs = {
197+
"model": model,
198+
"model_args": model_args,
199+
"tasks": self.tasks,
200+
"num_fewshot": self.few_shots,
201+
"batch_size": self.batch_size,
202+
"device": self.device,
203+
"task_manager": tm,
204+
"system_instruction": self.system_prompt,
205+
"apply_chat_template": should_apply_chat_template,
206+
}
207+
simple_evaluate_kwargs.update(extra_args)
208+
209+
results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs)
210+
self._results = results
211+
return results["results"]
182212

183213
# This method converts general errors from simple_evaluate
184214
# into a more user-understandable error
@@ -213,12 +243,13 @@ class MMLUEvaluator(AbstractMMLUEvaluator):
213243
Evaluator for Massive Multitask Language Understanding (MMLU)
214244
215245
Attributes:
216-
model_path absolute path to or name of a huggingface model
217-
tasks list of tasks for MMLU to test the model with
218-
model_dtype dtype of model when served
219-
few_shots number of examples
220-
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
221-
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
246+
model_path absolute path to or name of a huggingface model
247+
tasks list of tasks for MMLU to test the model with
248+
model_dtype dtype of model when served
249+
few_shots number of examples
250+
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
251+
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
252+
system_prompt system prompt to be used when applying the chat template
222253
"""
223254

224255
name = "mmlu"
@@ -231,9 +262,17 @@ def __init__(
231262
few_shots: int = 5,
232263
batch_size: Optional[Union[int, str]] = "auto",
233264
device: str = ("cuda" if torch.cuda.is_available() else "cpu"),
265+
system_prompt: Optional[str] = None,
234266
) -> None:
235267
super().__init__(
236-
model_path, None, tasks, model_dtype, few_shots, batch_size, device
268+
model_path,
269+
None,
270+
tasks,
271+
model_dtype,
272+
few_shots,
273+
batch_size,
274+
device,
275+
system_prompt=system_prompt,
237276
)
238277

239278

@@ -243,6 +282,7 @@ class MMLUBranchEvaluator(AbstractMMLUEvaluator):
243282
244283
Attributes:
245284
model_path absolute path to or name of a huggingface model
285+
system_prompt system prompt to be used when applying the chat template
246286
tasks_dir path where the <TASK_NAME>.jsonl and <TASK_NAME>_task.yaml files for the branches being evaluated are stored
247287
tasks group name that is shared by all the MMLUBranch tasks
248288
model_dtype dtype of model when served

tests/test_mmlu.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def test_mmlu_branch(eval_mock):
4848
tasks_dir = f"{os.path.dirname(os.path.realpath(__file__))}/testdata/sdg"
4949
tasks = ["mmlu_pr"]
5050
mmlu = MMLUBranchEvaluator(
51-
model_path=MODEL_EXAMPLE, tasks_dir=tasks_dir, tasks=tasks
51+
model_path=MODEL_EXAMPLE,
52+
tasks_dir=tasks_dir,
53+
tasks=tasks,
54+
system_prompt="You are an intelligent AI language model.",
5255
)
5356
overall_score, individual_scores = mmlu.run()
5457

@@ -62,7 +65,11 @@ def test_mmlu_branch(eval_mock):
6265
)
6366
def test_mmlu(eval_mock):
6467
tasks = ["mmlu_anatomy", "mmlu_astronomy", "mmlu_algebra"]
65-
mmlu = MMLUEvaluator(model_path=MODEL_EXAMPLE, tasks=tasks)
68+
mmlu = MMLUEvaluator(
69+
model_path=MODEL_EXAMPLE,
70+
tasks=tasks,
71+
system_prompt="You are an intelligent AI language model.",
72+
)
6673
overall_score, individual_scores = mmlu.run()
6774

6875
eval_mock.assert_called()

0 commit comments

Comments
 (0)