@@ -59,7 +59,8 @@ def calculate_rouge_score(model_outputs, ref_outputs):
59
59
m_result = metric .compute (
60
60
predictions = m_preds , references = m_targets , use_stemmer = True , use_aggregator = False
61
61
)
62
- m_rouge_result = {k : round (np .mean (v ) * 100 , 4 ) for k , v in m_result .items ()}
62
+ m_rouge_result = {k : round (np .mean (v ) * 100 , 4 )
63
+ for k , v in m_result .items ()}
63
64
64
65
return m_rouge_result
65
66
@@ -101,30 +102,35 @@ def maybe_remove_comma(x: str) -> str:
101
102
def try_float (x : str ):
102
103
try :
103
104
ret = float (x )
104
- except :
105
+ except BaseException :
105
106
ret = None
106
107
return ret
107
108
109
+
108
110
def postprocess_golang (code : str ) -> str :
109
- multi_line_imports = re .compile (r"^import \(\n(.+)((?:\n.+)+)\n\)" , re .MULTILINE )
111
+ multi_line_imports = re .compile (
112
+ r"^import \(\n(.+)((?:\n.+)+)\n\)" , re .MULTILINE )
110
113
line_imports = re .compile (r"^import \".*\"" )
111
114
func_main = re .compile (r"^func main.*^}" , re .MULTILINE | re .DOTALL )
112
115
113
- code = code .replace ("package main" , "" ) # Remove package main
116
+ code = code .replace ("package main" , "" ) # Remove package main
114
117
code = multi_line_imports .sub ("" , code )
115
118
code = line_imports .sub ("" , code )
116
119
code = func_main .sub ("" , code )
117
120
118
121
return code
119
122
123
+
120
124
def postprocess_scala (code : str ) -> str :
121
125
code = code .replace ("object Main extends App {" , "" )
122
126
code = "" .join (code .splitlines (True )[:- 1 ])
123
127
return code
124
128
129
+
125
130
def postprocess_python (code : str ) -> str :
126
131
return code .lstrip ()
127
132
133
+
128
134
def worker (inp_queue , out_queue ):
129
135
while True :
130
136
try :
@@ -143,7 +149,7 @@ def worker(inp_queue, out_queue):
143
149
try :
144
150
solution = solution [:solution .index ("```" )]
145
151
except ValueError :
146
- #Happens when a code block isn't closed properly
152
+ # Happens when a code block isn't closed properly
147
153
pass
148
154
149
155
if problem ["lang" ] == "go" :
@@ -153,15 +159,22 @@ def worker(inp_queue, out_queue):
153
159
elif problem ["lang" ] == "scala" :
154
160
solution = postprocess_scala (solution )
155
161
156
- # Mixtral likes escaping underscores for some reason, so let's remove these
157
- solution = solution .replace ("\_" , "_" )
162
+ # Mixtral likes escaping underscores for some reason, so let's remove
163
+ # these
164
+ solution = solution .replace ("\\ _" , "_" )
158
165
159
166
# The evaluation script evaluates `code = prompt + solution + tests`
160
- # But Mixtral regenerates the prompt in its output, so we should remove this
167
+ # But Mixtral regenerates the prompt in its output, so we should remove
168
+ # this
161
169
problem ["prompt" ] = ""
162
170
163
171
result = checker (problem , solution , timeout = 20.0 )
164
- out_queue .put ((key , problem ["lang" ], result ["passed" ], result ["result" ], problem ["response" ]))
172
+ out_queue .put (
173
+ (key ,
174
+ problem ["lang" ],
175
+ result ["passed" ],
176
+ result ["result" ],
177
+ problem ["response" ]))
165
178
166
179
167
180
def convert_pickle (df : pd .DataFrame , result_keys : dict ):
@@ -193,7 +206,8 @@ def evaluate_mbxp(n_works: int, df: pd.DataFrame, result_keys: dict):
193
206
n_problems = 0
194
207
195
208
for lang , problems in by_lang .items ():
196
- if lang not in ["cpp" , "python" , "php" , "javascript" , "ruby" , "typescript" ]:
209
+ if lang not in ["cpp" , "python" , "php" ,
210
+ "javascript" , "ruby" , "typescript" ]:
197
211
raise RuntimeError (f"{ lang } not in supported list." )
198
212
199
213
n_problems += len (problems )
@@ -213,7 +227,10 @@ def evaluate_mbxp(n_works: int, df: pd.DataFrame, result_keys: dict):
213
227
lang_counts = {}
214
228
for i in tqdm (range (n_problems )):
215
229
key , lang , passed , result , response = out_queue .get ()
216
- passes [key ] = {"passed" : passed , "result" : result , "response" : response }
230
+ passes [key ] = {
231
+ "passed" : passed ,
232
+ "result" : result ,
233
+ "response" : response }
217
234
n_passed += passed
218
235
219
236
lang_passed .setdefault (lang , 0 )
@@ -244,7 +261,8 @@ def evaluate_openorca(df: pd.DataFrame, result_keys: dict):
244
261
score = calculate_rouge_score (gen_output , gt_output )
245
262
gen_token_len = df [result_keys ['length' ]].tolist ()
246
263
gen_token_per_sample = sum (gen_token_len ) / len (gen_token_len )
247
- print (f"OpenOrca score: { score } , gen_token_per_sample: { gen_token_per_sample } " )
264
+ print (
265
+ f"OpenOrca score: { score } , gen_token_per_sample: { gen_token_per_sample } " )
248
266
return score
249
267
250
268
@@ -266,13 +284,18 @@ def evaluate_gsm8k(df: pd.DataFrame, result_keys: dict):
266
284
em = correct / total
267
285
gen_token_len = df [result_keys ['length' ]].tolist ()
268
286
gen_token_per_sample = sum (gen_token_len ) / len (gen_token_len )
269
- print (f"EM: { em } , correct: { correct } / { total } , gen_token_per_sample: { gen_token_per_sample } " )
287
+ print (
288
+ f"EM: { em } , correct: { correct } / { total } , gen_token_per_sample: { gen_token_per_sample } " )
270
289
return em
271
290
272
291
273
292
if __name__ == "__main__" :
274
293
parser = argparse .ArgumentParser ()
275
- parser .add_argument ("--n_workers" , type = int , default = 10 , help = "The number of processes to use" )
294
+ parser .add_argument (
295
+ "--n_workers" ,
296
+ type = int ,
297
+ default = 10 ,
298
+ help = "The number of processes to use" )
276
299
parser .add_argument ("--results_path" , type = str , default = "mixtral_8x7b_15000_greedy_reference_fp16_mintoken2.pkl" ,
277
300
help = "The path to the results file pickle file" )
278
301
parser .add_argument ("--result_key" , type = str , default = "ref_output" ,
@@ -307,9 +330,9 @@ def evaluate_gsm8k(df: pd.DataFrame, result_keys: dict):
307
330
"""
308
331
309
332
df = pd .read_pickle (args .results_path )
310
- df_gsm8k = df [df ['dataset' ]== "GSM8K" ].copy ()
333
+ df_gsm8k = df [df ['dataset' ] == "GSM8K" ].copy ()
311
334
evaluate_gsm8k (df_gsm8k , result_keys )
312
- df_openorca = df [df ['dataset' ]== "OpenOrca" ].copy ()
335
+ df_openorca = df [df ['dataset' ] == "OpenOrca" ].copy ()
313
336
evaluate_openorca (df_openorca , result_keys )
314
- df_mbxp = df [df ['dataset' ]== "MBXP" ].copy ()
337
+ df_mbxp = df [df ['dataset' ] == "MBXP" ].copy ()
315
338
evaluate_mbxp (args .n_workers , df_mbxp , result_keys )
0 commit comments