@@ -9,6 +9,7 @@ use egg::{Id, Language};
99use itertools:: Itertools ;
1010
1111use crate :: array;
12+ use crate :: catalog:: function:: FunctionCatalog ;
1213use crate :: catalog:: { RootCatalog , RootCatalogRef , TableRefId , DEFAULT_SCHEMA_NAME } ;
1314use crate :: parser:: * ;
1415use crate :: planner:: { Expr as Node , RecExpr , TypeError , TypeSchemaAnalysis } ;
@@ -97,6 +98,123 @@ pub struct Binder {
9798 contexts : Vec < Context > ,
9899 /// The number of occurrences of each table in the query.
99100 table_occurrences : HashMap < TableRefId , u32 > ,
101+ /// The context used in sql udf binding
102+ udf_context : UdfContext ,
103+ }
104+
105+ #[ derive( Clone , Debug , Default ) ]
106+ pub struct UdfContext {
107+ /// The mapping from `sql udf parameters` to a bound `Id` generated from `ast
108+ /// expressions` Note: The expressions are constructed during runtime, correspond to the
109+ /// actual users' input
110+ udf_param_context : HashMap < String , Id > ,
111+
112+ /// The global counter that records the calling stack depth
113+ /// of the current binding sql udf chain
114+ udf_global_counter : u32 ,
115+ }
116+
117+ impl UdfContext {
118+ pub fn new ( ) -> Self {
119+ Self {
120+ udf_param_context : HashMap :: new ( ) ,
121+ udf_global_counter : 0 ,
122+ }
123+ }
124+
125+ pub fn global_count ( & self ) -> u32 {
126+ self . udf_global_counter
127+ }
128+
129+ pub fn incr_global_count ( & mut self ) {
130+ self . udf_global_counter += 1 ;
131+ }
132+
133+ pub fn _is_empty ( & self ) -> bool {
134+ self . udf_param_context . is_empty ( )
135+ }
136+
137+ pub fn update_context ( & mut self , context : HashMap < String , Id > ) {
138+ self . udf_param_context = context;
139+ }
140+
141+ pub fn _clear ( & mut self ) {
142+ self . udf_global_counter = 0 ;
143+ self . udf_param_context . clear ( ) ;
144+ }
145+
146+ pub fn get_expr ( & self , name : & str ) -> Option < & Id > {
147+ self . udf_param_context . get ( name)
148+ }
149+
150+ pub fn get_context ( & self ) -> HashMap < String , Id > {
151+ self . udf_param_context . clone ( )
152+ }
153+
154+ /// A common utility function to extract sql udf
155+ /// expression out from the input `ast`
156+ pub fn extract_udf_expression ( ast : Vec < Statement > ) -> Result < Expr > {
157+ if ast. len ( ) != 1 {
158+ return Err ( BindError :: InvalidExpression (
159+ "the query for sql udf should contain only one statement" . to_string ( ) ,
160+ ) ) ;
161+ }
162+
163+ // Extract the expression out
164+ let Statement :: Query ( query) = ast[ 0 ] . clone ( ) else {
165+ return Err ( BindError :: InvalidExpression (
166+ "invalid function definition, please recheck the syntax" . to_string ( ) ,
167+ ) ) ;
168+ } ;
169+
170+ let SetExpr :: Select ( select) = * query. body else {
171+ return Err ( BindError :: InvalidExpression (
172+ "missing `select` body for sql udf expression, please recheck the syntax"
173+ . to_string ( ) ,
174+ ) ) ;
175+ } ;
176+
177+ if select. projection . len ( ) != 1 {
178+ return Err ( BindError :: InvalidExpression (
179+ "`projection` should contain only one `SelectItem`" . to_string ( ) ,
180+ ) ) ;
181+ }
182+
183+ let SelectItem :: UnnamedExpr ( expr) = select. projection [ 0 ] . clone ( ) else {
184+ return Err ( BindError :: InvalidExpression (
185+ "expect `UnnamedExpr` for `projection`" . to_string ( ) ,
186+ ) ) ;
187+ } ;
188+
189+ Ok ( expr)
190+ }
191+
192+ pub fn create_udf_context (
193+ args : & [ FunctionArg ] ,
194+ catalog : & Arc < FunctionCatalog > ,
195+ ) -> Result < HashMap < String , Expr > > {
196+ let mut ret: HashMap < String , Expr > = HashMap :: new ( ) ;
197+ for ( i, current_arg) in args. iter ( ) . enumerate ( ) {
198+ if let FunctionArg :: Unnamed ( _arg) = current_arg {
199+ match current_arg {
200+ FunctionArg :: Unnamed ( arg) => {
201+ let FunctionArgExpr :: Expr ( e) = arg else {
202+ return Err ( BindError :: InvalidExpression ( "invalid syntax" . to_string ( ) ) ) ;
203+ } ;
204+ if catalog. arg_names [ i] . is_empty ( ) {
205+ todo ! ( "anonymous parameters not yet supported" ) ;
206+ } else {
207+ // The index mapping here is accurate
208+ // So that we could directly use the index
209+ ret. insert ( catalog. arg_names [ i] . clone ( ) , e. clone ( ) ) ;
210+ }
211+ }
212+ _ => return Err ( BindError :: InvalidExpression ( "invalid syntax" . to_string ( ) ) ) ,
213+ }
214+ }
215+ }
216+ Ok ( ret)
217+ }
100218}
101219
102220pub fn bind_header ( mut chunk : array:: Chunk , stmt : & Statement ) -> array:: Chunk {
@@ -137,6 +255,7 @@ impl Binder {
137255 egraph : egg:: EGraph :: new ( TypeSchemaAnalysis { catalog } ) ,
138256 contexts : vec ! [ Context :: default ( ) ] ,
139257 table_occurrences : HashMap :: new ( ) ,
258+ udf_context : UdfContext :: new ( ) ,
140259 }
141260 }
142261
@@ -234,6 +353,10 @@ impl Binder {
234353 & self . egraph [ id] . nodes [ 0 ]
235354 }
236355
356+ fn _udf_context_mut ( & mut self ) -> & mut UdfContext {
357+ & mut self . udf_context
358+ }
359+
237360 fn catalog ( & self ) -> RootCatalogRef {
238361 self . catalog . clone ( )
239362 }
0 commit comments