Skip to content

Commit c7c1b09

Browse files
authored
Merge pull request #7 from jihe520/newdev
feat: local code interpreter
2 parents c85ea60 + 2ab9f5d commit c7c1b09

File tree

14 files changed

+1534
-727
lines changed

14 files changed

+1534
-727
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
- [ ] 更多测试案例
4242
- [ ] docker 部署
4343
- [ ] 引入用户的交互(选择模型,重写等等)
44-
- [ ] codeinterpreter 接入云端 如 e2b 等供应商..
44+
- [x] codeinterpreter 接入云端 如 e2b 等供应商..
4545
- [ ] 多语言: R 语言, matlab
4646
- [ ] 绘图 napki,draw.io
4747

@@ -64,16 +64,14 @@
6464

6565
```bash
6666
ENV=dev
67-
#兼容 OpenAI 格式都行,具体看官方文档
67+
# 兼容 OpenAI 格式都行,具体看官方文档
6868
DEEPSEEK_API_KEY=
6969
DEEPSEEK_MODEL=
7070
DEEPSEEK_BASE_URL=
7171
# 模型最大问答次数
7272
MAX_CHAT_TURNS=60
7373
# 思考反思次数
7474
MAX_RETRIES=5
75-
# https://e2b.dev/
76-
E2B_API_KEY=
7775

7876
LOG_LEVEL=DEBUG
7977
DEBUG=true

backend/.env.dev.example

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ DEEPSEEK_BASE_URL=
77
MAX_CHAT_TURNS=60
88
# 思考反思次数
99
MAX_RETRIES=5
10-
# https://e2b.dev/
11-
E2B_API_KEY=
10+
11+
# 不需要填,默认调用本地 Python
12+
# E2B_API_KEY=
1213

1314
LOG_LEVEL=DEBUG
1415
DEBUG=true

backend/app/config/setting.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydantic import AnyUrl, BeforeValidator, computed_field, field_validator, Field
22
from pydantic_settings import BaseSettings, SettingsConfigDict
33
import os
4-
from typing import Annotated
4+
from typing import Annotated, Optional
55

66

77
def parse_cors(value: str) -> list[str]:
@@ -22,14 +22,18 @@ class Settings(BaseSettings):
2222
DEEPSEEK_BASE_URL: str
2323
MAX_CHAT_TURNS: int
2424
MAX_RETRIES: int
25-
E2B_API_KEY: str
25+
E2B_API_KEY: Optional[str] = None
2626
LOG_LEVEL: str
2727
DEBUG: bool
2828
REDIS_URL: str
2929
REDIS_MAX_CONNECTIONS: int
3030
CORS_ALLOW_ORIGINS: Annotated[list[str] | str, BeforeValidator(parse_cors)]
3131

32-
model_config = SettingsConfigDict(env_file=".env.dev", env_file_encoding="utf-8")
32+
model_config = SettingsConfigDict(
33+
env_file=".env.dev",
34+
env_file_encoding="utf-8",
35+
extra="allow",
36+
)
3337

3438
def get_deepseek_config(self) -> dict:
3539
return {

backend/app/core/agents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from app.utils.enums import CompTemplate, FormatOutPut
1414
from app.utils.log_util import logger
1515
from app.config.setting import settings
16-
from app.tools.code_interpreter import E2BCodeInterpreter
1716
from app.utils.common_utils import get_current_files
1817
from app.utils.redis_manager import redis_manager
1918
from app.schemas.response import SystemMessage
19+
from app.tools.base_interpreter import BaseCodeInterpreter
2020

2121

2222
class Agent:
@@ -93,7 +93,7 @@ def __init__(
9393
work_dir: str, # 工作目录
9494
max_chat_turns: int = settings.MAX_CHAT_TURNS, # 最大聊天次数
9595
max_retries: int = settings.MAX_RETRIES, # 最大反思次数
96-
code_interpreter: E2BCodeInterpreter = None,
96+
code_interpreter: BaseCodeInterpreter = None,
9797
) -> None:
9898
super().__init__(task_id, model, max_chat_turns)
9999
self.work_dir = work_dir

backend/app/core/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
},
2222
]
2323

24+
# have installed: numpy scipy pandas matplotlib seaborn scikit-learn xgboost
25+
2426
# TODO: pip install python
2527

2628
# TODO: read files

backend/app/core/prompts.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from app.utils.enums import FormatOutPut
22

3+
# TODO: 设计成一个类?
4+
35
MODELER_PROMPT = """
46
role:你是一名数学建模经验丰富的建模手,负责建模部分。
57
task:你需要根据用户要求和数据建立数学模型求解问题。
@@ -9,6 +11,8 @@
911
**不需要建立复杂的模型,简单规划需要步骤**
1012
"""
1113

14+
# TODO : 对于特大 csv 读取
15+
1216
CODER_PROMPT = """You are an AI code interpreter.
1317
Your goal is to help users do a variety of jobs by executing Python code.
1418
@@ -72,12 +76,12 @@ def get_writer_prompt(
7276
return f"""
7377
role:你是一名数学建模经验丰富的写作手,负责写作部分。
7478
task: 根据问题和如下的模板写出解答,
75-
skill:熟练掌握{format_output}排版,
76-
output:你需要按照要求的格式排版,只输出{format_output}排版的内容
79+
skill:熟练掌握{format_output}排版,如图片、**公式**、表格、列表等
80+
output:你需要按照要求的格式排版,只输出正确的{format_output}排版的内容
7781
7882
1. 当你输入图像引用时候,使用![image_name](image_name.png)
7983
2. 你不需要输出markdown的这个```markdown格式,只需要输出markdown的内容,
80-
3. Latex公式使用$$ $$包裹
84+
3. LaTex: 行内公式(Inline Formula) 和 块级公式(Block Formula
8185
4. 严格按照参考用户输入的格式模板以及**正确的编号顺序**
8286
5. 不需要询问用户
8387
6. 当提到图片时,请使用提供的图片列表中的文件名

backend/app/core/workflow.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from app.utils.common_utils import create_work_dir, simple_chat, get_config_template
88
from app.models.user_output import UserOutput
99
from app.config.setting import settings
10+
from app.tools.interpreter_factory import create_interpreter
1011
from app.core.llm import DeepSeekModel
11-
from app.tools.code_interpreter import E2BCodeInterpreter
1212
import json
1313
from app.utils.redis_manager import redis_manager
1414
from app.utils.notebook_serializer import NotebookSerializer
15+
from app.tools.base_interpreter import BaseCodeInterpreter
1516

1617

1718
class WorkFlow:
@@ -58,9 +59,10 @@ async def execute(self, problem: Problem):
5859
SystemMessage(content="正在创建代码沙盒环境"),
5960
)
6061

61-
e2b_code_interpreter = await E2BCodeInterpreter.create(
62-
workd_dir=self.work_dir,
62+
code_interpreter = await create_interpreter(
63+
kind="local",
6364
task_id=self.task_id,
65+
work_dir=self.work_dir,
6466
notebook_serializer=notebook_serializer,
6567
timeout=3000,
6668
)
@@ -81,7 +83,7 @@ async def execute(self, problem: Problem):
8183
work_dir=self.work_dir,
8284
max_chat_turns=settings.MAX_CHAT_TURNS,
8385
max_retries=settings.MAX_RETRIES,
84-
code_interpreter=e2b_code_interpreter,
86+
code_interpreter=code_interpreter,
8587
)
8688

8789
################################################ solution steps
@@ -106,7 +108,7 @@ async def execute(self, problem: Problem):
106108

107109
# TODO: 是否可以不需要coder_response
108110
writer_prompt = self.get_writer_prompt(
109-
key, coder_response, e2b_code_interpreter, config_template
111+
key, coder_response, code_interpreter, config_template
110112
)
111113

112114
await redis_manager.publish_message(
@@ -125,7 +127,7 @@ async def execute(self, problem: Problem):
125127
## TODO: 图片引用错误
126128
writer_response = await writer_agent.run(
127129
writer_prompt,
128-
available_images=await e2b_code_interpreter.get_created_images(key),
130+
available_images=await code_interpreter.get_created_images(key),
129131
sub_title=key,
130132
)
131133

@@ -138,7 +140,7 @@ async def execute(self, problem: Problem):
138140

139141
# 关闭沙盒
140142

141-
await e2b_code_interpreter.shutdown_sandbox()
143+
await code_interpreter.cleanup()
142144
logger.info(user_output.get_res())
143145

144146
################################################ write steps
@@ -223,7 +225,7 @@ def get_writer_prompt(
223225
self,
224226
key: str,
225227
coder_response: str,
226-
code_interpreter: E2BCodeInterpreter,
228+
code_interpreter: BaseCodeInterpreter,
227229
config_template: dict,
228230
) -> str:
229231
"""根据不同的key生成对应的writer_prompt

backend/app/routers/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def config():
3131
}
3232

3333

34-
@router.post("/modeling/")
34+
@router.post("/modeling")
3535
async def modeling(
3636
background_tasks: BackgroundTasks,
3737
ques_all: str = Form(...), # 从表单获取
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# base_interpreter.py
2+
import abc
3+
import re
4+
from app.utils.notebook_serializer import NotebookSerializer
5+
from app.utils.redis_manager import redis_manager
6+
from app.utils.common_utils import get_current_files
7+
from app.utils.log_util import logger
8+
from app.schemas.response import (
9+
CoderMessage,
10+
OutputItem,
11+
AgentType,
12+
)
13+
14+
15+
class BaseCodeInterpreter(abc.ABC):
16+
def __init__(
17+
self,
18+
task_id: str,
19+
work_dir: str,
20+
notebook_serializer: NotebookSerializer,
21+
):
22+
self.task_id = task_id
23+
self.work_dir = work_dir
24+
self.notebook_serializer = notebook_serializer
25+
self.section_output: dict[str, dict[str, list[str]]] = {}
26+
self.created_images: list[str] = []
27+
28+
@abc.abstractmethod
29+
async def initialize(self):
30+
"""初始化解释器,必要时上传文件、启动内核等"""
31+
...
32+
33+
@abc.abstractmethod
34+
async def _pre_execute_code(self):
35+
"""执行初始化代码"""
36+
...
37+
38+
@abc.abstractmethod
39+
async def execute_code(self, code: str) -> tuple[str, bool, str]:
40+
"""执行一段代码,返回 (输出文本, 是否出错, 错误信息)"""
41+
...
42+
43+
@abc.abstractmethod
44+
async def cleanup(self):
45+
"""清理资源,比如关闭沙箱或内核"""
46+
...
47+
48+
async def _push_to_websocket(self, content_to_display: list[OutputItem] | None):
49+
logger.info("执行结果已推送到WebSocket")
50+
51+
agent_msg = CoderMessage(
52+
agent_type=AgentType.CODER,
53+
code_results=content_to_display,
54+
files=get_current_files(self.work_dir, "all"),
55+
)
56+
logger.debug(f"发送消息: {agent_msg.model_dump_json()}")
57+
await redis_manager.publish_message(
58+
self.task_id,
59+
agent_msg,
60+
)
61+
62+
def add_section(self, section_name: str) -> None:
63+
"""确保添加的section结构正确"""
64+
65+
if section_name not in self.section_output:
66+
self.section_output[section_name] = {"content": [], "images": []}
67+
68+
def add_content(self, section: str, text: str) -> None:
69+
"""向指定section添加文本内容"""
70+
self.add_section(section)
71+
self.section_output[section]["content"].append(text)
72+
73+
def get_code_output(self, section: str) -> str:
74+
"""获取指定section的代码输出"""
75+
return "\n".join(self.section_output[section]["content"])
76+
77+
def delete_color_control_char(self, string):
78+
ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
79+
return ansi_escape.sub("", string)
80+
81+
def _truncate_text(self, text: str, max_length: int = 1000) -> str:
82+
"""截断文本,保留开头和结尾的重要信息"""
83+
if len(text) <= max_length:
84+
return text
85+
86+
half_length = max_length // 2
87+
return text[:half_length] + "\n... (内容已截断) ...\n" + text[-half_length:]

0 commit comments

Comments
 (0)