Skip to content

Commit 76e7c29

Browse files
committed
fix: workflow connection handling with force_connect method
1 parent bf5c8f3 commit 76e7c29

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

framework/web/api/workflow/routes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ async def create_workflow(group_id: str, workflow_id: str):
112112
builder.update_position(block_def.name, block_def.position)
113113

114114
# 添加连接
115+
builder.wires = []
115116
for wire in workflow_def.wires:
116117
source_block = next(b for b in builder.blocks if b.name == wire.source_block)
117118
target_block = next(b for b in builder.blocks if b.name == wire.target_block)
118-
builder._connect_blocks(source_block, target_block)
119+
builder.force_connect(source_block, target_block, wire.source_output, wire.target_input)
119120

120121
# 保存工作流
121122
file_path = registry.get_workflow_path(group_id, workflow_id)
@@ -161,10 +162,11 @@ async def update_workflow(group_id: str, workflow_id: str):
161162

162163
builder.update_position(block_def.name, block_def.position)
163164
# 添加连接
165+
builder.wires = []
164166
for wire in workflow_def.wires:
165167
source_block = next(b for b in builder.blocks if b.name == wire.source_block)
166168
target_block = next(b for b in builder.blocks if b.name == wire.target_block)
167-
builder._connect_blocks(source_block, target_block)
169+
builder.force_connect(source_block, target_block, wire.source_output, wire.target_input)
168170

169171
# 保存工作流
170172
file_path = registry.get_workflow_path(group_id, workflow_id)

framework/workflow/core/workflow/builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ def _connect_blocks(self, source_block: Block, target_block: Block):
317317
if is_connected:
318318
break
319319

320+
def force_connect(self, source_block: Block, target_block: Block, source_output: str, target_input: str):
321+
"""强制连接两个块"""
322+
wire = Wire(source_block, source_output, target_block, target_input)
323+
self.wires.append(wire)
324+
320325
def _find_parallel_nodes(self, start_node: Node) -> List[Node]:
321326
"""查找所有并行节点"""
322327
parallel_nodes = []
@@ -452,11 +457,12 @@ def get_block_class(type_name: str) -> Type[Block]:
452457
builder.chain(block_class, name=block_data['name'], **params)
453458
builder.update_position(block_data['name'], block_data['position'])
454459
# 第二遍:建立连接
460+
builder.wires = []
455461
for block_data in workflow_data['blocks']:
456462
if 'connected_to' in block_data:
457463
source_node = builder.nodes_by_name[block_data['name']]
458464
for connection in block_data['connected_to']:
459465
target_node = builder.nodes_by_name[connection['target']]
460-
builder._connect_blocks(source_node.block, target_node.block)
466+
builder.force_connect(source_node.block, target_node.block, connection['mapping']['from'], connection['mapping']['to'])
461467

462468
return builder

framework/workflow/implementations/blocks/llm/chat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class ChatMessageConstructor(Block):
1616
name = "chat_message_constructor"
1717
inputs = {
1818
"user_msg": Input("user_msg", "本轮消息", IMMessage, "用户消息"),
19+
"user_prompt_format": Input("user_prompt_format", "本轮消息格式", str, "本轮消息格式", default=""),
1920
"memory_content": Input("memory_content", "上下文消息", str, "历史消息对话"),
2021
"system_prompt_format": Input("system_prompt_format", "上下文消息格式", str, "上下文消息格式", default=""),
21-
"user_prompt_format": Input("user_prompt_format", "本轮消息格式", str, "本轮消息格式", default="")
2222
}
2323
outputs = {"llm_msg": Output("llm_msg", "LLM 对话记录", List[LLMChatMessage], "LLM 对话记录")}
2424
container: DependencyContainer
@@ -74,7 +74,6 @@ def execute(self, user_msg: IMMessage, memory_content: str, system_prompt_format
7474
# 再替换其他变量
7575
system_prompt = self.substitute_variables(system_prompt_format, executor)
7676
user_prompt = self.substitute_variables(user_prompt_format, executor)
77-
7877
llm_msg = [
7978
LLMChatMessage(role='system', content=system_prompt),
8079
LLMChatMessage(role='user', content=user_prompt)

plugins/llm_preset_adapters/gemini_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def chat(self, req: LLMChatRequest) -> LLMChatResponse:
4141
"safetySettings": []
4242
}
4343

44-
self.logger.debug(f"Contents: {data['contents'][0]['parts'][0]['text']}")
44+
self.logger.debug(f"Contents: {data['contents']}")
4545
# Remove None fields
4646
data = {k: v for k, v in data.items() if v is not None}
4747

0 commit comments

Comments
 (0)