2
2
import json
3
3
import logging
4
4
import multiprocessing
5
+ import os
5
6
import re
6
7
import shutil
7
8
import sys
12
13
from typing import List , Optional
13
14
14
15
import prompts
16
+ import translator
15
17
import utils
16
18
from langchain .callbacks .manager import CallbackManager
17
19
from langchain .callbacks .streaming_stdout import StreamingStdOutCallbackHandler
22
24
from langchain .pydantic_v1 import BaseModel
23
25
24
26
AUTOFIX_WITH_OPENAI = False
25
- ENABLE_STREAMING = False
27
+ ENABLE_STREAMING = True
26
28
REQUEST_TIMEOUT = 60
29
+ USE_MULTIPROCESSING_FOR_TERMINATION = True
27
30
28
31
29
32
class TypeEvalPySchema (BaseModel ):
@@ -39,6 +42,7 @@ class TypeEvalPySchema(BaseModel):
39
42
"json_based_1" : prompts .json_based_1 ,
40
43
"json_based_2" : prompts .json_based_2 ,
41
44
"questions_based_1" : prompts .questions_based_1 ,
45
+ "questions_based_2" : prompts .questions_based_2 ,
42
46
}
43
47
44
48
# Create a logger
@@ -76,17 +80,20 @@ def get_prompt(prompt_id, code_path, json_filepath):
76
80
with open (code_path , "r" ) as file :
77
81
code = file .read ()
78
82
79
- if prompt_id == "questions_based_1" :
83
+ if prompt_id in [ "questions_based_1" , "questions_based_2" ] :
80
84
questions_from_json = utils .generate_questions_from_json (json_filepath )
81
85
82
86
prompt = PromptTemplate (
83
87
template = PROMPTS_MAP [prompt_id ],
84
- input_variables = ["code" , "questions" ],
88
+ input_variables = ["code" , "questions" , "answers" ],
85
89
)
86
90
87
91
prompt_data = {
88
92
"code" : code ,
89
- "questions" : "\n Result:\n " .join (questions_from_json ),
93
+ "questions" : "\n " .join (questions_from_json ),
94
+ "answers" : "\n " .join (
95
+ [f"{ x } ." for x in range (1 , len (questions_from_json ) + 1 )]
96
+ ),
90
97
}
91
98
elif prompt_id in ["json_based_1" , "json_based_2" ]:
92
99
parser = PydanticOutputParser (pydantic_object = TypeEvalPySchema )
@@ -112,32 +119,35 @@ def process_file(file_path, llm, openai_llm, prompt_id):
112
119
json_filepath = str (file_path ).replace (".py" , "_gt.json" )
113
120
result_filepath = str (file_path ).replace (".py" , f"_result.json" )
114
121
115
- # Queue for communication between processes
116
- queue = multiprocessing .Queue ()
122
+ if USE_MULTIPROCESSING_FOR_TERMINATION :
123
+ # Queue for communication between processes
124
+ queue = multiprocessing .Queue ()
117
125
118
- # Create a process for llm.invoke
119
- process = multiprocessing .Process (
120
- target = invoke_llm ,
121
- args = (llm , get_prompt (prompt_id , file_path , json_filepath ), queue ),
122
- )
123
- process .start ()
126
+ # Create a process for llm.invoke
127
+ process = multiprocessing .Process (
128
+ target = invoke_llm ,
129
+ args = (llm , get_prompt (prompt_id , file_path , json_filepath ), queue ),
130
+ )
131
+ process .start ()
124
132
125
- # Wait for the process to finish with a timeout (e.g., 60 seconds)
126
- process .join (timeout = REQUEST_TIMEOUT )
133
+ # Wait for the process to finish with a timeout (e.g., 60 seconds)
134
+ process .join (timeout = REQUEST_TIMEOUT )
127
135
128
- if process .is_alive ():
129
- logger .info (f"Timeout occurred for { file_path } " )
130
- process .terminate () # Terminate the process if it's still running
131
- process .join ()
132
- logger .info (f"{ file_path } failed: Not a valid JSON" )
133
- raise utils .TimeoutException ("json" )
136
+ if process .is_alive ():
137
+ logger .info (f"Timeout occurred for { file_path } " )
138
+ process .terminate () # Terminate the process if it's still running
139
+ process .join ()
140
+ logger .info (f"{ file_path } failed: Not a valid JSON" )
141
+ raise utils .TimeoutException ("json" )
134
142
135
- result = queue .get_nowait ()
143
+ result = queue .get_nowait ()
136
144
137
- if isinstance (result , Exception ):
138
- raise result
145
+ if isinstance (result , Exception ):
146
+ raise result
139
147
140
- output = result
148
+ output = result
149
+ else :
150
+ output = llm .invoke (get_prompt (prompt_id , file_path , json_filepath ))
141
151
142
152
if isinstance (llm , ChatOpenAI ):
143
153
output = output .content
@@ -157,7 +167,13 @@ def process_file(file_path, llm, openai_llm, prompt_id):
157
167
158
168
logger .info (output )
159
169
160
- is_valid_json = utils .generate_json_file (result_filepath , output )
170
+ if prompt_id == "questions_based_2" :
171
+ answers_json = utils .generate_json_from_answers (json_filepath , output )
172
+ translated_json = translator .translate_content (answers_json )
173
+ else :
174
+ translated_json = translator .translate_content (output )
175
+
176
+ is_valid_json = utils .generate_json_file (result_filepath , translated_json )
161
177
if not is_valid_json :
162
178
logger .info (f"{ file_path } failed: Not a valid JSON" )
163
179
raise utils .JsonException ("json" )
@@ -184,33 +200,37 @@ def main_runner(args):
184
200
185
201
python_files = list_python_files (results_dst )
186
202
187
- if model .startswith ("gpt-" ):
188
- # OpenAI models
189
- llm = ChatOpenAI (
190
- model_name = model ,
191
- temperature = temperature ,
192
- openai_api_key = args .openai_key ,
193
- )
194
-
195
- else :
196
- llm = Ollama (
197
- model = model ,
198
- callback_manager = (
199
- CallbackManager ([StreamingStdOutCallbackHandler ()])
200
- if ENABLE_STREAMING
201
- else None
202
- ),
203
- temperature = temperature ,
204
- timeout = REQUEST_TIMEOUT ,
205
- )
206
- llm .base_url = args .ollama_url
207
- if utils .is_ollama_online (llm .base_url ):
203
+ if not model .startswith ("gpt-" ):
204
+ if utils .is_ollama_online (args .ollama_url ):
208
205
logger .info ("Ollama is online!" )
209
206
else :
210
207
logger .error ("Ollama server is not online!!!" )
211
208
sys .exit (- 1 )
212
209
213
210
for file in python_files :
211
+ # Recreating llm object each iteration since we might force terminate in thread
212
+ # Maybe there is another better way to do this
213
+ if model .startswith ("gpt-" ):
214
+ # OpenAI models
215
+ llm = ChatOpenAI (
216
+ model_name = model ,
217
+ temperature = temperature ,
218
+ openai_api_key = args .openai_key ,
219
+ )
220
+
221
+ else :
222
+ llm = Ollama (
223
+ model = model ,
224
+ callback_manager = (
225
+ CallbackManager ([StreamingStdOutCallbackHandler ()])
226
+ if ENABLE_STREAMING
227
+ else None
228
+ ),
229
+ temperature = temperature ,
230
+ timeout = REQUEST_TIMEOUT ,
231
+ )
232
+ llm .base_url = args .ollama_url
233
+
214
234
prompt_start_time = time .time ()
215
235
try :
216
236
logger .info (file )
@@ -229,9 +249,9 @@ def main_runner(args):
229
249
230
250
files_analyzed += 1
231
251
logger .info (
232
- f"Progress: { files_analyzed } /{ len (python_files )} | Errors/ JSON: "
233
- f" { error_count } / { json_count } | PromptTime: "
234
- f" { time .time ()- prompt_start_time } "
252
+ f"Progress: { files_analyzed } /{ len (python_files )} | Total Errors / JSON"
253
+ f" Errors / Timeouts: { error_count } , { json_count } , { timeout_count } | "
254
+ f" PromptTime: { time .time ()- prompt_start_time } "
235
255
)
236
256
237
257
logger .info (
0 commit comments