11
11
12
12
13
13
class HumanEval (BaseEval ):
14
+ def __init__ (self , tokenizer , config ):
15
+ super ().__init__ (tokenizer , config )
16
+ self .res_path = self .eval_cfg .get ('res_path' , None )
17
+ assert self .res_path is not None
18
+ os .makedirs (self .res_path , exist_ok = True )
19
+ self .format_tabs = self .eval_cfg .get ('format_tabs' , False )
20
+ self .instruction = self .eval_cfg .get ('instruction' ,
21
+ 'Complete the following Python code:' )
22
+ self .add_chat_temp = self .eval_cfg .get ('add_chat_temp' , False )
14
23
15
24
@torch .no_grad ()
16
25
def eval_func (self , org_model , model , testenc , seq_len , bs , eval_pos ):
@@ -22,6 +31,7 @@ def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
22
31
prompt = testenc [task_id ]['prompt' ].replace (' ' , '\t ' )
23
32
else :
24
33
prompt = testenc [task_id ]['prompt' ]
34
+ prompt = self .gen_prompt (prompt )
25
35
batch_completions = self .generate_batch_completion (
26
36
model , prompt , bs
27
37
)
@@ -46,8 +56,24 @@ def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
46
56
res = self .post_process (testenc )
47
57
return res
48
58
59
+ def gen_prompt (self , prompt ):
60
+ prompt = self .instruction + '\n ' + prompt
61
+ if self .model_type in ['Starcoder' ]:
62
+ prompt = '<fim_prefix>' + prompt + '<fim_suffix><fim_middle>'
63
+
64
+ if self .add_chat_temp :
65
+ chat_prompt = [{'role' : 'user' , 'content' : prompt }]
66
+ chat_prompt = self .tokenizer .apply_chat_template (
67
+ chat_prompt ,
68
+ tokenize = False ,
69
+ add_generation_prompt = True
70
+ )
71
+ return chat_prompt
72
+
73
+ return prompt
74
+
49
75
@torch .no_grad ()
50
- def generated_llama (
76
+ def generated (
51
77
self ,
52
78
model ,
53
79
inputs ,
@@ -56,14 +82,20 @@ def generated_llama(
56
82
top_p = 0.95 ,
57
83
do_sample = True ,
58
84
):
85
+
86
+ if hasattr (self .tokenizer , 'pad_token_id' ):
87
+ pad_token_id = self .tokenizer .pad_token_id
88
+ else :
89
+ pad_token_id = self .tokenizer .eos_token_id
90
+
59
91
generated_ids = model .model .generate (
60
92
** inputs ,
61
93
max_new_tokens = max_new_tokens ,
62
94
temperature = temperature ,
63
95
top_p = top_p ,
64
96
do_sample = do_sample ,
65
97
eos_token_id = self .tokenizer .eos_token_id ,
66
- pad_token_id = self . tokenizer . eos_token_id ,
98
+ pad_token_id = pad_token_id ,
67
99
use_cache = True ,
68
100
)
69
101
return generated_ids
@@ -74,11 +106,8 @@ def generate_batch_completion(self, model, prompt, bs):
74
106
inputs = self .tokenizer (input_batch , return_tensors = 'pt' ).to (model .model .device )
75
107
input_ids_cutoff = inputs .input_ids .size (dim = 1 )
76
108
77
- if self .model_type in ['Llama' ]:
78
- generated_ids = self .generated_llama (model , inputs )
79
- model .reset_kv ()
80
- else :
81
- raise NotImplementedError ('This model is not support yet.' )
109
+ generated_ids = self .generated (model , inputs )
110
+ model .reset_kv ()
82
111
83
112
batch_completions = self .tokenizer .batch_decode (
84
113
[ids [input_ids_cutoff :] for ids in generated_ids ],
0 commit comments