Skip to content

Commit 0d03fa7

Browse files
authored
feat(sql-udf): add support for named sql udf (#829)
* feat(function): add `CreateFunctionExecutor` & `bind_create_function` Signed-off-by: Michael Xu <xzhseh@gmail.com> * fix check Signed-off-by: Michael Xu <xzhseh@gmail.com> * feat(sql-udf): introduce UdfContext Signed-off-by: Michael Xu <xzhseh@gmail.com> * Revert "fix check" This reverts commit 0526e69. Signed-off-by: Michael Xu <xzhseh@gmail.com> * Revert "Revert "fix check"" This reverts commit 7ada184. Signed-off-by: Michael Xu <xzhseh@gmail.com> * Revert "fix check" This reverts commit 0526e69. Signed-off-by: Michael Xu <xzhseh@gmail.com> * Revert "feat(function): add `CreateFunctionExecutor` & `bind_create_function`" This reverts commit 8cc4b8e. Signed-off-by: Michael Xu <xzhseh@gmail.com> * tiny refactor & update Signed-off-by: Michael Xu <xzhseh@gmail.com> * fix format Signed-off-by: Michael Xu <xzhseh@gmail.com> * Update mod.rs Signed-off-by: Zihao Xu <xzhseh@gmail.com> * support named sql udf Signed-off-by: Michael Xu <xzhseh@gmail.com> * fix check Signed-off-by: Michael Xu <xzhseh@gmail.com> * add named sql udf related tests Signed-off-by: Michael Xu <xzhseh@gmail.com> * fix check Signed-off-by: Michael Xu <xzhseh@gmail.com> --------- Signed-off-by: Michael Xu <xzhseh@gmail.com> Signed-off-by: Zihao Xu <xzhseh@gmail.com>
1 parent 650c9c7 commit 0d03fa7

File tree

9 files changed

+246
-6
lines changed

9 files changed

+246
-6
lines changed

src/binder/create_function.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub struct CreateFunction {
1515
pub schema_name: String,
1616
pub name: String,
1717
pub arg_types: Vec<RlDataType>,
18+
pub arg_names: Vec<String>,
1819
pub return_type: RlDataType,
1920
pub language: String,
2021
pub body: String,
@@ -96,14 +97,17 @@ impl Binder {
9697
};
9798

9899
let mut arg_types = vec![];
100+
let mut arg_names = vec![];
99101
for arg in args.unwrap_or_default() {
100102
arg_types.push(RlDataType::new(DataTypeKind::from(&arg.data_type), false));
103+
arg_names.push(arg.name.map_or("".to_string(), |n| n.to_string()));
101104
}
102105

103106
let f = self.egraph.add(Node::CreateFunction(CreateFunction {
104107
schema_name,
105108
name,
106109
arg_types,
110+
arg_names,
107111
return_type,
108112
language,
109113
body,

src/binder/expr.rs

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0.
22

33
use rust_decimal::Decimal;
4+
use sqlparser::dialect::GenericDialect;
5+
use sqlparser::parser::Parser;
46

57
use super::*;
68
use crate::parser::{
@@ -86,6 +88,12 @@ impl Binder {
8688
[schema, table, column] => (Some(&schema.value), Some(&table.value), &column.value),
8789
_ => return Err(BindError::InvalidTableName(idents)),
8890
};
91+
92+
// Special check for sql udf
93+
if let Some(id) = self.udf_context.get_expr(column_name) {
94+
return Ok(*id);
95+
}
96+
8997
let map = self
9098
.current_ctx()
9199
.aliases
@@ -281,7 +289,7 @@ impl Binder {
281289

282290
fn bind_function(&mut self, func: Function) -> Result {
283291
let mut args = vec![];
284-
for arg in func.args {
292+
for arg in func.args.clone() {
285293
// ignore argument name
286294
let arg = match arg {
287295
FunctionArg::Named { arg, .. } => arg,
@@ -298,15 +306,69 @@ impl Binder {
298306
}
299307
}
300308

301-
// TODO: sql udf inlining
302-
let _catalog = self.catalog();
303-
let Ok((_schema_name, _function_name)) = split_name(&func.name) else {
309+
let catalog = self.catalog();
310+
let Ok((schema_name, function_name)) = split_name(&func.name) else {
304311
return Err(BindError::BindFunctionError(format!(
305312
"failed to parse the function name {}",
306313
func.name
307314
)));
308315
};
309316

317+
// See if the input function is sql udf
318+
if let Some(ref function_catalog) = catalog.get_function_by_name(schema_name, function_name)
319+
{
320+
// Create the brand new `udf_context`
321+
let Ok(context) =
322+
UdfContext::create_udf_context(func.args.as_slice(), function_catalog)
323+
else {
324+
return Err(BindError::InvalidExpression(
325+
"failed to create udf context".to_string(),
326+
));
327+
};
328+
329+
let mut udf_context = HashMap::new();
330+
// Bind each expression in the newly created `udf_context`
331+
for (c, e) in context {
332+
let Ok(e) = self.bind_expr(e) else {
333+
return Err(BindError::BindFunctionError(
334+
"failed to bind arguments within the given sql udf".to_string(),
335+
));
336+
};
337+
udf_context.insert(c, e);
338+
}
339+
340+
// Parse the sql body using `function_catalog`
341+
let dialect = GenericDialect {};
342+
let Ok(ast) = Parser::parse_sql(&dialect, &function_catalog.body) else {
343+
return Err(BindError::InvalidSQL);
344+
};
345+
346+
// Extract the corresponding udf expression out from `ast`
347+
let Ok(expr) = UdfContext::extract_udf_expression(ast) else {
348+
return Err(BindError::InvalidExpression(
349+
"failed to bind the sql udf expression".to_string(),
350+
));
351+
};
352+
353+
let stashed_udf_context = self.udf_context.get_context();
354+
355+
// Update the `udf_context` in `Binder` before binding
356+
self.udf_context.update_context(udf_context);
357+
358+
// Bind the expression in sql udf body
359+
let Ok(bind_result) = self.bind_expr(expr) else {
360+
return Err(BindError::InvalidExpression(
361+
"failed to bind the expression".to_string(),
362+
));
363+
};
364+
365+
// Restore the context after binding
366+
// to avoid affecting the potential subsequent binding(s)
367+
self.udf_context.update_context(stashed_udf_context);
368+
369+
return Ok(bind_result);
370+
}
371+
310372
let node = match func.name.to_string().to_lowercase().as_str() {
311373
"count" if args.is_empty() => Node::RowCount,
312374
"count" => Node::Count(args[0]),

src/binder/mod.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use egg::{Id, Language};
99
use itertools::Itertools;
1010

1111
use crate::array;
12+
use crate::catalog::function::FunctionCatalog;
1213
use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId, DEFAULT_SCHEMA_NAME};
1314
use crate::parser::*;
1415
use crate::planner::{Expr as Node, RecExpr, TypeError, TypeSchemaAnalysis};
@@ -97,6 +98,123 @@ pub struct Binder {
9798
contexts: Vec<Context>,
9899
/// The number of occurrences of each table in the query.
99100
table_occurrences: HashMap<TableRefId, u32>,
101+
/// The context used in sql udf binding
102+
udf_context: UdfContext,
103+
}
104+
105+
#[derive(Clone, Debug, Default)]
106+
pub struct UdfContext {
107+
/// The mapping from `sql udf parameters` to a bound `Id` generated from `ast
108+
/// expressions` Note: The expressions are constructed during runtime, correspond to the
109+
/// actual users' input
110+
udf_param_context: HashMap<String, Id>,
111+
112+
/// The global counter that records the calling stack depth
113+
/// of the current binding sql udf chain
114+
udf_global_counter: u32,
115+
}
116+
117+
impl UdfContext {
118+
pub fn new() -> Self {
119+
Self {
120+
udf_param_context: HashMap::new(),
121+
udf_global_counter: 0,
122+
}
123+
}
124+
125+
pub fn global_count(&self) -> u32 {
126+
self.udf_global_counter
127+
}
128+
129+
pub fn incr_global_count(&mut self) {
130+
self.udf_global_counter += 1;
131+
}
132+
133+
pub fn _is_empty(&self) -> bool {
134+
self.udf_param_context.is_empty()
135+
}
136+
137+
pub fn update_context(&mut self, context: HashMap<String, Id>) {
138+
self.udf_param_context = context;
139+
}
140+
141+
pub fn _clear(&mut self) {
142+
self.udf_global_counter = 0;
143+
self.udf_param_context.clear();
144+
}
145+
146+
pub fn get_expr(&self, name: &str) -> Option<&Id> {
147+
self.udf_param_context.get(name)
148+
}
149+
150+
pub fn get_context(&self) -> HashMap<String, Id> {
151+
self.udf_param_context.clone()
152+
}
153+
154+
/// A common utility function to extract sql udf
155+
/// expression out from the input `ast`
156+
pub fn extract_udf_expression(ast: Vec<Statement>) -> Result<Expr> {
157+
if ast.len() != 1 {
158+
return Err(BindError::InvalidExpression(
159+
"the query for sql udf should contain only one statement".to_string(),
160+
));
161+
}
162+
163+
// Extract the expression out
164+
let Statement::Query(query) = ast[0].clone() else {
165+
return Err(BindError::InvalidExpression(
166+
"invalid function definition, please recheck the syntax".to_string(),
167+
));
168+
};
169+
170+
let SetExpr::Select(select) = *query.body else {
171+
return Err(BindError::InvalidExpression(
172+
"missing `select` body for sql udf expression, please recheck the syntax"
173+
.to_string(),
174+
));
175+
};
176+
177+
if select.projection.len() != 1 {
178+
return Err(BindError::InvalidExpression(
179+
"`projection` should contain only one `SelectItem`".to_string(),
180+
));
181+
}
182+
183+
let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
184+
return Err(BindError::InvalidExpression(
185+
"expect `UnnamedExpr` for `projection`".to_string(),
186+
));
187+
};
188+
189+
Ok(expr)
190+
}
191+
192+
pub fn create_udf_context(
193+
args: &[FunctionArg],
194+
catalog: &Arc<FunctionCatalog>,
195+
) -> Result<HashMap<String, Expr>> {
196+
let mut ret: HashMap<String, Expr> = HashMap::new();
197+
for (i, current_arg) in args.iter().enumerate() {
198+
if let FunctionArg::Unnamed(_arg) = current_arg {
199+
match current_arg {
200+
FunctionArg::Unnamed(arg) => {
201+
let FunctionArgExpr::Expr(e) = arg else {
202+
return Err(BindError::InvalidExpression("invalid syntax".to_string()));
203+
};
204+
if catalog.arg_names[i].is_empty() {
205+
todo!("anonymous parameters not yet supported");
206+
} else {
207+
// The index mapping here is accurate
208+
// So that we could directly use the index
209+
ret.insert(catalog.arg_names[i].clone(), e.clone());
210+
}
211+
}
212+
_ => return Err(BindError::InvalidExpression("invalid syntax".to_string())),
213+
}
214+
}
215+
}
216+
Ok(ret)
217+
}
100218
}
101219

102220
pub fn bind_header(mut chunk: array::Chunk, stmt: &Statement) -> array::Chunk {
@@ -137,6 +255,7 @@ impl Binder {
137255
egraph: egg::EGraph::new(TypeSchemaAnalysis { catalog }),
138256
contexts: vec![Context::default()],
139257
table_occurrences: HashMap::new(),
258+
udf_context: UdfContext::new(),
140259
}
141260
}
142261

@@ -234,6 +353,10 @@ impl Binder {
234353
&self.egraph[id].nodes[0]
235354
}
236355

356+
fn _udf_context_mut(&mut self) -> &mut UdfContext {
357+
&mut self.udf_context
358+
}
359+
237360
fn catalog(&self) -> RootCatalogRef {
238361
self.catalog.clone()
239362
}

src/catalog/function.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::types::DataType;
66
pub struct FunctionCatalog {
77
pub name: String,
88
pub arg_types: Vec<DataType>,
9+
pub arg_names: Vec<String>,
910
pub return_type: DataType,
1011
pub language: String,
1112
pub body: String,
@@ -15,13 +16,15 @@ impl FunctionCatalog {
1516
pub fn new(
1617
name: String,
1718
arg_types: Vec<DataType>,
19+
arg_names: Vec<String>,
1820
return_type: DataType,
1921
language: String,
2022
body: String,
2123
) -> Self {
2224
Self {
2325
name,
2426
arg_types,
27+
arg_names,
2528
return_type,
2629
language,
2730
body,

src/catalog/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ static CONTRIBUTORS_TABLE_NAME: &str = "contributors";
1818
pub const CONTRIBUTORS_TABLE_ID: TableId = 0;
1919

2020
mod column;
21-
mod function;
21+
pub mod function;
2222
mod root;
2323
mod schema;
2424
mod table;

src/catalog/root.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,21 @@ impl RootCatalog {
110110
schema.get_function_by_name(function_name)
111111
}
112112

113+
#[allow(clippy::too_many_arguments)]
113114
pub fn create_function(
114115
&self,
115116
schema_name: String,
116117
name: String,
117118
arg_types: Vec<DataType>,
119+
arg_names: Vec<String>,
118120
return_type: DataType,
119121
language: String,
120122
body: String,
121123
) {
122124
let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap();
123125
let mut inner = self.inner.lock().unwrap();
124126
let schema = inner.schemas.get_mut(&schema_idx).unwrap();
125-
schema.create_function(name, arg_types, return_type, language, body);
127+
schema.create_function(name, arg_types, arg_names, return_type, language, body);
126128
}
127129
}
128130

src/catalog/schema.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ impl SchemaCatalog {
9494
&mut self,
9595
name: String,
9696
arg_types: Vec<DataType>,
97+
arg_names: Vec<String>,
9798
return_type: DataType,
9899
language: String,
99100
body: String,
@@ -103,6 +104,7 @@ impl SchemaCatalog {
103104
Arc::new(FunctionCatalog {
104105
name: name.clone(),
105106
arg_types,
107+
arg_names,
106108
return_type,
107109
language,
108110
body,

src/executor/create_function.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ impl CreateFunctionExecutor {
1717
schema_name,
1818
name,
1919
arg_types,
20+
arg_names,
2021
return_type,
2122
language,
2223
body,
@@ -26,6 +27,7 @@ impl CreateFunctionExecutor {
2627
schema_name.clone(),
2728
name.clone(),
2829
arg_types,
30+
arg_names,
2931
return_type,
3032
language,
3133
body,

0 commit comments

Comments
 (0)