@@ -1857,43 +1857,6 @@ impl<'a> TypeChecker<'a> {
1857
1857
args : & [ & Expr ] ,
1858
1858
lambda : & Lambda ,
1859
1859
) -> Result < Box < ( ScalarExpr , DataType ) > > {
1860
- if func_name. starts_with ( "json_" ) && !args. is_empty ( ) {
1861
- let target_type = if func_name. starts_with ( "json_array" ) {
1862
- TypeName :: Array ( Box :: new ( TypeName :: Nullable ( Box :: new ( TypeName :: Variant ) ) ) )
1863
- } else {
1864
- TypeName :: Map {
1865
- key_type : Box :: new ( TypeName :: String ) ,
1866
- val_type : Box :: new ( TypeName :: Nullable ( Box :: new ( TypeName :: Variant ) ) ) ,
1867
- }
1868
- } ;
1869
- let func_name = & func_name[ 5 ..] ;
1870
- let mut new_args: Vec < Expr > = args. iter ( ) . map ( |v| ( * v) . to_owned ( ) ) . collect ( ) ;
1871
- new_args[ 0 ] = Expr :: Cast {
1872
- span : new_args[ 0 ] . span ( ) ,
1873
- expr : Box :: new ( new_args[ 0 ] . clone ( ) ) ,
1874
- target_type,
1875
- pg_style : false ,
1876
- } ;
1877
-
1878
- let args: Vec < & Expr > = new_args. iter ( ) . collect ( ) ;
1879
- let result = self . resolve_lambda_function ( span, func_name, & args, lambda) ?;
1880
-
1881
- let target_type = if result. 1 . is_nullable ( ) {
1882
- DataType :: Variant . wrap_nullable ( )
1883
- } else {
1884
- DataType :: Variant
1885
- } ;
1886
-
1887
- let result_expr = ScalarExpr :: CastExpr ( CastExpr {
1888
- span : new_args[ 0 ] . span ( ) ,
1889
- is_try : false ,
1890
- argument : Box :: new ( result. 0 . clone ( ) ) ,
1891
- target_type : Box :: new ( target_type. clone ( ) ) ,
1892
- } ) ;
1893
-
1894
- return Ok ( Box :: new ( ( result_expr, target_type) ) ) ;
1895
- }
1896
-
1897
1860
if matches ! (
1898
1861
self . bind_context. expr_context,
1899
1862
ExprContext :: InLambdaFunction
@@ -1903,13 +1866,6 @@ impl<'a> TypeChecker<'a> {
1903
1866
)
1904
1867
. set_span ( span) ) ;
1905
1868
}
1906
- let params = lambda
1907
- . params
1908
- . iter ( )
1909
- . map ( |param| param. name . to_lowercase ( ) )
1910
- . collect :: < Vec < _ > > ( ) ;
1911
-
1912
- self . check_lambda_param_count ( func_name, params. len ( ) , span) ?;
1913
1869
1914
1870
if args. len ( ) != 1 {
1915
1871
return Err ( ErrorCode :: SemanticError ( format ! (
@@ -1919,7 +1875,46 @@ impl<'a> TypeChecker<'a> {
1919
1875
) )
1920
1876
. set_span ( span) ) ;
1921
1877
}
1922
- let box ( mut arg, arg_type) = self . resolve ( args[ 0 ] ) ?;
1878
+ let box ( mut arg, mut arg_type) = self . resolve ( args[ 0 ] ) ?;
1879
+
1880
+ let mut func_name = func_name;
1881
+ let mut is_cast_variant = false ;
1882
+ if arg_type. remove_nullable ( ) == DataType :: Variant {
1883
+ if func_name. starts_with ( "json_" ) {
1884
+ func_name = & func_name[ 5 ..] ;
1885
+ }
1886
+ // Try auto cast the Variant type to Array(Variant) or Map(String, Variant),
1887
+ // so that the lambda functions support variant type as argument.
1888
+ let mut target_type = if func_name. starts_with ( "array" ) {
1889
+ DataType :: Array ( Box :: new ( DataType :: Nullable ( Box :: new ( DataType :: Variant ) ) ) )
1890
+ } else {
1891
+ DataType :: Map ( Box :: new ( DataType :: Tuple ( vec ! [
1892
+ DataType :: String ,
1893
+ DataType :: Nullable ( Box :: new( DataType :: Variant ) ) ,
1894
+ ] ) ) )
1895
+ } ;
1896
+ if arg_type. is_nullable ( ) {
1897
+ target_type = target_type. wrap_nullable ( ) ;
1898
+ }
1899
+
1900
+ arg = ScalarExpr :: CastExpr ( CastExpr {
1901
+ span : None ,
1902
+ is_try : false ,
1903
+ argument : Box :: new ( arg. clone ( ) ) ,
1904
+ target_type : Box :: new ( target_type. clone ( ) ) ,
1905
+ } ) ;
1906
+ arg_type = target_type;
1907
+
1908
+ is_cast_variant = true ;
1909
+ }
1910
+
1911
+ let params = lambda
1912
+ . params
1913
+ . iter ( )
1914
+ . map ( |param| param. name . to_lowercase ( ) )
1915
+ . collect :: < Vec < _ > > ( ) ;
1916
+
1917
+ self . check_lambda_param_count ( func_name, params. len ( ) , span) ?;
1923
1918
1924
1919
let inner_ty = match arg_type. remove_nullable ( ) {
1925
1920
DataType :: Array ( box inner_ty) => inner_ty. clone ( ) ,
@@ -2134,7 +2129,22 @@ impl<'a> TypeChecker<'a> {
2134
2129
}
2135
2130
} ;
2136
2131
2137
- Ok ( Box :: new ( ( lambda_func, data_type) ) )
2132
+ if is_cast_variant {
2133
+ let result_target_type = if data_type. is_nullable ( ) {
2134
+ DataType :: Nullable ( Box :: new ( DataType :: Variant ) )
2135
+ } else {
2136
+ DataType :: Variant
2137
+ } ;
2138
+ let result_target_scalar = ScalarExpr :: CastExpr ( CastExpr {
2139
+ span : None ,
2140
+ is_try : false ,
2141
+ argument : Box :: new ( lambda_func) ,
2142
+ target_type : Box :: new ( result_target_type. clone ( ) ) ,
2143
+ } ) ;
2144
+ Ok ( Box :: new ( ( result_target_scalar, result_target_type) ) )
2145
+ } else {
2146
+ Ok ( Box :: new ( ( lambda_func, data_type) ) )
2147
+ }
2138
2148
}
2139
2149
2140
2150
fn check_lambda_param_count (
@@ -2768,6 +2778,12 @@ impl<'a> TypeChecker<'a> {
2768
2778
) ) ) ;
2769
2779
}
2770
2780
2781
+ if let Some ( rewritten_func_func) =
2782
+ self . try_rewrite_array_function ( span, func_name, & params, & mut args, & mut arg_types)
2783
+ {
2784
+ return rewritten_func_func;
2785
+ }
2786
+
2771
2787
self . resolve_scalar_function_call ( span, func_name, params, args)
2772
2788
}
2773
2789
@@ -3641,6 +3657,91 @@ impl<'a> TypeChecker<'a> {
3641
3657
}
3642
3658
}
3643
3659
3660
+ fn array_functions ( ) -> & ' static [ Ascii < & ' static str > ] {
3661
+ static ARRAY_FUNCTIONS : & [ Ascii < & ' static str > ] = & [
3662
+ Ascii :: new ( "array_count" ) ,
3663
+ Ascii :: new ( "array_max" ) ,
3664
+ Ascii :: new ( "array_min" ) ,
3665
+ Ascii :: new ( "array_any" ) ,
3666
+ Ascii :: new ( "array_approx_count_distinct" ) ,
3667
+ Ascii :: new ( "array_unique" ) ,
3668
+ Ascii :: new ( "array_sort_asc_null_first" ) ,
3669
+ Ascii :: new ( "array_sort_desc_null_first" ) ,
3670
+ Ascii :: new ( "array_sort_asc_null_last" ) ,
3671
+ Ascii :: new ( "array_sort_desc_null_last" ) ,
3672
+ Ascii :: new ( "array_remove_first" ) ,
3673
+ Ascii :: new ( "array_remove_last" ) ,
3674
+ Ascii :: new ( "array_distinct" ) ,
3675
+ ] ;
3676
+ ARRAY_FUNCTIONS
3677
+ }
3678
+
3679
+ fn try_rewrite_array_function (
3680
+ & mut self ,
3681
+ span : Span ,
3682
+ func_name : & str ,
3683
+ params : & [ Scalar ] ,
3684
+ args : & mut [ ScalarExpr ] ,
3685
+ arg_types : & mut [ DataType ] ,
3686
+ ) -> Option < Result < Box < ( ScalarExpr , DataType ) > > > {
3687
+ // Try auto cast the Variant type to Array(Variant),
3688
+ // so that the array functions support Variant type as argument.
3689
+ let uni_case_func_name = Ascii :: new ( func_name) ;
3690
+ if Self :: array_functions ( ) . contains ( & uni_case_func_name)
3691
+ && !arg_types. is_empty ( )
3692
+ && arg_types[ 0 ] . remove_nullable ( ) == DataType :: Variant
3693
+ {
3694
+ let target_type = if arg_types[ 0 ] . is_nullable ( ) {
3695
+ DataType :: Nullable ( Box :: new ( DataType :: Array ( Box :: new ( DataType :: Nullable (
3696
+ Box :: new ( DataType :: Variant ) ,
3697
+ ) ) ) ) )
3698
+ } else {
3699
+ DataType :: Array ( Box :: new ( DataType :: Nullable ( Box :: new ( DataType :: Variant ) ) ) )
3700
+ } ;
3701
+ let arg = args[ 0 ] . clone ( ) ;
3702
+ args[ 0 ] = ScalarExpr :: CastExpr ( CastExpr {
3703
+ span : None ,
3704
+ is_try : false ,
3705
+ argument : Box :: new ( arg) ,
3706
+ target_type : Box :: new ( target_type. clone ( ) ) ,
3707
+ } ) ;
3708
+ arg_types[ 0 ] = target_type;
3709
+
3710
+ let result =
3711
+ self . resolve_scalar_function_call ( span, func_name, params. to_vec ( ) , args. to_vec ( ) ) ;
3712
+ if func_name == "array_remove_first"
3713
+ || func_name == "array_remove_last"
3714
+ || func_name == "array_distinct"
3715
+ || func_name == "array_sort_asc_null_first"
3716
+ || func_name == "array_sort_desc_null_first"
3717
+ || func_name == "array_sort_asc_null_last"
3718
+ || func_name == "array_sort_desc_null_last"
3719
+ {
3720
+ if result. is_err ( ) {
3721
+ return Some ( result) ;
3722
+ }
3723
+ let box ( result_scalar, result_type) = result. unwrap ( ) ;
3724
+
3725
+ let result_target_type = if result_type. is_nullable ( ) {
3726
+ DataType :: Nullable ( Box :: new ( DataType :: Variant ) )
3727
+ } else {
3728
+ DataType :: Variant
3729
+ } ;
3730
+ let result_target_scalar = ScalarExpr :: CastExpr ( CastExpr {
3731
+ span : None ,
3732
+ is_try : false ,
3733
+ argument : Box :: new ( result_scalar) ,
3734
+ target_type : Box :: new ( result_target_type. clone ( ) ) ,
3735
+ } ) ;
3736
+ Some ( Ok ( Box :: new ( ( result_target_scalar, result_target_type) ) ) )
3737
+ } else {
3738
+ Some ( result)
3739
+ }
3740
+ } else {
3741
+ None
3742
+ }
3743
+ }
3744
+
3644
3745
fn resolve_trim_function (
3645
3746
& mut self ,
3646
3747
span : Span ,
0 commit comments