Skip to content

Commit 61ab9d0

Browse files
authored
core: Support uncorrelated EXISTS (#14474)
* Accept any uncorrelated plan when checking subquery correlation For the purpose of decorrelation, an uncorrelated plan is a unit. No verification needs to be performed on it. * Extract variable Extract variable from a long if condition involving a match. Improves readability. * Simplify control flow Handle the unhandled case returning immediately. This adds additional return point to the function, but removes subsequent if. At the point of this additional return we know why we bail out (some unhandled situation), later the None filter could be construed as a true condition. * Add more EXISTS SLT tests * Support uncorrelated EXISTS * fixup! Support uncorrelated EXISTS * fixup! Support uncorrelated EXISTS
1 parent d5ff3e7 commit 61ab9d0

File tree

6 files changed

+181
-75
lines changed

6 files changed

+181
-75
lines changed

datafusion/expr/src/logical_plan/builder.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,7 @@ pub fn table_scan_with_filter_and_fetch(
16021602
)
16031603
}
16041604

1605-
fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
1605+
pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
16061606
let table_schema = Arc::new(table_schema.clone());
16071607
Arc::new(LogicalTableSource { table_schema })
16081608
}

datafusion/expr/src/logical_plan/invariants.rs

+24-21
Original file line numberDiff line numberDiff line change
@@ -249,31 +249,26 @@ pub fn check_subquery_expr(
249249

250250
// Recursively check the unsupported outer references in the sub query plan.
251251
fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
252-
check_inner_plan(inner_plan, true)
252+
check_inner_plan(inner_plan)
253253
}
254254

255255
// Recursively check the unsupported outer references in the sub query plan.
256256
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
257-
fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> {
258-
if !can_contain_outer_ref && inner_plan.contains_outer_reference() {
259-
return plan_err!("Accessing outer reference columns is not allowed in the plan");
260-
}
257+
fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
261258
// We want to support as many operators as possible inside the correlated subquery
262259
match inner_plan {
263260
LogicalPlan::Aggregate(_) => {
264261
inner_plan.apply_children(|plan| {
265-
check_inner_plan(plan, can_contain_outer_ref)?;
262+
check_inner_plan(plan)?;
266263
Ok(TreeNodeRecursion::Continue)
267264
})?;
268265
Ok(())
269266
}
270-
LogicalPlan::Filter(Filter { input, .. }) => {
271-
check_inner_plan(input, can_contain_outer_ref)
272-
}
267+
LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
273268
LogicalPlan::Window(window) => {
274269
check_mixed_out_refer_in_window(window)?;
275270
inner_plan.apply_children(|plan| {
276-
check_inner_plan(plan, can_contain_outer_ref)?;
271+
check_inner_plan(plan)?;
277272
Ok(TreeNodeRecursion::Continue)
278273
})?;
279274
Ok(())
@@ -290,7 +285,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
290285
| LogicalPlan::SubqueryAlias(_)
291286
| LogicalPlan::Unnest(_) => {
292287
inner_plan.apply_children(|plan| {
293-
check_inner_plan(plan, can_contain_outer_ref)?;
288+
check_inner_plan(plan)?;
294289
Ok(TreeNodeRecursion::Continue)
295290
})?;
296291
Ok(())
@@ -303,7 +298,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
303298
}) => match join_type {
304299
JoinType::Inner => {
305300
inner_plan.apply_children(|plan| {
306-
check_inner_plan(plan, can_contain_outer_ref)?;
301+
check_inner_plan(plan)?;
307302
Ok(TreeNodeRecursion::Continue)
308303
})?;
309304
Ok(())
@@ -312,26 +307,34 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re
312307
| JoinType::LeftSemi
313308
| JoinType::LeftAnti
314309
| JoinType::LeftMark => {
315-
check_inner_plan(left, can_contain_outer_ref)?;
316-
check_inner_plan(right, false)
310+
check_inner_plan(left)?;
311+
check_no_outer_references(right)
317312
}
318313
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
319-
check_inner_plan(left, false)?;
320-
check_inner_plan(right, can_contain_outer_ref)
314+
check_no_outer_references(left)?;
315+
check_inner_plan(right)
321316
}
322317
JoinType::Full => {
323318
inner_plan.apply_children(|plan| {
324-
check_inner_plan(plan, false)?;
319+
check_no_outer_references(plan)?;
325320
Ok(TreeNodeRecursion::Continue)
326321
})?;
327322
Ok(())
328323
}
329324
},
330325
LogicalPlan::Extension(_) => Ok(()),
331-
_ => plan_err!(
332-
"Unsupported operator in the subquery plan: {}",
326+
plan => check_no_outer_references(plan),
327+
}
328+
}
329+
330+
fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
331+
if inner_plan.contains_outer_reference() {
332+
plan_err!(
333+
"Accessing outer reference columns is not allowed in the plan: {}",
333334
inner_plan.display()
334-
),
335+
)
336+
} else {
337+
Ok(())
335338
}
336339
}
337340

@@ -473,6 +476,6 @@ mod test {
473476
}),
474477
});
475478

476-
check_inner_plan(&plan, true).unwrap();
479+
check_inner_plan(&plan).unwrap();
477480
}
478481
}

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

+115-34
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
3333
use datafusion_expr::logical_plan::{JoinType, Subquery};
3434
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
3535
use datafusion_expr::{
36-
exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
36+
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
3737
LogicalPlan, LogicalPlanBuilder, Operator,
3838
};
3939

@@ -342,7 +342,7 @@ fn build_join(
342342
replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some)
343343
})?;
344344

345-
if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) {
345+
let join_filter = match (join_filter_opt, in_predicate_opt) {
346346
(
347347
Some(join_filter),
348348
Some(Expr::BinaryExpr(BinaryExpr {
@@ -353,9 +353,9 @@ fn build_join(
353353
) => {
354354
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
355355
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
356-
Some(in_predicate.and(join_filter))
356+
in_predicate.and(join_filter)
357357
}
358-
(Some(join_filter), _) => Some(join_filter),
358+
(Some(join_filter), _) => join_filter,
359359
(
360360
_,
361361
Some(Expr::BinaryExpr(BinaryExpr {
@@ -366,24 +366,23 @@ fn build_join(
366366
) => {
367367
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
368368
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
369-
Some(in_predicate)
369+
in_predicate
370370
}
371-
_ => None,
372-
} {
373-
// join our sub query into the main plan
374-
let new_plan = LogicalPlanBuilder::from(left.clone())
375-
.join_on(sub_query_alias, join_type, Some(join_filter))?
376-
.build()?;
377-
debug!(
378-
"predicate subquery optimized:\n{}",
379-
new_plan.display_indent()
380-
);
381-
Ok(Some(new_plan))
382-
} else {
383-
Ok(None)
384-
}
371+
(None, None) => lit(true),
372+
_ => return Ok(None),
373+
};
374+
// join our sub query into the main plan
375+
let new_plan = LogicalPlanBuilder::from(left.clone())
376+
.join_on(sub_query_alias, join_type, Some(join_filter))?
377+
.build()?;
378+
debug!(
379+
"predicate subquery optimized:\n{}",
380+
new_plan.display_indent()
381+
);
382+
Ok(Some(new_plan))
385383
}
386384

385+
#[derive(Debug)]
387386
struct SubqueryInfo {
388387
query: Subquery,
389388
where_in_expr: Option<Expr>,
@@ -429,6 +428,7 @@ mod tests {
429428
use crate::test::*;
430429

431430
use arrow::datatypes::{DataType, Field, Schema};
431+
use datafusion_expr::builder::table_source;
432432
use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan};
433433

434434
fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
@@ -1423,7 +1423,14 @@ mod tests {
14231423
.project(vec![col("customer.c_custkey")])?
14241424
.build()?;
14251425

1426-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
1426+
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1427+
\n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\
1428+
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1429+
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
1430+
\n Projection: orders.o_custkey [o_custkey:Int64]\
1431+
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1432+
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
1433+
assert_optimized_plan_equal(plan, expected)
14271434
}
14281435

14291436
/// Test for correlated exists subquery not equal
@@ -1608,14 +1615,14 @@ mod tests {
16081615
.project(vec![col("customer.c_custkey")])?
16091616
.build()?;
16101617

1611-
// not optimized
1612-
let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
1613-
Filter: EXISTS (<subquery>) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
1614-
Subquery: [o_custkey:Int64]
1615-
Projection: orders.o_custkey [o_custkey:Int64]
1616-
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1617-
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1618-
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
1618+
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
1619+
\n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\
1620+
\n LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\
1621+
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
1622+
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
1623+
\n Projection: orders.o_custkey [o_custkey:Int64]\
1624+
\n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
1625+
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
16191626

16201627
assert_optimized_plan_equal(plan, expected)
16211628
}
@@ -1654,7 +1661,13 @@ mod tests {
16541661
.project(vec![col("test.b")])?
16551662
.build()?;
16561663

1657-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
1664+
let expected = "Projection: test.b [b:UInt32]\
1665+
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1666+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1667+
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
1668+
\n Projection: sq.c [c:UInt32]\
1669+
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
1670+
assert_optimized_plan_equal(plan, expected)
16581671
}
16591672

16601673
/// Test for single NOT exists subquery filter
@@ -1666,7 +1679,13 @@ mod tests {
16661679
.project(vec![col("test.b")])?
16671680
.build()?;
16681681

1669-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
1682+
let expected = "Projection: test.b [b:UInt32]\
1683+
\n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1684+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1685+
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
1686+
\n Projection: sq.c [c:UInt32]\
1687+
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
1688+
assert_optimized_plan_equal(plan, expected)
16701689
}
16711690

16721691
#[test]
@@ -1750,12 +1769,12 @@ mod tests {
17501769

17511770
// Subquery and outer query refer to the same table.
17521771
let expected = "Projection: test.b [b:UInt32]\
1753-
\n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
1754-
\n Subquery: [c:UInt32]\
1772+
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1773+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1774+
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
17551775
\n Projection: test.c [c:UInt32]\
17561776
\n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
1757-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1758-
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
1777+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
17591778

17601779
assert_optimized_plan_equal(plan, expected)
17611780
}
@@ -1844,6 +1863,68 @@ mod tests {
18441863
assert_optimized_plan_equal(plan, expected)
18451864
}
18461865

1866+
#[test]
1867+
fn exists_uncorrelated_unnest() -> Result<()> {
1868+
let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1869+
"arr",
1870+
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
1871+
true,
1872+
)]));
1873+
let subquery = LogicalPlanBuilder::scan_with_filters(
1874+
"sq",
1875+
subquery_table_source,
1876+
None,
1877+
vec![],
1878+
)?
1879+
.unnest_column("arr")?
1880+
.build()?;
1881+
let table_scan = test_table_scan()?;
1882+
let plan = LogicalPlanBuilder::from(table_scan)
1883+
.filter(exists(Arc::new(subquery)))?
1884+
.project(vec![col("test.b")])?
1885+
.build()?;
1886+
1887+
let expected = "Projection: test.b [b:UInt32]\
1888+
\n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\
1889+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1890+
\n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\
1891+
\n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\
1892+
\n TableScan: sq [arr:List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";
1893+
assert_optimized_plan_equal(plan, expected)
1894+
}
1895+
1896+
#[test]
1897+
fn exists_correlated_unnest() -> Result<()> {
1898+
let table_scan = test_table_scan()?;
1899+
let subquery_table_source = table_source(&Schema::new(vec![Field::new(
1900+
"a",
1901+
DataType::List(Arc::new(Field::new_list_field(DataType::UInt32, true))),
1902+
true,
1903+
)]));
1904+
let subquery = LogicalPlanBuilder::scan_with_filters(
1905+
"sq",
1906+
subquery_table_source,
1907+
None,
1908+
vec![],
1909+
)?
1910+
.unnest_column("a")?
1911+
.filter(col("a").eq(out_ref_col(DataType::UInt32, "test.b")))?
1912+
.build()?;
1913+
let plan = LogicalPlanBuilder::from(table_scan)
1914+
.filter(exists(Arc::new(subquery)))?
1915+
.project(vec![col("test.b")])?
1916+
.build()?;
1917+
1918+
let expected = "Projection: test.b [b:UInt32]\
1919+
\n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\
1920+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
1921+
\n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\
1922+
\n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\
1923+
\n TableScan: sq [a:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]";
1924+
1925+
assert_optimized_plan_equal(plan, expected)
1926+
}
1927+
18471928
#[test]
18481929
fn upper_case_ident() -> Result<()> {
18491930
let fields = vec![

datafusion/sqllogictest/test_files/explain.slt

+11-7
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,17 @@ query TT
423423
explain select a from t1 where exists (select count(*) from t2);
424424
----
425425
logical_plan
426-
01)Filter: EXISTS (<subquery>)
427-
02)--Subquery:
428-
03)----Projection: count(*)
429-
04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]
430-
05)--------TableScan: t2
431-
06)--TableScan: t1 projection=[a]
432-
physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: <subquery>, negated: false })
426+
01)LeftSemi Join:
427+
02)--TableScan: t1 projection=[a]
428+
03)--SubqueryAlias: __correlated_sq_1
429+
04)----Projection:
430+
05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]
431+
06)--------TableScan: t2 projection=[]
432+
physical_plan
433+
01)NestedLoopJoinExec: join_type=LeftSemi
434+
02)--MemoryExec: partitions=1, partition_sizes=[0]
435+
03)--ProjectionExec: expr=[]
436+
04)----PlaceholderRowExec
433437

434438
statement ok
435439
drop table t1;

0 commit comments

Comments
 (0)