3
3
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
4
4
// https://opensource.org/licenses/MIT.
5
5
6
- use optd_core:: nodes:: { PlanNodeOrGroup , PredNode } ;
7
- // TODO: No push past join
8
- // TODO: Sideways information passing??
6
+ use datafusion_expr:: { AggregateFunction , BuiltinScalarFunction } ;
7
+ use optd_core:: nodes:: { PlanNodeOrGroup , PredNode , Value } ;
9
8
use optd_core:: optimizer:: Optimizer ;
10
9
use optd_core:: rules:: { Rule , RuleMatcher } ;
11
10
12
11
use crate :: plan_nodes:: {
13
- ArcDfPlanNode , ArcDfPredNode , BinOpPred , BinOpType , ColumnRefPred , ConstantPred , DependentJoin ,
14
- DfNodeType , DfPredType , DfReprPlanNode , DfReprPredNode , ExternColumnRefPred , JoinType ,
15
- ListPred , LogOpPred , LogOpType , LogicalAgg , LogicalFilter , LogicalJoin , LogicalProjection ,
16
- PredExt , RawDependentJoin ,
12
+ ArcDfPlanNode , ArcDfPredNode , BinOpPred , BinOpType , ColumnRefPred , ConstantPred , ConstantType ,
13
+ DependentJoin , DfNodeType , DfPredType , DfReprPlanNode , DfReprPredNode , ExternColumnRefPred ,
14
+ FuncPred , FuncType , JoinType , ListPred , LogOpPred , LogOpType , LogicalAgg , LogicalFilter ,
15
+ LogicalJoin , LogicalProjection , PredExt , RawDependentJoin ,
17
16
} ;
18
17
use crate :: rules:: macros:: define_rule;
19
18
use crate :: OptimizerExt ;
@@ -288,11 +287,8 @@ define_rule!(
288
287
/// deduplicated set).
289
288
/// For info on why we do the outer join, refer to the Unnesting Arbitrary Queries
290
289
/// talk by Mark Raasveldt. The correlated columns are covered in the original paper.
291
- ///
292
- /// TODO: the outer join is not implemented yet, so some edge cases won't work.
293
- /// Run SQLite tests to catch these, I guess.
294
290
fn apply_dep_join_past_agg (
295
- _optimizer : & impl Optimizer < DfNodeType > ,
291
+ optimizer : & impl Optimizer < DfNodeType > ,
296
292
binding : ArcDfPlanNode ,
297
293
) -> Vec < PlanNodeOrGroup < DfNodeType > > {
298
294
let join = DependentJoin :: from_plan_node ( binding) . unwrap ( ) ;
@@ -305,6 +301,8 @@ fn apply_dep_join_past_agg(
305
301
let groups = agg. groups ( ) ;
306
302
let right = agg. child ( ) ;
307
303
304
+ let left_schema_size = optimizer. get_schema_of ( left. clone ( ) ) . len ( ) ;
305
+
308
306
// Cross join should always have true cond
309
307
assert ! ( cond == ConstantPred :: bool ( true ) . into_pred_node( ) ) ;
310
308
@@ -345,11 +343,90 @@ fn apply_dep_join_past_agg(
345
343
) ;
346
344
347
345
let new_dep_join =
348
- DependentJoin :: new_unchecked ( left, right, cond, extern_cols, JoinType :: Cross ) ;
346
+ DependentJoin :: new_unchecked ( left. clone ( ) , right, cond, extern_cols, JoinType :: Cross ) ;
349
347
348
+ let new_agg_exprs_size = new_exprs. len ( ) ;
349
+ let new_agg_groups_size = new_groups. len ( ) ;
350
+ let new_agg_schema_size = new_agg_groups_size + new_agg_exprs_size;
350
351
let new_agg = LogicalAgg :: new ( new_dep_join. into_plan_node ( ) , new_exprs, new_groups) ;
351
352
352
- vec ! [ new_agg. into_plan_node( ) . into( ) ]
353
+ // Add left outer join above the agg node, joining the deduplicated set
354
+ // with the new agg node.
355
+
356
+ // Both sides will have an agg now, so we want to match the correlated
357
+ // columns from the left with those from the right
358
+ let outer_join_cond = LogOpPred :: new (
359
+ LogOpType :: And ,
360
+ correlated_col_indices
361
+ . iter ( )
362
+ . enumerate ( )
363
+ . map ( |( i, _) | {
364
+ assert ! ( i + left_schema_size < left_schema_size + new_agg_schema_size) ;
365
+ BinOpPred :: new (
366
+ ColumnRefPred :: new ( i) . into_pred_node ( ) ,
367
+ // We *prepend* the correlated columns to the groups list,
368
+ // so we don't need to take into account the old
369
+ // group-by expressions to get the corresponding correlated
370
+ // column.
371
+ ColumnRefPred :: new ( left_schema_size + i) . into_pred_node ( ) ,
372
+ BinOpType :: Eq ,
373
+ )
374
+ . into_pred_node ( )
375
+ } )
376
+ . collect ( ) ,
377
+ ) ;
378
+
379
+ let new_outer_join = LogicalJoin :: new_unchecked (
380
+ left,
381
+ new_agg. into_plan_node ( ) ,
382
+ outer_join_cond. into_pred_node ( ) ,
383
+ JoinType :: LeftOuter ,
384
+ ) ;
385
+
386
+ // We have to maintain the same schema above outer join as w/o it, but we
387
+ // also need to use the groups from the deduplicated left side, and the
388
+ // exprs from the new agg node. If we use everything from the new agg,
389
+ // we don't maintain nulls as desired.
390
+ let outer_join_proj = LogicalProjection :: new (
391
+ // The meaning is to take everything from the left side, and everything
392
+ // from the right side *that is not in the left side*. I am unsure
393
+ // of the correctness of this project in every case.
394
+ new_outer_join. into_plan_node ( ) ,
395
+ ListPred :: new (
396
+ ( 0 ..left_schema_size)
397
+ . chain ( left_schema_size + left_schema_size..left_schema_size + new_agg_schema_size)
398
+ . map ( |x| {
399
+ // Count(*) special case: We want all NULLs to be transformed into 0s.
400
+ if x >= left_schema_size + new_agg_groups_size {
401
+ // If this node corresponds to an agg function, and
402
+ // it's a count(*), apply the workaround
403
+ let expr =
404
+ exprs. to_vec ( ) [ x - left_schema_size - new_agg_groups_size] . clone ( ) ;
405
+ if expr. typ == DfPredType :: Func ( FuncType :: Agg ( AggregateFunction :: Count ) ) {
406
+ let expr_child = expr. child ( 0 ) . child ( 0 ) ;
407
+
408
+ if expr_child. typ == DfPredType :: Constant ( ConstantType :: UInt8 )
409
+ && expr_child. data == Some ( Value :: UInt8 ( 1 ) )
410
+ {
411
+ return FuncPred :: new (
412
+ FuncType :: Scalar ( BuiltinScalarFunction :: Coalesce ) ,
413
+ ListPred :: new ( vec ! [
414
+ ColumnRefPred :: new( x) . into_pred_node( ) ,
415
+ ConstantPred :: int64( 0 ) . into_pred_node( ) ,
416
+ ] ) ,
417
+ )
418
+ . into_pred_node ( ) ;
419
+ }
420
+ }
421
+ }
422
+
423
+ ColumnRefPred :: new ( x) . into_pred_node ( )
424
+ } )
425
+ . collect ( ) ,
426
+ ) ,
427
+ ) ;
428
+
429
+ vec ! [ outer_join_proj. into_plan_node( ) . into( ) ]
353
430
}
354
431
355
432
// Heuristics-only rule. If we don't have references to the external columns on the right side,
0 commit comments