Skip to content

Commit 5f98ddf

Browse files
committed
support lmms-eval
1 parent f7adbb9 commit 5f98ddf

File tree

6 files changed

+346
-3
lines changed

6 files changed

+346
-3
lines changed

llmc/__main__.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,22 @@
33
import gc
44
import json
55
import os
6+
import sys
67
import time
78

89
import torch
910
import torch.distributed as dist
1011
import yaml
1112
from easydict import EasyDict
13+
from lmms_eval.utils import make_table
1214
from loguru import logger
1315
from torch.distributed import destroy_process_group, init_process_group
1416

1517
from llmc.compression.quantization import *
1618
from llmc.compression.sparsification import *
1719
from llmc.data import BaseDataset
1820
from llmc.eval import (AccuracyEval, HumanEval, PerplexityEval,
19-
TokenConsistencyEval, VLMEval)
21+
TokenConsistencyEval, VLMEval, VQAEval)
2022
from llmc.models import *
2123
from llmc.utils import (check_config, mkdirs, print_important_package_version,
2224
seed_all, update_autoawq_quant_config,
@@ -49,6 +51,9 @@ def main(config):
4951
elif config.eval.type == 'img_txt':
5052
acc_eval = VLMEval(config_for_eval)
5153
eval_list.append(acc_eval)
54+
elif config.eval.type == 'vqa':
55+
vqa_eval = VQAEval(config_for_eval)
56+
eval_list.append(vqa_eval)
5257
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
5358
human_eval = HumanEval(model.get_tokenizer(), config_for_eval)
5459
eval_list.append(human_eval)
@@ -65,6 +70,11 @@ def main(config):
6570
for vlm_eval in eval_list:
6671
results = vlm_eval.eval(model)
6772
logger.info(f'{config.eval.name} results : {results}')
73+
elif config.eval.type == 'vqa':
74+
for vqa_eval in eval_list:
75+
results = vqa_eval.eval(model)
76+
logger.info(f'{config.eval.name} results :')
77+
print(make_table(results))
6878
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
6979
for human_eval in eval_list:
7080
results = human_eval.eval(model, eval_pos='pretrain')
@@ -161,6 +171,11 @@ def main(config):
161171
for vlm_eval in eval_list:
162172
results = vlm_eval.eval(model)
163173
logger.info(f'{config.eval.name} results : {results}')
174+
elif config.eval.type == 'vqa':
175+
for vqa_eval in eval_list:
176+
results = vqa_eval.eval(model)
177+
logger.info(f'{config.eval.name} results :')
178+
print(make_table(results))
164179
elif config.eval.type == 'code' and config.eval.name == 'human_eval':
165180
for human_eval in eval_list:
166181
results = human_eval.eval(model, eval_pos='fake_quant')
@@ -251,6 +266,7 @@ def main(config):
251266

252267

253268
if __name__ == '__main__':
269+
logger.add(sys.stdout, level='INFO')
254270
llmc_start_time = time.time()
255271
parser = argparse.ArgumentParser()
256272
parser.add_argument('--config', type=str, required=True)

llmc/eval/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .eval_ppl import PerplexityEval
44
from .eval_token_consist import TokenConsistencyEval
55
from .eval_vlm import VLMEval
6+
from .eval_vqa import VQAEval

llmc/eval/eval_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self, tokenizer, config):
2222
'c4',
2323
'ptb',
2424
'custom',
25-
'human_eval'
25+
'human_eval',
26+
'mme',
2627
], 'Eval only support wikitext2, c4, ptb, custom, human_eval dataset now.'
2728
self.seq_len = self.eval_cfg.get('seq_len', None)
2829
self.bs = self.eval_cfg['bs']

llmc/eval/eval_vqa.py

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import random
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
from lmms_eval.evaluator import evaluate
7+
from lmms_eval.evaluator_utils import run_task_tests
8+
from lmms_eval.loggers.evaluation_tracker import EvaluationTracker
9+
from lmms_eval.tasks import TaskManager, get_task_dict
10+
from lmms_eval.utils import get_datetime_str, simple_parse_args_string
11+
from loguru import logger
12+
13+
from llmc.utils.registry_factory import MODEL_REGISTRY
14+
15+
16+
class VQAEval:
17+
def __init__(self, config):
18+
self.eval_config = config.eval
19+
self.model_path = config.model.path
20+
self.dataset = self.eval_config['name']
21+
if not isinstance(self.dataset, list):
22+
self.dataset = [self.dataset, ]
23+
self.eval_dataset_path = self.eval_config['path']
24+
self.eval_bs = self.eval_config['bs']
25+
26+
def eval(
27+
self,
28+
llmc_model,
29+
model_args: Optional[Union[str, dict]] = None,
30+
tasks: Optional[List[Union[str, dict, object]]] = None,
31+
num_fewshot: Optional[int] = None,
32+
batch_size: Optional[Union[int, str]] = None,
33+
max_batch_size: Optional[int] = None,
34+
device: Optional[str] = None,
35+
use_cache: Optional[str] = None,
36+
cache_requests: bool = False,
37+
rewrite_requests_cache: bool = False,
38+
delete_requests_cache: bool = False,
39+
limit: Optional[Union[int, float]] = None,
40+
bootstrap_iters: int = 100000,
41+
check_integrity: bool = False,
42+
write_out: bool = False,
43+
log_samples: bool = True,
44+
evaluation_tracker: Optional[EvaluationTracker] = None,
45+
system_instruction: Optional[str] = None,
46+
apply_chat_template: bool = False,
47+
fewshot_as_multiturn: bool = False,
48+
gen_kwargs: Optional[str] = None,
49+
task_manager: Optional[TaskManager] = None,
50+
verbosity: str = 'INFO',
51+
predict_only: bool = False,
52+
random_seed: int = 0,
53+
numpy_random_seed: int = 1234,
54+
torch_random_seed: int = 1234,
55+
fewshot_random_seed: int = 1234,
56+
datetime_str: str = get_datetime_str(),
57+
cli_args=None,
58+
):
59+
60+
model = llmc_model.eval_name
61+
model_args = 'pretrained=' + self.model_path + ',device_map=auto'
62+
batch_size = self.eval_bs
63+
tasks = self.dataset
64+
num_fewshot = 0
65+
66+
seed_message = []
67+
if random_seed is not None:
68+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
69+
seed_message.append(f'Setting random seed to {random_seed}')
70+
random.seed(random_seed)
71+
72+
if numpy_random_seed is not None:
73+
seed_message.append(f'Setting numpy seed to {numpy_random_seed}')
74+
np.random.seed(numpy_random_seed)
75+
76+
if torch_random_seed is not None:
77+
seed_message.append(f'Setting torch manual seed to {torch_random_seed}')
78+
torch.manual_seed(torch_random_seed)
79+
80+
if seed_message:
81+
logger.info(' | '.join(seed_message))
82+
83+
assert tasks != [], 'No tasks specified, or no tasks found. Please verify the task names.'
84+
85+
if gen_kwargs:
86+
gen_kwargs = simple_parse_args_string(gen_kwargs)
87+
logger.warning('generation_kwargs specified through cli.')
88+
if gen_kwargs == '':
89+
gen_kwargs = None
90+
91+
if model_args is None:
92+
model_args = ''
93+
94+
if task_manager is None:
95+
task_manager = TaskManager(verbosity, model_name=model)
96+
97+
task_dict = get_task_dict(tasks, task_manager)
98+
99+
lm = MODEL_REGISTRY[model].create_from_arg_string(
100+
model_args,
101+
{
102+
'llmc_model': llmc_model.vlm_model,
103+
'batch_size': batch_size,
104+
'device': device,
105+
},
106+
)
107+
# helper function to recursively apply config overrides to leaf subtasks,
108+
# skipping their constituent groups.
109+
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
110+
111+
def _adjust_config(task_dict):
112+
adjusted_task_dict = {}
113+
for task_name, task_obj in task_dict.items():
114+
if isinstance(task_obj, dict):
115+
adjusted_task_dict = {
116+
**adjusted_task_dict,
117+
**{task_name: _adjust_config(task_obj)},
118+
}
119+
120+
else:
121+
task_obj = task_dict[task_name]
122+
if type(task_obj) == tuple:
123+
group, task_obj = task_obj
124+
if task_obj is None:
125+
continue
126+
lm.task_dict[task_name] = task_obj.dataset
127+
if 'generate_until' in task_obj.get_config('output_type'):
128+
if gen_kwargs is not None:
129+
task_obj.set_config(key='generation_kwargs',
130+
value=gen_kwargs, update=True)
131+
132+
if predict_only:
133+
logger.info(f'Processing {task_name} in output-only mode. \
134+
Metrics will not be calculated!')
135+
# we have to change the class properties post-hoc. This is pretty hacky.
136+
task_obj.override_metric(metric_name='bypass')
137+
138+
# override tasks' fewshot values to
139+
# the provided num_fewshot arg value
140+
# except if tasks have it set to 0 manually in their configs--then
141+
# we should never overwrite that
142+
if num_fewshot is not None:
143+
if (default_num_fewshot := task_obj.get_config('num_fewshot')) == 0:
144+
logger.info(f'num_fewshot has been set to 0 for {task_name} \
145+
in its config. Manual configuration will be ignored.')
146+
else:
147+
logger.warning(f'Overwriting default num_fewshot of {task_name} \
148+
from {default_num_fewshot} to {num_fewshot}')
149+
task_obj.set_config(key='num_fewshot', value=num_fewshot)
150+
else:
151+
# if num_fewshot not provided, and the task does not define a default one,
152+
# default to 0
153+
if (default_num_fewshot := task_obj.get_config('num_fewshot')) is None:
154+
task_obj.set_config(key='num_fewshot', value=0)
155+
# fewshot_random_seed set for tasks, even with a default num_fewshot
156+
# (e.g. in the YAML file)
157+
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
158+
# logger.info(f"Setting fewshot random generator seed to {fewshot_random_seed}")
159+
160+
adjusted_task_dict[task_name] = task_obj
161+
162+
return adjusted_task_dict
163+
164+
task_dict = _adjust_config(task_dict)
165+
166+
if check_integrity:
167+
run_task_tests(task_list=tasks)
168+
169+
if evaluation_tracker is not None:
170+
evaluation_tracker.general_config_tracker.log_experiment_args(
171+
model_source=model,
172+
model_args=model_args,
173+
system_instruction=system_instruction,
174+
chat_template=lm.chat_template if apply_chat_template else None,
175+
fewshot_as_multiturn=fewshot_as_multiturn,
176+
)
177+
178+
results = evaluate(
179+
lm=lm,
180+
task_dict=task_dict,
181+
limit=limit,
182+
cache_requests=cache_requests,
183+
rewrite_requests_cache=rewrite_requests_cache,
184+
bootstrap_iters=bootstrap_iters,
185+
write_out=write_out,
186+
log_samples=True if predict_only else log_samples,
187+
system_instruction=system_instruction,
188+
apply_chat_template=apply_chat_template,
189+
fewshot_as_multiturn=fewshot_as_multiturn,
190+
verbosity=verbosity,
191+
cli_args=cli_args,
192+
)
193+
194+
if hasattr(lm, '_model'):
195+
del lm._model
196+
torch.cuda.empty_cache()
197+
198+
if lm.rank == 0:
199+
if isinstance(model, str):
200+
model_name = model
201+
elif hasattr(model, 'config') and hasattr(model.config, '_name_or_path'):
202+
model_name = model.config._name_or_path
203+
else:
204+
model_name = type(model).__name__
205+
206+
# add info about the model and few shot config
207+
results['config'] = {
208+
'model': model_name,
209+
'model_args': model_args,
210+
}
211+
# add more detailed model info if available TODO: add model info
212+
# if isinstance(lm, lm_eval.models.huggingface.HFLM):
213+
# results["config"].update(lm.get_model_info())
214+
# add info about execution
215+
results['config'].update(
216+
{
217+
'batch_size': batch_size,
218+
'batch_sizes': (list(lm.batch_sizes.values())
219+
if hasattr(lm, 'batch_sizes') else []),
220+
'device': device,
221+
'use_cache': use_cache,
222+
'limit': limit,
223+
'bootstrap_iters': bootstrap_iters,
224+
'gen_kwargs': gen_kwargs,
225+
'random_seed': random_seed,
226+
'numpy_seed': numpy_random_seed,
227+
'torch_seed': torch_random_seed,
228+
'fewshot_seed': fewshot_random_seed,
229+
}
230+
)
231+
results['date'] = datetime_str
232+
# add_env_info(results) # additional environment info to results
233+
# add_tokenizer_info(results, lm) # additional info about tokenizer
234+
return results
235+
else:
236+
return None

0 commit comments

Comments
 (0)