Skip to content

Commit 60b9eb3

Browse files
author
gushiqiao
committed
Support chat model human-eval test.
1 parent 93fe81c commit 60b9eb3

File tree

4 files changed

+50
-22
lines changed

4 files changed

+50
-22
lines changed

configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ eval:
1414
bs: 1
1515
format_tabs: True
1616
inference_per_block: False
17+
# add_chat_temp: True
1718
quant:
1819
method: RTN
1920
weight:

llmc/eval/eval_base.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,24 @@ class BaseEval:
1313
def __init__(self, tokenizer, config):
1414
self.tokenizer = tokenizer
1515
# eval_cfg
16-
eval_cfg = config.eval
16+
self.eval_cfg = config.eval
1717
self.model_type = config.model.type
18-
logger.info(f'eval_cfg : {eval_cfg}')
19-
self.dataset = eval_cfg['name']
18+
logger.info(f'eval_cfg : {self.eval_cfg}')
19+
self.dataset = self.eval_cfg['name']
2020
assert self.dataset in [
2121
'wikitext2',
2222
'c4',
2323
'ptb',
2424
'custom',
2525
'human_eval'
26-
], 'Ppl eval only support wikitext2, c4, ptb, human_eval dataset now.'
27-
self.seq_len = eval_cfg.get('seq_len', None)
28-
self.bs = eval_cfg['bs']
29-
self.path = eval_cfg.get('path', None)
30-
self.download = eval_cfg.get('download', False)
31-
self.load_from_txt = eval_cfg.get('load_from_txt', False)
32-
self.inference_per_block = eval_cfg.get('inference_per_block', False)
26+
], 'Eval only support wikitext2, c4, ptb, custom, human_eval dataset now.'
27+
self.seq_len = self.eval_cfg.get('seq_len', None)
28+
self.bs = self.eval_cfg['bs']
29+
self.path = self.eval_cfg.get('path', None)
30+
self.download = self.eval_cfg.get('download', False)
31+
self.load_from_txt = self.eval_cfg.get('load_from_txt', False)
32+
self.inference_per_block = self.eval_cfg.get('inference_per_block', False)
3333
self.testenc = self.build_data()
34-
self.res_path = eval_cfg.get('res_path', None)
35-
if self.dataset in ['human_eval']:
36-
assert self.res_path is not None
37-
os.makedirs(self.res_path, exist_ok=True)
38-
self.format_tabs = eval_cfg.get('format_tabs', False)
3934

4035
@torch.no_grad()
4136
def build_data(self):

llmc/eval/eval_code.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212

1313
class HumanEval(BaseEval):
14+
def __init__(self, tokenizer, config):
15+
super().__init__(tokenizer, config)
16+
self.res_path = self.eval_cfg.get('res_path', None)
17+
assert self.res_path is not None
18+
os.makedirs(self.res_path, exist_ok=True)
19+
self.format_tabs = self.eval_cfg.get('format_tabs', False)
20+
self.instruction = self.eval_cfg.get('instruction',
21+
'Complete the following Python code:')
22+
self.add_chat_temp = self.eval_cfg.get('add_chat_temp', False)
1423

1524
@torch.no_grad()
1625
def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
@@ -22,6 +31,7 @@ def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
2231
prompt = testenc[task_id]['prompt'].replace(' ', '\t')
2332
else:
2433
prompt = testenc[task_id]['prompt']
34+
prompt = self.gen_prompt(prompt)
2535
batch_completions = self.generate_batch_completion(
2636
model, prompt, bs
2737
)
@@ -46,8 +56,24 @@ def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
4656
res = self.post_process(testenc)
4757
return res
4858

59+
def gen_prompt(self, prompt):
60+
prompt = self.instruction + '\n' + prompt
61+
if self.model_type in ['Starcoder']:
62+
prompt = '<fim_prefix>' + prompt + '<fim_suffix><fim_middle>'
63+
64+
if self.add_chat_temp:
65+
chat_prompt = [{'role': 'user', 'content': prompt}]
66+
chat_prompt = self.tokenizer.apply_chat_template(
67+
chat_prompt,
68+
tokenize=False,
69+
add_generation_prompt=True
70+
)
71+
return chat_prompt
72+
73+
return prompt
74+
4975
@torch.no_grad()
50-
def generated_llama(
76+
def generated(
5177
self,
5278
model,
5379
inputs,
@@ -56,14 +82,20 @@ def generated_llama(
5682
top_p=0.95,
5783
do_sample=True,
5884
):
85+
86+
if hasattr(self.tokenizer, 'pad_token_id'):
87+
pad_token_id = self.tokenizer.pad_token_id
88+
else:
89+
pad_token_id = self.tokenizer.eos_token_id
90+
5991
generated_ids = model.model.generate(
6092
**inputs,
6193
max_new_tokens=max_new_tokens,
6294
temperature=temperature,
6395
top_p=top_p,
6496
do_sample=do_sample,
6597
eos_token_id=self.tokenizer.eos_token_id,
66-
pad_token_id=self.tokenizer.eos_token_id,
98+
pad_token_id=pad_token_id,
6799
use_cache=True,
68100
)
69101
return generated_ids
@@ -74,11 +106,8 @@ def generate_batch_completion(self, model, prompt, bs):
74106
inputs = self.tokenizer(input_batch, return_tensors='pt').to(model.model.device)
75107
input_ids_cutoff = inputs.input_ids.size(dim=1)
76108

77-
if self.model_type in ['Llama']:
78-
generated_ids = self.generated_llama(model, inputs)
79-
model.reset_kv()
80-
else:
81-
raise NotImplementedError('This model is not support yet.')
109+
generated_ids = self.generated(model, inputs)
110+
model.reset_kv()
82111

83112
batch_completions = self.tokenizer.batch_decode(
84113
[ids[input_ids_cutoff:] for ids in generated_ids],

llmc/models/internlm2.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_pre_head_layernorm_layers(self):
3232
def get_layers_except_blocks(self):
3333
return [self.tok_embeddings, self.model.model.norm, self.model.output]
3434

35+
def get_attn_in_block(self, block):
36+
return {'attention': block.attention}
37+
3538
def skip_layer_name(self):
3639
return ['lm_head']
3740

0 commit comments

Comments
 (0)