@@ -33,7 +33,7 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33
33
use datafusion_expr:: logical_plan:: { JoinType , Subquery } ;
34
34
use datafusion_expr:: utils:: { conjunction, split_conjunction_owned} ;
35
35
use datafusion_expr:: {
36
- exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr , Expr , Filter ,
36
+ exists, in_subquery, lit , not, not_exists, not_in_subquery, BinaryExpr , Expr , Filter ,
37
37
LogicalPlan , LogicalPlanBuilder , Operator ,
38
38
} ;
39
39
@@ -342,7 +342,7 @@ fn build_join(
342
342
replace_qualified_name ( filter, & all_correlated_cols, & alias) . map ( Some )
343
343
} ) ?;
344
344
345
- if let Some ( join_filter) = match ( join_filter_opt, in_predicate_opt) {
345
+ let join_filter = match ( join_filter_opt, in_predicate_opt) {
346
346
(
347
347
Some ( join_filter) ,
348
348
Some ( Expr :: BinaryExpr ( BinaryExpr {
@@ -353,9 +353,9 @@ fn build_join(
353
353
) => {
354
354
let right_col = create_col_from_scalar_expr ( right. deref ( ) , alias) ?;
355
355
let in_predicate = Expr :: eq ( left. deref ( ) . clone ( ) , Expr :: Column ( right_col) ) ;
356
- Some ( in_predicate. and ( join_filter) )
356
+ in_predicate. and ( join_filter)
357
357
}
358
- ( Some ( join_filter) , _) => Some ( join_filter) ,
358
+ ( Some ( join_filter) , _) => join_filter,
359
359
(
360
360
_,
361
361
Some ( Expr :: BinaryExpr ( BinaryExpr {
@@ -366,24 +366,23 @@ fn build_join(
366
366
) => {
367
367
let right_col = create_col_from_scalar_expr ( right. deref ( ) , alias) ?;
368
368
let in_predicate = Expr :: eq ( left. deref ( ) . clone ( ) , Expr :: Column ( right_col) ) ;
369
- Some ( in_predicate)
369
+ in_predicate
370
370
}
371
- _ => None ,
372
- } {
373
- // join our sub query into the main plan
374
- let new_plan = LogicalPlanBuilder :: from ( left. clone ( ) )
375
- . join_on ( sub_query_alias, join_type, Some ( join_filter) ) ?
376
- . build ( ) ?;
377
- debug ! (
378
- "predicate subquery optimized:\n {}" ,
379
- new_plan. display_indent( )
380
- ) ;
381
- Ok ( Some ( new_plan) )
382
- } else {
383
- Ok ( None )
384
- }
371
+ ( None , None ) => lit ( true ) ,
372
+ _ => return Ok ( None ) ,
373
+ } ;
374
+ // join our sub query into the main plan
375
+ let new_plan = LogicalPlanBuilder :: from ( left. clone ( ) )
376
+ . join_on ( sub_query_alias, join_type, Some ( join_filter) ) ?
377
+ . build ( ) ?;
378
+ debug ! (
379
+ "predicate subquery optimized:\n {}" ,
380
+ new_plan. display_indent( )
381
+ ) ;
382
+ Ok ( Some ( new_plan) )
385
383
}
386
384
385
+ #[ derive( Debug ) ]
387
386
struct SubqueryInfo {
388
387
query : Subquery ,
389
388
where_in_expr : Option < Expr > ,
@@ -429,6 +428,7 @@ mod tests {
429
428
use crate :: test:: * ;
430
429
431
430
use arrow:: datatypes:: { DataType , Field , Schema } ;
431
+ use datafusion_expr:: builder:: table_source;
432
432
use datafusion_expr:: { and, binary_expr, col, lit, not, out_ref_col, table_scan} ;
433
433
434
434
fn assert_optimized_plan_equal ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -1423,7 +1423,14 @@ mod tests {
1423
1423
. project ( vec ! [ col( "customer.c_custkey" ) ] ) ?
1424
1424
. build ( ) ?;
1425
1425
1426
- assert_optimization_skipped ( Arc :: new ( DecorrelatePredicateSubquery :: new ( ) ) , plan)
1426
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1427
+ \n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\
1428
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1429
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
1430
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
1431
+ \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1432
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1433
+ assert_optimized_plan_equal ( plan, expected)
1427
1434
}
1428
1435
1429
1436
/// Test for correlated exists subquery not equal
@@ -1608,14 +1615,14 @@ mod tests {
1608
1615
. project ( vec ! [ col( "customer.c_custkey" ) ] ) ?
1609
1616
. build ( ) ?;
1610
1617
1611
- // not optimized
1612
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
1613
- Filter: EXISTS (<subquery>) OR customer.c_custkey = Int32(1 ) [c_custkey:Int64, c_name:Utf8]
1614
- Subquery: [o_custkey :Int64]
1615
- Projection: orders.o_custkey [o_custkey:Int64]
1616
- Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1617
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1618
- TableScan: customer [c_custkey :Int64, c_name: Utf8]"# ;
1618
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64] \
1619
+ \n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] \
1620
+ \n LeftMark Join: Filter: Boolean(true ) [c_custkey:Int64, c_name:Utf8, mark:Boolean] \
1621
+ \n TableScan: customer [c_custkey :Int64, c_name:Utf8] \
1622
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
1623
+ \n Projection: orders.o_custkey [o_custkey:Int64] \
1624
+ \n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1625
+ \n TableScan: orders [o_orderkey :Int64, o_custkey:Int64, o_orderstatus: Utf8, o_totalprice:Float64;N]" ;
1619
1626
1620
1627
assert_optimized_plan_equal ( plan, expected)
1621
1628
}
@@ -1654,7 +1661,13 @@ mod tests {
1654
1661
. project ( vec ! [ col( "test.b" ) ] ) ?
1655
1662
. build ( ) ?;
1656
1663
1657
- assert_optimization_skipped ( Arc :: new ( DecorrelatePredicateSubquery :: new ( ) ) , plan)
1664
+ let expected = "Projection: test.b [b:UInt32]\
1665
+ \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1666
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1667
+ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
1668
+ \n Projection: sq.c [c:UInt32]\
1669
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
1670
+ assert_optimized_plan_equal ( plan, expected)
1658
1671
}
1659
1672
1660
1673
/// Test for single NOT exists subquery filter
@@ -1666,7 +1679,13 @@ mod tests {
1666
1679
. project ( vec ! [ col( "test.b" ) ] ) ?
1667
1680
. build ( ) ?;
1668
1681
1669
- assert_optimization_skipped ( Arc :: new ( DecorrelatePredicateSubquery :: new ( ) ) , plan)
1682
+ let expected = "Projection: test.b [b:UInt32]\
1683
+ \n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1684
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1685
+ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
1686
+ \n Projection: sq.c [c:UInt32]\
1687
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
1688
+ assert_optimized_plan_equal ( plan, expected)
1670
1689
}
1671
1690
1672
1691
#[ test]
@@ -1750,12 +1769,12 @@ mod tests {
1750
1769
1751
1770
// Subquery and outer query refer to the same table.
1752
1771
let expected = "Projection: test.b [b:UInt32]\
1753
- \n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
1754
- \n Subquery: [c:UInt32]\
1772
+ \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1773
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1774
+ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
1755
1775
\n Projection: test.c [c:UInt32]\
1756
1776
\n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
1757
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1758
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
1777
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
1759
1778
1760
1779
assert_optimized_plan_equal ( plan, expected)
1761
1780
}
@@ -1844,6 +1863,68 @@ mod tests {
1844
1863
assert_optimized_plan_equal ( plan, expected)
1845
1864
}
1846
1865
1866
+ #[ test]
1867
+ fn exists_uncorrelated_unnest ( ) -> Result < ( ) > {
1868
+ let subquery_table_source = table_source ( & Schema :: new ( vec ! [ Field :: new(
1869
+ "arr" ,
1870
+ DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: Int32 , true ) ) ) ,
1871
+ true ,
1872
+ ) ] ) ) ;
1873
+ let subquery = LogicalPlanBuilder :: scan_with_filters (
1874
+ "sq" ,
1875
+ subquery_table_source,
1876
+ None ,
1877
+ vec ! [ ] ,
1878
+ ) ?
1879
+ . unnest_column ( "arr" ) ?
1880
+ . build ( ) ?;
1881
+ let table_scan = test_table_scan ( ) ?;
1882
+ let plan = LogicalPlanBuilder :: from ( table_scan)
1883
+ . filter ( exists ( Arc :: new ( subquery) ) ) ?
1884
+ . project ( vec ! [ col( "test.b" ) ] ) ?
1885
+ . build ( ) ?;
1886
+
1887
+ let expected = "Projection: test.b [b:UInt32]\
1888
+ \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1889
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1890
+ \n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\
1891
+ \n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\
1892
+ \n TableScan: sq [arr:List(Field { name: \" item\" , data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";
1893
+ assert_optimized_plan_equal ( plan, expected)
1894
+ }
1895
+
1896
+ #[ test]
1897
+ fn exists_correlated_unnest ( ) -> Result < ( ) > {
1898
+ let table_scan = test_table_scan ( ) ?;
1899
+ let subquery_table_source = table_source ( & Schema :: new ( vec ! [ Field :: new(
1900
+ "a" ,
1901
+ DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: UInt32 , true ) ) ) ,
1902
+ true ,
1903
+ ) ] ) ) ;
1904
+ let subquery = LogicalPlanBuilder :: scan_with_filters (
1905
+ "sq" ,
1906
+ subquery_table_source,
1907
+ None ,
1908
+ vec ! [ ] ,
1909
+ ) ?
1910
+ . unnest_column ( "a" ) ?
1911
+ . filter ( col ( "a" ) . eq ( out_ref_col ( DataType :: UInt32 , "test.b" ) ) ) ?
1912
+ . build ( ) ?;
1913
+ let plan = LogicalPlanBuilder :: from ( table_scan)
1914
+ . filter ( exists ( Arc :: new ( subquery) ) ) ?
1915
+ . project ( vec ! [ col( "test.b" ) ] ) ?
1916
+ . build ( ) ?;
1917
+
1918
+ let expected = "Projection: test.b [b:UInt32]\
1919
+ \n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\
1920
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1921
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\
1922
+ \n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\
1923
+ \n TableScan: sq [a:List(Field { name: \" item\" , data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";
1924
+
1925
+ assert_optimized_plan_equal ( plan, expected)
1926
+ }
1927
+
1847
1928
#[ test]
1848
1929
fn upper_case_ident ( ) -> Result < ( ) > {
1849
1930
let fields = vec ! [
0 commit comments