|
1 | 1 | from inspect import Parameter, signature |
2 | | -from typing import Dict, List, Tuple, Type, Optional |
| 2 | +from typing import Dict, List, Tuple, Type, Optional, get_origin, get_args, Annotated, Union |
3 | 3 | import warnings |
4 | 4 | from framework.workflow.core.block import Block |
| 5 | +from framework.workflow.core.block.param import ParamMeta |
5 | 6 | from .schema import BlockConfig, BlockInput, BlockOutput |
6 | 7 |
|
| 8 | +def extract_block_param(param): |
| 9 | + """ |
| 10 | + 提取 Block 参数信息,包括类型字符串、标签、是否必需、描述和默认值。 |
| 11 | + """ |
| 12 | + param_type = param.annotation |
| 13 | + required = True |
| 14 | + label = param.name |
| 15 | + description = None |
| 16 | + default = param.default if param.default != Parameter.empty else None |
| 17 | + |
| 18 | + if get_origin(param_type) is Annotated: |
| 19 | + args = get_args(param_type) |
| 20 | + if len(args) > 0: |
| 21 | + actual_type = args[0] |
| 22 | + metadata = args[1] if len(args) > 1 else None |
| 23 | + |
| 24 | + if isinstance(metadata, ParamMeta): |
| 25 | + label = metadata.label |
| 26 | + description = metadata.description |
| 27 | + |
| 28 | + # 递归调用 extract_block_param 处理实际类型 |
| 29 | + block_config = extract_block_param(Parameter(name=param.name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=actual_type, default=default)) |
| 30 | + type_string = block_config.type |
| 31 | + required = block_config.required # 继承 required 属性 |
| 32 | + else: |
| 33 | + type_string = "Any" |
| 34 | + elif get_origin(param_type) is Union: |
| 35 | + args = get_args(param_type) |
| 36 | + # 检查 Union 中是否包含 NoneType |
| 37 | + if type(None) in args: |
| 38 | + required = False |
| 39 | + # 移除 NoneType,并递归处理剩余的类型 |
| 40 | + non_none_args = [arg for arg in args if arg is not type(None)] |
| 41 | + if len(non_none_args) == 1: |
| 42 | + block_config = extract_block_param(Parameter(name=param.name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=non_none_args[0], default=default)) |
| 43 | + type_string = block_config.type |
| 44 | + else: |
| 45 | + # 如果 Union 中包含多个非 NoneType,则返回 Union 类型 |
| 46 | + type_string = f"Union[{', '.join(get_type_name(arg) for arg in non_none_args)}]" |
| 47 | + else: |
| 48 | + # 如果 Union 中不包含 NoneType,则直接返回 Union 类型 |
| 49 | + type_string = f"Union[{', '.join(get_type_name(arg) for arg in args)}]" |
| 50 | + else: |
| 51 | + type_string = get_type_name(param_type) |
| 52 | + |
| 53 | + return BlockConfig( |
| 54 | + name=param.name, # 设置名称 |
| 55 | + description=description, |
| 56 | + type=type_string, |
| 57 | + required=required, |
| 58 | + default=default, # 设置默认值 |
| 59 | + label=label |
| 60 | + ) |
| 61 | + |
| 62 | +def get_type_name(type_obj): |
| 63 | + """ |
| 64 | + 获取类型的名称。 |
| 65 | + """ |
| 66 | + if hasattr(type_obj, '__name__'): |
| 67 | + return type_obj.__name__ |
| 68 | + return str(type_obj) |
| 69 | + |
7 | 70 | class BlockRegistry: |
8 | 71 | """Block 注册表,用于管理所有已注册的 block""" |
9 | 72 |
|
@@ -102,21 +165,9 @@ def extract_block_info(self, block_type: Type[Block]) -> Tuple[Dict[str, BlockIn |
102 | 165 | if param.name in builtin_params: |
103 | 166 | continue |
104 | 167 |
|
105 | | - param_type = param.annotation |
106 | | - # 解 Optional[T] 类型 |
107 | | - if hasattr(param_type, '__args__') and param_type.__name__ == 'Optional': |
| 168 | + block_config = extract_block_param(param) |
108 | 169 |
|
109 | | - actual_type = param_type.__args__[0] |
110 | | - else: |
111 | | - actual_type = param_type |
112 | | - |
113 | | - configs[param.name] = BlockConfig( |
114 | | - name=param.name, |
115 | | - description='', # 暂时没有描述信息 |
116 | | - type=str(actual_type.__name__), |
117 | | - required=param.default == Parameter.empty, # 没有默认值则为必需 |
118 | | - default=param.default if param.default != Parameter.empty else None |
119 | | - ) |
| 170 | + configs[param.name] = block_config |
120 | 171 | return inputs, outputs, configs |
121 | 172 |
|
122 | 173 | def get_builtin_params(self) -> List[str]: |
|
0 commit comments