Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 6696706

Browse files
authored
Subquery Unnesting Agg NULL case workarounds (#257)
- Add an outer join with the deduplicated "left side", and a corresponding projection node, to pass along NULL values as expected. - Add a specific workaround in the projection node to case NULL -> 0 if we have a COUNT(*).
1 parent 3fd39a8 commit 6696706

File tree

9 files changed

+492
-232
lines changed

9 files changed

+492
-232
lines changed

optd-datafusion-repr/src/rules/subquery/depjoin_pushdown.rs

+90-13
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
44
// https://opensource.org/licenses/MIT.
55

6-
use optd_core::nodes::{PlanNodeOrGroup, PredNode};
7-
// TODO: No push past join
8-
// TODO: Sideways information passing??
6+
use datafusion_expr::{AggregateFunction, BuiltinScalarFunction};
7+
use optd_core::nodes::{PlanNodeOrGroup, PredNode, Value};
98
use optd_core::optimizer::Optimizer;
109
use optd_core::rules::{Rule, RuleMatcher};
1110

1211
use crate::plan_nodes::{
13-
ArcDfPlanNode, ArcDfPredNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, DependentJoin,
14-
DfNodeType, DfPredType, DfReprPlanNode, DfReprPredNode, ExternColumnRefPred, JoinType,
15-
ListPred, LogOpPred, LogOpType, LogicalAgg, LogicalFilter, LogicalJoin, LogicalProjection,
16-
PredExt, RawDependentJoin,
12+
ArcDfPlanNode, ArcDfPredNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType,
13+
DependentJoin, DfNodeType, DfPredType, DfReprPlanNode, DfReprPredNode, ExternColumnRefPred,
14+
FuncPred, FuncType, JoinType, ListPred, LogOpPred, LogOpType, LogicalAgg, LogicalFilter,
15+
LogicalJoin, LogicalProjection, PredExt, RawDependentJoin,
1716
};
1817
use crate::rules::macros::define_rule;
1918
use crate::OptimizerExt;
@@ -288,11 +287,8 @@ define_rule!(
288287
/// deduplicated set).
289288
/// For info on why we do the outer join, refer to the Unnesting Arbitrary Queries
290289
/// talk by Mark Raasveldt. The correlated columns are covered in the original paper.
291-
///
292-
/// TODO: the outer join is not implemented yet, so some edge cases won't work.
293-
/// Run SQLite tests to catch these, I guess.
294290
fn apply_dep_join_past_agg(
295-
_optimizer: &impl Optimizer<DfNodeType>,
291+
optimizer: &impl Optimizer<DfNodeType>,
296292
binding: ArcDfPlanNode,
297293
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
298294
let join = DependentJoin::from_plan_node(binding).unwrap();
@@ -305,6 +301,8 @@ fn apply_dep_join_past_agg(
305301
let groups = agg.groups();
306302
let right = agg.child();
307303

304+
let left_schema_size = optimizer.get_schema_of(left.clone()).len();
305+
308306
// Cross join should always have true cond
309307
assert!(cond == ConstantPred::bool(true).into_pred_node());
310308

@@ -345,11 +343,90 @@ fn apply_dep_join_past_agg(
345343
);
346344

347345
let new_dep_join =
348-
DependentJoin::new_unchecked(left, right, cond, extern_cols, JoinType::Cross);
346+
DependentJoin::new_unchecked(left.clone(), right, cond, extern_cols, JoinType::Cross);
349347

348+
let new_agg_exprs_size = new_exprs.len();
349+
let new_agg_groups_size = new_groups.len();
350+
let new_agg_schema_size = new_agg_groups_size + new_agg_exprs_size;
350351
let new_agg = LogicalAgg::new(new_dep_join.into_plan_node(), new_exprs, new_groups);
351352

352-
vec![new_agg.into_plan_node().into()]
353+
// Add left outer join above the agg node, joining the deduplicated set
354+
// with the new agg node.
355+
356+
// Both sides will have an agg now, so we want to match the correlated
357+
// columns from the left with those from the right
358+
let outer_join_cond = LogOpPred::new(
359+
LogOpType::And,
360+
correlated_col_indices
361+
.iter()
362+
.enumerate()
363+
.map(|(i, _)| {
364+
assert!(i + left_schema_size < left_schema_size + new_agg_schema_size);
365+
BinOpPred::new(
366+
ColumnRefPred::new(i).into_pred_node(),
367+
// We *prepend* the correlated columns to the groups list,
368+
// so we don't need to take into account the old
369+
// group-by expressions to get the corresponding correlated
370+
// column.
371+
ColumnRefPred::new(left_schema_size + i).into_pred_node(),
372+
BinOpType::Eq,
373+
)
374+
.into_pred_node()
375+
})
376+
.collect(),
377+
);
378+
379+
let new_outer_join = LogicalJoin::new_unchecked(
380+
left,
381+
new_agg.into_plan_node(),
382+
outer_join_cond.into_pred_node(),
383+
JoinType::LeftOuter,
384+
);
385+
386+
// We have to maintain the same schema above outer join as w/o it, but we
387+
// also need to use the groups from the deduplicated left side, and the
388+
// exprs from the new agg node. If we use everything from the new agg,
389+
// we don't maintain nulls as desired.
390+
let outer_join_proj = LogicalProjection::new(
391+
// The meaning is to take everything from the left side, and everything
392+
// from the right side *that is not in the left side*. I am unsure
393+
// of the correctness of this project in every case.
394+
new_outer_join.into_plan_node(),
395+
ListPred::new(
396+
(0..left_schema_size)
397+
.chain(left_schema_size + left_schema_size..left_schema_size + new_agg_schema_size)
398+
.map(|x| {
399+
// Count(*) special case: We want all NULLs to be transformed into 0s.
400+
if x >= left_schema_size + new_agg_groups_size {
401+
// If this node corresponds to an agg function, and
402+
// it's a count(*), apply the workaround
403+
let expr =
404+
exprs.to_vec()[x - left_schema_size - new_agg_groups_size].clone();
405+
if expr.typ == DfPredType::Func(FuncType::Agg(AggregateFunction::Count)) {
406+
let expr_child = expr.child(0).child(0);
407+
408+
if expr_child.typ == DfPredType::Constant(ConstantType::UInt8)
409+
&& expr_child.data == Some(Value::UInt8(1))
410+
{
411+
return FuncPred::new(
412+
FuncType::Scalar(BuiltinScalarFunction::Coalesce),
413+
ListPred::new(vec![
414+
ColumnRefPred::new(x).into_pred_node(),
415+
ConstantPred::int64(0).into_pred_node(),
416+
]),
417+
)
418+
.into_pred_node();
419+
}
420+
}
421+
}
422+
423+
ColumnRefPred::new(x).into_pred_node()
424+
})
425+
.collect(),
426+
),
427+
);
428+
429+
vec![outer_join_proj.into_plan_node().into()]
353430
}
354431

355432
// Heuristics-only rule. If we don't have references to the external columns on the right side,

optd-perfbench/src/datafusion_dbms.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ impl DatafusionDBMS {
153153

154154
let batches = df.collect().await?;
155155

156-
let options = FormatOptions::default();
156+
let options = FormatOptions::default().with_null("NULL");
157157

158158
for batch in batches {
159159
let converters = batch
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
include _basic_tables.slt.part
2+
3+
# This query has NULL values from the subquery agg. It won't work without the
4+
# outer join fix.
5+
# It also has an out-of-order extern column [#1]
6+
query
7+
select
8+
v1,
9+
v2,
10+
(
11+
select avg(v4)
12+
from t2
13+
where v4 = v2
14+
) as avg_v4
15+
from t1 order by v1;
16+
----
17+
1 100 NULL
18+
2 200 200.0
19+
2 250 250.0
20+
3 300 300.0
21+
3 300 300.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
include _basic_tables.slt.part
2+
3+
# This query uses a count(*) agg function, with nulls. Nulls should be
4+
# transformed from NULL to 0 when they come from count(*).
5+
# It won't work without the outer join fix + a special case on count(*).
6+
# It also has an out-of-order extern column [#1]
7+
query
8+
select
9+
v1,
10+
v2,
11+
(
12+
select count(*)
13+
from t2
14+
where v4 = v2
15+
) as avg_v4
16+
from t1 order by v1;
17+
----
18+
1 100 0
19+
2 200 1
20+
2 250 1
21+
3 300 1
22+
3 300 1

optd-sqllogictest/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl DatafusionDBMS {
107107
};
108108

109109
let batches = df.collect().await?;
110-
let options = FormatOptions::default();
110+
let options = FormatOptions::default().with_null("NULL");
111111

112112
for batch in batches {
113113
if types.is_empty() {

optd-sqlplannertest/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl DatafusionDBMS {
183183

184184
let batches = df.collect().await?;
185185

186-
let options = FormatOptions::default();
186+
let options = FormatOptions::default().with_null("NULL");
187187

188188
for batch in batches {
189189
let converters = batch

0 commit comments

Comments
 (0)