Skip to content

Commit 346966e

Browse files
authored
feat(query): array function support variant type as argument (#17428)
1 parent f8eec68 commit 346966e

File tree

2 files changed

+182
-46
lines changed

2 files changed

+182
-46
lines changed

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 147 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,43 +1857,6 @@ impl<'a> TypeChecker<'a> {
18571857
args: &[&Expr],
18581858
lambda: &Lambda,
18591859
) -> Result<Box<(ScalarExpr, DataType)>> {
1860-
if func_name.starts_with("json_") && !args.is_empty() {
1861-
let target_type = if func_name.starts_with("json_array") {
1862-
TypeName::Array(Box::new(TypeName::Nullable(Box::new(TypeName::Variant))))
1863-
} else {
1864-
TypeName::Map {
1865-
key_type: Box::new(TypeName::String),
1866-
val_type: Box::new(TypeName::Nullable(Box::new(TypeName::Variant))),
1867-
}
1868-
};
1869-
let func_name = &func_name[5..];
1870-
let mut new_args: Vec<Expr> = args.iter().map(|v| (*v).to_owned()).collect();
1871-
new_args[0] = Expr::Cast {
1872-
span: new_args[0].span(),
1873-
expr: Box::new(new_args[0].clone()),
1874-
target_type,
1875-
pg_style: false,
1876-
};
1877-
1878-
let args: Vec<&Expr> = new_args.iter().collect();
1879-
let result = self.resolve_lambda_function(span, func_name, &args, lambda)?;
1880-
1881-
let target_type = if result.1.is_nullable() {
1882-
DataType::Variant.wrap_nullable()
1883-
} else {
1884-
DataType::Variant
1885-
};
1886-
1887-
let result_expr = ScalarExpr::CastExpr(CastExpr {
1888-
span: new_args[0].span(),
1889-
is_try: false,
1890-
argument: Box::new(result.0.clone()),
1891-
target_type: Box::new(target_type.clone()),
1892-
});
1893-
1894-
return Ok(Box::new((result_expr, target_type)));
1895-
}
1896-
18971860
if matches!(
18981861
self.bind_context.expr_context,
18991862
ExprContext::InLambdaFunction
@@ -1903,13 +1866,6 @@ impl<'a> TypeChecker<'a> {
19031866
)
19041867
.set_span(span));
19051868
}
1906-
let params = lambda
1907-
.params
1908-
.iter()
1909-
.map(|param| param.name.to_lowercase())
1910-
.collect::<Vec<_>>();
1911-
1912-
self.check_lambda_param_count(func_name, params.len(), span)?;
19131869

19141870
if args.len() != 1 {
19151871
return Err(ErrorCode::SemanticError(format!(
@@ -1919,7 +1875,46 @@ impl<'a> TypeChecker<'a> {
19191875
))
19201876
.set_span(span));
19211877
}
1922-
let box (mut arg, arg_type) = self.resolve(args[0])?;
1878+
let box (mut arg, mut arg_type) = self.resolve(args[0])?;
1879+
1880+
let mut func_name = func_name;
1881+
let mut is_cast_variant = false;
1882+
if arg_type.remove_nullable() == DataType::Variant {
1883+
if func_name.starts_with("json_") {
1884+
func_name = &func_name[5..];
1885+
}
1886+
// Try auto cast the Variant type to Array(Variant) or Map(String, Variant),
1887+
// so that the lambda functions support variant type as argument.
1888+
let mut target_type = if func_name.starts_with("array") {
1889+
DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::Variant))))
1890+
} else {
1891+
DataType::Map(Box::new(DataType::Tuple(vec![
1892+
DataType::String,
1893+
DataType::Nullable(Box::new(DataType::Variant)),
1894+
])))
1895+
};
1896+
if arg_type.is_nullable() {
1897+
target_type = target_type.wrap_nullable();
1898+
}
1899+
1900+
arg = ScalarExpr::CastExpr(CastExpr {
1901+
span: None,
1902+
is_try: false,
1903+
argument: Box::new(arg.clone()),
1904+
target_type: Box::new(target_type.clone()),
1905+
});
1906+
arg_type = target_type;
1907+
1908+
is_cast_variant = true;
1909+
}
1910+
1911+
let params = lambda
1912+
.params
1913+
.iter()
1914+
.map(|param| param.name.to_lowercase())
1915+
.collect::<Vec<_>>();
1916+
1917+
self.check_lambda_param_count(func_name, params.len(), span)?;
19231918

19241919
let inner_ty = match arg_type.remove_nullable() {
19251920
DataType::Array(box inner_ty) => inner_ty.clone(),
@@ -2134,7 +2129,22 @@ impl<'a> TypeChecker<'a> {
21342129
}
21352130
};
21362131

2137-
Ok(Box::new((lambda_func, data_type)))
2132+
if is_cast_variant {
2133+
let result_target_type = if data_type.is_nullable() {
2134+
DataType::Nullable(Box::new(DataType::Variant))
2135+
} else {
2136+
DataType::Variant
2137+
};
2138+
let result_target_scalar = ScalarExpr::CastExpr(CastExpr {
2139+
span: None,
2140+
is_try: false,
2141+
argument: Box::new(lambda_func),
2142+
target_type: Box::new(result_target_type.clone()),
2143+
});
2144+
Ok(Box::new((result_target_scalar, result_target_type)))
2145+
} else {
2146+
Ok(Box::new((lambda_func, data_type)))
2147+
}
21382148
}
21392149

21402150
fn check_lambda_param_count(
@@ -2768,6 +2778,12 @@ impl<'a> TypeChecker<'a> {
27682778
)));
27692779
}
27702780

2781+
if let Some(rewritten_func_func) =
2782+
self.try_rewrite_array_function(span, func_name, &params, &mut args, &mut arg_types)
2783+
{
2784+
return rewritten_func_func;
2785+
}
2786+
27712787
self.resolve_scalar_function_call(span, func_name, params, args)
27722788
}
27732789

@@ -3641,6 +3657,91 @@ impl<'a> TypeChecker<'a> {
36413657
}
36423658
}
36433659

3660+
fn array_functions() -> &'static [Ascii<&'static str>] {
3661+
static ARRAY_FUNCTIONS: &[Ascii<&'static str>] = &[
3662+
Ascii::new("array_count"),
3663+
Ascii::new("array_max"),
3664+
Ascii::new("array_min"),
3665+
Ascii::new("array_any"),
3666+
Ascii::new("array_approx_count_distinct"),
3667+
Ascii::new("array_unique"),
3668+
Ascii::new("array_sort_asc_null_first"),
3669+
Ascii::new("array_sort_desc_null_first"),
3670+
Ascii::new("array_sort_asc_null_last"),
3671+
Ascii::new("array_sort_desc_null_last"),
3672+
Ascii::new("array_remove_first"),
3673+
Ascii::new("array_remove_last"),
3674+
Ascii::new("array_distinct"),
3675+
];
3676+
ARRAY_FUNCTIONS
3677+
}
3678+
3679+
fn try_rewrite_array_function(
3680+
&mut self,
3681+
span: Span,
3682+
func_name: &str,
3683+
params: &[Scalar],
3684+
args: &mut [ScalarExpr],
3685+
arg_types: &mut [DataType],
3686+
) -> Option<Result<Box<(ScalarExpr, DataType)>>> {
3687+
// Try auto cast the Variant type to Array(Variant),
3688+
// so that the array functions support Variant type as argument.
3689+
let uni_case_func_name = Ascii::new(func_name);
3690+
if Self::array_functions().contains(&uni_case_func_name)
3691+
&& !arg_types.is_empty()
3692+
&& arg_types[0].remove_nullable() == DataType::Variant
3693+
{
3694+
let target_type = if arg_types[0].is_nullable() {
3695+
DataType::Nullable(Box::new(DataType::Array(Box::new(DataType::Nullable(
3696+
Box::new(DataType::Variant),
3697+
)))))
3698+
} else {
3699+
DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::Variant))))
3700+
};
3701+
let arg = args[0].clone();
3702+
args[0] = ScalarExpr::CastExpr(CastExpr {
3703+
span: None,
3704+
is_try: false,
3705+
argument: Box::new(arg),
3706+
target_type: Box::new(target_type.clone()),
3707+
});
3708+
arg_types[0] = target_type;
3709+
3710+
let result =
3711+
self.resolve_scalar_function_call(span, func_name, params.to_vec(), args.to_vec());
3712+
if func_name == "array_remove_first"
3713+
|| func_name == "array_remove_last"
3714+
|| func_name == "array_distinct"
3715+
|| func_name == "array_sort_asc_null_first"
3716+
|| func_name == "array_sort_desc_null_first"
3717+
|| func_name == "array_sort_asc_null_last"
3718+
|| func_name == "array_sort_desc_null_last"
3719+
{
3720+
if result.is_err() {
3721+
return Some(result);
3722+
}
3723+
let box (result_scalar, result_type) = result.unwrap();
3724+
3725+
let result_target_type = if result_type.is_nullable() {
3726+
DataType::Nullable(Box::new(DataType::Variant))
3727+
} else {
3728+
DataType::Variant
3729+
};
3730+
let result_target_scalar = ScalarExpr::CastExpr(CastExpr {
3731+
span: None,
3732+
is_try: false,
3733+
argument: Box::new(result_scalar),
3734+
target_type: Box::new(result_target_type.clone()),
3735+
});
3736+
Some(Ok(Box::new((result_target_scalar, result_target_type))))
3737+
} else {
3738+
Some(result)
3739+
}
3740+
} else {
3741+
None
3742+
}
3743+
}
3744+
36443745
fn resolve_trim_function(
36453746
&mut self,
36463747
span: Span,

tests/sqllogictests/suites/query/functions/02_0061_function_array.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,46 @@ select json_array_transform(try_cast(col1 as Variant), a -> a::Int + col2), json
393393
NULL NULL
394394
[12,13] []
395395

396+
query TT
397+
select array_transform(try_cast(col1 as Variant), a -> a::Int + col2), array_filter(try_cast(col1 as Variant), a -> a::Int = col2) from t3;
398+
----
399+
[3,4,5] [2]
400+
[null,null] []
401+
NULL NULL
402+
[12,13] []
403+
396404
query TT
397405
select json_array_reduce([1,2,3,4]::Variant, (x, y) -> 3 + x + y), json_array_transform(parse_json('"aa"'), data -> CONCAT(data::String, 'bend'));
398406
----
399407
19 []
400408

409+
query TT
410+
select array_reduce([1,2,3,4]::Variant, (x, y) -> 3 + x + y), array_transform(parse_json('"aa"'), data -> CONCAT(data::String, 'bend'));
411+
----
412+
19 []
413+
414+
statement ok
415+
create or replace table t4(col1 Variant Null)
416+
417+
statement ok
418+
insert into t4 values('[3,2,1,1]'),('[4,5,null,true]'),(null),('[7,"c","d"]')
419+
420+
query TTTT
421+
select array_count(col1), array_max(col1), array_min(col1), array_approx_count_distinct(col1) from t4;
422+
----
423+
4 3 1 3
424+
4 null true 4
425+
NULL NULL NULL NULL
426+
3 "d" 7 3
427+
428+
query TTTTT
429+
select array_sort(col1), array_unique(col1), array_remove_first(col1), array_remove_last(col1), array_distinct(col1) from t4;
430+
----
431+
[1,1,2,3] 3 [2,1,1] [3,2,1] [3,2,1]
432+
[true,4,5,null] 4 [5,null,true] [4,5,null] [4,5,null,true]
433+
NULL NULL NULL NULL NULL
434+
[7,"c","d"] 3 ["c","d"] [7,"c"] [7,"c","d"]
435+
401436
query T
402437
SELECT arrays_zip(1, 'a', null);
403438
----

0 commit comments

Comments
 (0)