Skip to content

Commit 26e23bd

Browse files
committed
refactor: support position and param meta info
1 parent 79345f5 commit 26e23bd

File tree

14 files changed

+135
-58
lines changed

14 files changed

+135
-58
lines changed

framework/web/api/workflow/routes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def get_workflow(group_id: str, workflow_id: str):
5252
'type_name': block_registry.get_block_type_name(block.__class__),
5353
'name': block.name,
5454
'config': builder.nodes_by_name[block.name].spec.kwargs,
55-
'position': block.position if hasattr(block, 'position') else {'x': 0, 'y': 0}
55+
'position': builder.nodes_by_name[block.name].position if hasattr(builder.nodes_by_name[block.name], 'position') else {'x': 0, 'y': 0}
5656
})
5757

5858
wires = []
@@ -107,7 +107,9 @@ async def create_workflow(group_id: str, workflow_id: str):
107107
builder.use(block_class, name=block_def.name, **block_def.config)
108108
else:
109109
builder.chain(block_class, name=block_def.name, **block_def.config)
110-
110+
111+
builder.update_position(block_def.name, block_def.position)
112+
111113
# 添加连接
112114
for wire in workflow_def.wires:
113115
source_block = next(b for b in builder.blocks if b.name == wire.source_block)
@@ -156,6 +158,7 @@ async def update_workflow(group_id: str, workflow_id: str):
156158
else:
157159
builder.chain(block_class, name=block_def.name, **block_def.config)
158160

161+
builder.update_position(block_def.name, block_def.position)
159162
# 添加连接
160163
for wire in workflow_def.wires:
161164
source_block = next(b for b in builder.blocks if b.name == wire.source_block)
Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
from .base import Block, ConditionBlock, LoopBlock, LoopEndBlock
22
from .registry import BlockRegistry
33
from .schema import BlockInput, BlockOutput, BlockConfig
4+
from .param import ParamMeta
5+
from .input_output import Input, Output
46

5-
__all__ = ["Block", "ConditionBlock", "LoopBlock", "LoopEndBlock", "BlockRegistry", "BlockInput", "BlockOutput", "BlockConfig"]
7+
__all__ = [
8+
"Block",
9+
"ConditionBlock",
10+
"LoopBlock",
11+
"LoopEndBlock",
12+
"BlockRegistry",
13+
"BlockInput",
14+
"BlockOutput",
15+
"BlockConfig",
16+
"ParamMeta",
17+
"Input",
18+
"Output"
19+
]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
3+
class ParamMeta:
4+
def __init__(self, label: Optional[str] = None, description: Optional[str] = None):
5+
self.label = label
6+
self.description = description
7+
8+
def __repr__(self):
9+
return f"ParamMeta(label={self.label}, description={self.description})"
10+
11+
def __str__(self):
12+
return self.__repr__()
13+
14+

framework/workflow/core/block/registry.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,72 @@
11
from inspect import Parameter, signature
2-
from typing import Dict, List, Tuple, Type, Optional
2+
from typing import Dict, List, Tuple, Type, Optional, get_origin, get_args, Annotated, Union
33
import warnings
44
from framework.workflow.core.block import Block
5+
from framework.workflow.core.block.param import ParamMeta
56
from .schema import BlockConfig, BlockInput, BlockOutput
67

8+
def extract_block_param(param):
9+
"""
10+
提取 Block 参数信息,包括类型字符串、标签、是否必需、描述和默认值。
11+
"""
12+
param_type = param.annotation
13+
required = True
14+
label = param.name
15+
description = None
16+
default = param.default if param.default != Parameter.empty else None
17+
18+
if get_origin(param_type) is Annotated:
19+
args = get_args(param_type)
20+
if len(args) > 0:
21+
actual_type = args[0]
22+
metadata = args[1] if len(args) > 1 else None
23+
24+
if isinstance(metadata, ParamMeta):
25+
label = metadata.label
26+
description = metadata.description
27+
28+
# 递归调用 extract_block_param 处理实际类型
29+
block_config = extract_block_param(Parameter(name=param.name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=actual_type, default=default))
30+
type_string = block_config.type
31+
required = block_config.required # 继承 required 属性
32+
else:
33+
type_string = "Any"
34+
elif get_origin(param_type) is Union:
35+
args = get_args(param_type)
36+
# 检查 Union 中是否包含 NoneType
37+
if type(None) in args:
38+
required = False
39+
# 移除 NoneType,并递归处理剩余的类型
40+
non_none_args = [arg for arg in args if arg is not type(None)]
41+
if len(non_none_args) == 1:
42+
block_config = extract_block_param(Parameter(name=param.name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=non_none_args[0], default=default))
43+
type_string = block_config.type
44+
else:
45+
# 如果 Union 中包含多个非 NoneType,则返回 Union 类型
46+
type_string = f"Union[{', '.join(get_type_name(arg) for arg in non_none_args)}]"
47+
else:
48+
# 如果 Union 中不包含 NoneType,则直接返回 Union 类型
49+
type_string = f"Union[{', '.join(get_type_name(arg) for arg in args)}]"
50+
else:
51+
type_string = get_type_name(param_type)
52+
53+
return BlockConfig(
54+
name=param.name, # 设置名称
55+
description=description,
56+
type=type_string,
57+
required=required,
58+
default=default, # 设置默认值
59+
label=label
60+
)
61+
62+
def get_type_name(type_obj):
63+
"""
64+
获取类型的名称。
65+
"""
66+
if hasattr(type_obj, '__name__'):
67+
return type_obj.__name__
68+
return str(type_obj)
69+
770
class BlockRegistry:
871
"""Block 注册表,用于管理所有已注册的 block"""
972

@@ -102,21 +165,9 @@ def extract_block_info(self, block_type: Type[Block]) -> Tuple[Dict[str, BlockIn
102165
if param.name in builtin_params:
103166
continue
104167

105-
param_type = param.annotation
106-
# 解 Optional[T] 类型
107-
if hasattr(param_type, '__args__') and param_type.__name__ == 'Optional':
168+
block_config = extract_block_param(param)
108169

109-
actual_type = param_type.__args__[0]
110-
else:
111-
actual_type = param_type
112-
113-
configs[param.name] = BlockConfig(
114-
name=param.name,
115-
description='', # 暂时没有描述信息
116-
type=str(actual_type.__name__),
117-
required=param.default == Parameter.empty, # 没有默认值则为必需
118-
default=param.default if param.default != Parameter.empty else None
119-
)
170+
configs[param.name] = block_config
120171
return inputs, outputs, configs
121172

122173
def get_builtin_params(self) -> List[str]:

framework/workflow/core/block/schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class BlockOutput(BaseModel):
2222
class BlockConfig(BaseModel):
2323
"""Block配置项定义"""
2424
name: str
25-
description: str
25+
description: Optional[str] = None
2626
type: str
2727
required: bool = True
28-
default: Optional[Any] = None
28+
default: Optional[Any] = None
29+
label: Optional[str] = None

framework/workflow/core/workflow/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class Node:
4040
is_loop: bool = False # 添加 is_loop 标记
4141
parent: 'Node' = None
4242
spec: BlockSpec = None
43+
position: Optional[Dict[str, int]] = None
4344

4445
def __post_init__(self):
4546
self.next_nodes = self.next_nodes or []
@@ -339,6 +340,11 @@ def build(self, container: DependencyContainer) -> Workflow:
339340
block.container = container
340341

341342
return Workflow(self.name, self.blocks, self.wires)
343+
344+
def update_position(self, name: str, position: Tuple[int, int]):
345+
"""更新节点的位置"""
346+
node = self.nodes_by_name[name]
347+
node.position = position
342348

343349
def save_to_yaml(self, file_path: str, container: DependencyContainer):
344350
"""将工作流保存为 YAML 格式"""
@@ -356,7 +362,8 @@ def serialize_node(node: Node) -> dict:
356362
block_data = {
357363
'type': registry.get_block_type_name(node.block.__class__),
358364
'name': node.block.name,
359-
'params': node.spec.kwargs
365+
'params': node.spec.kwargs,
366+
'position': node.position
360367
}
361368

362369
if node.is_parallel:
@@ -443,7 +450,7 @@ def get_block_class(type_name: str) -> Type[Block]:
443450
builder.use(block_class, name=block_data['name'], **params)
444451
else:
445452
builder.chain(block_class, name=block_data['name'], **params)
446-
453+
builder.update_position(block_data['name'], block_data['position'])
447454
# 第二遍:建立连接
448455
for block_data in workflow_data['blocks']:
449456
if 'connected_to' in block_data:

framework/workflow/implementations/blocks/im/messages.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import asyncio
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, Optional, Annotated
3+
4+
from pydantic import Field
35
from framework.im.adapter import IMAdapter
46
from framework.im.manager import IMManager
57
from framework.im.message import IMMessage
68
from framework.im.sender import ChatSender
79
from framework.ioc.container import DependencyContainer
8-
from framework.workflow.core.block import Block
9-
from framework.workflow.core.block.input_output import Input
10-
from framework.workflow.core.block.input_output import Output
11-
10+
from framework.workflow.core.block import Block, Input, Output, ParamMeta
1211

1312
class GetIMMessage(Block):
1413
"""获取 IM 消息"""
@@ -30,7 +29,7 @@ class SendIMMessage(Block):
3029
outputs = {}
3130
container: DependencyContainer
3231

33-
def __init__(self, im_name: Optional[str] = None):
32+
def __init__(self, im_name: Annotated[Optional[str], ParamMeta(label="IM 适配器名称")] = None):
3433
self.im_name = im_name
3534

3635
def execute(self, msg: IMMessage) -> Dict[str, Any]:

framework/workflow/implementations/blocks/im/states.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import asyncio
2-
from typing import Any, Dict, Optional
2+
from typing import Annotated, Any, Dict
33
from framework.im.adapter import EditStateAdapter, IMAdapter
44
from framework.im.message import IMMessage
55
from framework.im.sender import ChatSender
66
from framework.ioc.container import DependencyContainer
7-
from framework.workflow.core.block import Block
8-
from framework.workflow.core.block.input_output import Input
9-
from framework.workflow.core.block.input_output import Output
7+
from framework.workflow.core.block import Block, Input, ParamMeta
108

119
# Toggle edit state
1210
class ToggleEditState(Block):
@@ -15,7 +13,7 @@ class ToggleEditState(Block):
1513
outputs = {}
1614
container: DependencyContainer
1715

18-
def __init__(self, is_editing: bool):
16+
def __init__(self, is_editing: Annotated[bool, ParamMeta(label="是否编辑", description="是否切换到编辑状态")]):
1917
self.is_editing = is_editing
2018

2119
def execute(self, sender: ChatSender) -> Dict[str, Any]:

framework/workflow/implementations/blocks/llm/chat.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Annotated, Any, Dict, List, Optional
22
import re
33
from framework.llm.format.message import LLMChatMessage
44
from framework.llm.format.request import LLMChatRequest
@@ -7,9 +7,7 @@
77
from framework.ioc.container import DependencyContainer
88
from framework.llm.llm_registry import LLMAbility
99
from framework.logger import get_logger
10-
from framework.workflow.core.block import Block
11-
from framework.workflow.core.block.input_output import Input
12-
from framework.workflow.core.block.input_output import Output
10+
from framework.workflow.core.block import Block, Input, Output, ParamMeta
1311
from framework.config.global_config import GlobalConfig
1412
from framework.im.message import IMMessage, TextMessage
1513
from framework.workflow.core.execution.executor import WorkflowExecutor
@@ -89,7 +87,7 @@ class ChatCompletion(Block):
8987
outputs = {"resp": Output("resp", "LLM 对话响应", LLMChatResponse, "LLM 对话响应")}
9088
container: DependencyContainer
9189

92-
def __init__(self, model_name: Optional[str] = None):
90+
def __init__(self, model_name: Annotated[Optional[str], ParamMeta(label="模型 ID", description="要使用的模型 ID")] = None):
9391
self.model_name = model_name
9492
self.logger = get_logger("ChatCompletionBlock")
9593

framework/workflow/implementations/blocks/memory/chat_memory.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Annotated, Any, Dict, List, Optional
22
from framework.im.message import IMMessage
33
from framework.im.sender import ChatSender
44
from framework.ioc.container import DependencyContainer
5-
from framework.workflow.core.block import Block
6-
from framework.workflow.core.block.input_output import Input
7-
from framework.workflow.core.block.input_output import Output
5+
from framework.workflow.core.block import Block, ParamMeta, Input, Output
86
from framework.memory.memory_manager import MemoryManager
97
from framework.memory.registry import ScopeRegistry, ComposerRegistry, DecomposerRegistry
108
from framework.llm.format.response import LLMChatResponse
@@ -15,7 +13,7 @@ class ChatMemoryQuery(Block):
1513
outputs = {"memory_content": Output("memory_content", "记忆内容", str, "记忆内容")}
1614
container: DependencyContainer
1715

18-
def __init__(self, scope_type: Optional[str] = None):
16+
def __init__(self, scope_type: Annotated[Optional[str], ParamMeta(label="级别", description="要查询记忆的级别")]):
1917
self.scope_type = scope_type
2018

2119

@@ -49,7 +47,7 @@ class ChatMemoryStore(Block):
4947
outputs = {}
5048
container: DependencyContainer
5149

52-
def __init__(self, scope_type: Optional[str] = None):
50+
def __init__(self, scope_type: Annotated[Optional[str], ParamMeta(label="级别", description="要查询记忆的级别")]):
5351
self.scope_type = scope_type
5452

5553
def execute(self, user_msg: IMMessage, llm_resp: LLMChatResponse) -> Dict[str, Any]:

0 commit comments

Comments
 (0)