@@ -86,40 +86,43 @@ def wrapper(*args: Any, **kwargs: Any):
86
86
C = TypeVar ("C" , bound = Callable [..., Any ])
87
87
88
88
89
- def component (fn : C ) -> C :
90
- """Wraps a Python function to make it a user-defined component."""
91
-
92
- validated_fn = validate (fn )
93
-
94
- @wraps (fn )
95
- def wrapper (* args : Any , ** kw_args : Any ):
96
- prev_current_node = runtime ().context ().current_node ()
97
- component = prev_current_node .children .add ()
98
- source_code_location = None
99
- if runtime ().debug_mode :
100
- source_code_location = get_caller_source_code_location (levels = 2 )
101
- component .MergeFrom (
102
- create_component (
103
- component_name = get_component_name (fn ),
104
- proto = pb .UserDefinedType (
105
- args = [
106
- pb .UserDefinedType .Arg (
107
- arg_name = kw_arg , code_value = map_code_value (value )
108
- )
109
- for kw_arg , value in kw_args .items ()
110
- if map_code_value (value ) is not None
111
- ]
112
- ),
113
- source_code_location = source_code_location ,
89
+ def component (skip_validation : bool = False ):
90
+ def component_wrapper (fn : C ) -> C :
91
+ """Wraps a Python function to make it a user-defined component."""
92
+
93
+ validated_fn = fn if skip_validation else validate (fn )
94
+
95
+ @wraps (fn )
96
+ def wrapper (* args : Any , ** kw_args : Any ):
97
+ prev_current_node = runtime ().context ().current_node ()
98
+ component = prev_current_node .children .add ()
99
+ source_code_location = None
100
+ if runtime ().debug_mode :
101
+ source_code_location = get_caller_source_code_location (levels = 2 )
102
+ component .MergeFrom (
103
+ create_component (
104
+ component_name = get_component_name (fn ),
105
+ proto = pb .UserDefinedType (
106
+ args = [
107
+ pb .UserDefinedType .Arg (
108
+ arg_name = kw_arg , code_value = map_code_value (value )
109
+ )
110
+ for kw_arg , value in kw_args .items ()
111
+ if map_code_value (value ) is not None
112
+ ]
113
+ ),
114
+ source_code_location = source_code_location ,
115
+ )
114
116
)
115
- )
116
- runtime ().context ().set_current_node (component )
117
- ret = validated_fn (* args , ** kw_args )
118
- runtime ().context ().set_current_node (prev_current_node )
119
- return ret
117
+ runtime ().context ().set_current_node (component )
118
+ ret = validated_fn (* args , ** kw_args )
119
+ runtime ().context ().set_current_node (prev_current_node )
120
+ return ret
120
121
121
- runtime ().register_native_component_fn (fn )
122
- return cast (C , wrapper )
122
+ runtime ().register_native_component_fn (fn )
123
+ return cast (C , wrapper )
124
+
125
+ return component_wrapper
123
126
124
127
125
128
def get_component_name (fn : Callable [..., Any ]) -> pb .ComponentName :
0 commit comments