1
1
use std:: { collections:: HashMap , sync:: Arc } ;
2
2
3
3
use crate :: plan_nodes:: {
4
- BinOpType , ColumnRefExpr , ConstantExpr , ConstantType , LogOpType , OptRelNode , UnOpType ,
4
+ BinOpType , ColumnRefExpr , ConstantExpr , ConstantType , ExprList , LogOpType , OptRelNode , UnOpType ,
5
5
} ;
6
6
use crate :: properties:: column_ref:: { ColumnRefPropertyBuilder , GroupColumnRefs } ;
7
7
use crate :: {
@@ -11,8 +11,8 @@ use crate::{
11
11
use arrow_schema:: { ArrowError , DataType } ;
12
12
use datafusion:: arrow:: array:: {
13
13
Array , BooleanArray , Date32Array , Decimal128Array , Float32Array , Float64Array , Int16Array ,
14
- Int32Array , Int8Array , RecordBatch , RecordBatchIterator , RecordBatchReader , UInt16Array ,
15
- UInt32Array , UInt8Array ,
14
+ Int32Array , Int8Array , RecordBatch , RecordBatchIterator , RecordBatchReader , StringArray ,
15
+ UInt16Array , UInt32Array , UInt8Array ,
16
16
} ;
17
17
use itertools:: Itertools ;
18
18
use optd_core:: {
@@ -22,6 +22,7 @@ use optd_core::{
22
22
} ;
23
23
use optd_gungnir:: stats:: hyperloglog:: { self , HyperLogLog } ;
24
24
use optd_gungnir:: stats:: tdigest:: { self , TDigest } ;
25
+ use optd_gungnir:: utils:: arith_encoder;
25
26
use serde:: { Deserialize , Serialize } ;
26
27
27
28
fn compute_plan_node_cost < T : RelNodeTyp , C : CostModel < T > > (
@@ -181,6 +182,7 @@ impl DataFusionPerTableStats {
181
182
| DataType :: UInt32
182
183
| DataType :: Float32
183
184
| DataType :: Float64
185
+ | DataType :: Utf8
184
186
)
185
187
}
186
188
@@ -222,6 +224,10 @@ impl DataFusionPerTableStats {
222
224
val as f64
223
225
}
224
226
227
+ fn str_to_f64 ( string : & str ) -> f64 {
228
+ arith_encoder:: encode ( string)
229
+ }
230
+
225
231
match col_type {
226
232
DataType :: Boolean => {
227
233
generate_stats_for_col ! ( { col, distr, hll, BooleanArray , to_f64_safe } )
@@ -256,6 +262,9 @@ impl DataFusionPerTableStats {
256
262
DataType :: Decimal128 ( _, _) => {
257
263
generate_stats_for_col ! ( { col, distr, hll, Decimal128Array , i128_to_f64 } )
258
264
}
265
+ DataType :: Utf8 => {
266
+ generate_stats_for_col ! ( { col, distr, hll, StringArray , str_to_f64 } )
267
+ }
259
268
_ => unreachable ! ( ) ,
260
269
}
261
270
}
@@ -323,6 +332,10 @@ const DEFAULT_EQ_SEL: f64 = 0.005;
323
332
const DEFAULT_INEQ_SEL : f64 = 0.3333333333333333 ;
324
333
// Default selectivity estimate for pattern-match operators such as LIKE
325
334
const DEFAULT_MATCH_SEL : f64 = 0.005 ;
335
+ // Default selectivity if we have no information
336
+ const DEFAULT_UNK_SEL : f64 = 0.005 ;
337
+ // Default n-distinct estimate for derived columns or columns lacking statistics
338
+ const DEFAULT_N_DISTINCT : u64 = 200 ;
326
339
327
340
const INVALID_SEL : f64 = 0.01 ;
328
341
@@ -401,37 +414,33 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
401
414
OptRelNodeTyp :: PhysicalEmptyRelation => Self :: cost ( 0.5 , 0.01 , 0.0 ) ,
402
415
OptRelNodeTyp :: PhysicalLimit => {
403
416
let ( row_cnt, compute_cost, _) = Self :: cost_tuple ( & children[ 0 ] ) ;
404
- let row_cnt = if let Some ( context) = context {
405
- if let Some ( optimizer) = optimizer {
406
- let mut fetch_expr =
407
- optimizer. get_all_group_bindings ( context. children_group_ids [ 2 ] , false ) ;
408
- assert ! (
409
- fetch_expr. len( ) == 1 ,
410
- "fetch expression should be the only expr in the group"
411
- ) ;
412
- let fetch_expr = fetch_expr. pop ( ) . unwrap ( ) ;
413
- assert ! (
414
- matches!(
415
- fetch_expr. typ,
416
- OptRelNodeTyp :: Constant ( ConstantType :: UInt64 )
417
- ) ,
418
- "fetch type can only be UInt64"
419
- ) ;
420
- let fetch = ConstantExpr :: from_rel_node ( fetch_expr)
421
- . unwrap ( )
422
- . value ( )
423
- . as_u64 ( ) ;
424
- // u64::MAX represents None
425
- if fetch == u64:: MAX {
426
- row_cnt
427
- } else {
428
- row_cnt. min ( fetch as f64 )
429
- }
417
+ let row_cnt = if let ( Some ( context) , Some ( optimizer) ) = ( context, optimizer) {
418
+ let mut fetch_expr =
419
+ optimizer. get_all_group_bindings ( context. children_group_ids [ 2 ] , false ) ;
420
+ assert ! (
421
+ fetch_expr. len( ) == 1 ,
422
+ "fetch expression should be the only expr in the group"
423
+ ) ;
424
+ let fetch_expr = fetch_expr. pop ( ) . unwrap ( ) ;
425
+ assert ! (
426
+ matches!(
427
+ fetch_expr. typ,
428
+ OptRelNodeTyp :: Constant ( ConstantType :: UInt64 )
429
+ ) ,
430
+ "fetch type can only be UInt64"
431
+ ) ;
432
+ let fetch = ConstantExpr :: from_rel_node ( fetch_expr)
433
+ . unwrap ( )
434
+ . value ( )
435
+ . as_u64 ( ) ;
436
+ // u64::MAX represents None
437
+ if fetch == u64:: MAX {
438
+ row_cnt
430
439
} else {
431
- panic ! ( "compute_cost() should not be called if optimizer is None" )
440
+ row_cnt . min ( fetch as f64 )
432
441
}
433
442
} else {
434
- panic ! ( "compute_cost() should not be called if context is None" )
443
+ ( row_cnt * DEFAULT_UNK_SEL ) . max ( 1.0 )
435
444
} ;
436
445
Self :: cost ( row_cnt, compute_cost, 0.0 )
437
446
}
@@ -499,10 +508,15 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
499
508
Self :: cost ( row_cnt, row_cnt * row_cnt. ln_1p ( ) . max ( 1.0 ) , 0.0 )
500
509
}
501
510
OptRelNodeTyp :: PhysicalAgg => {
502
- let ( row_cnt, _, _) = Self :: cost_tuple ( & children[ 0 ] ) ;
511
+ let child_row_cnt = Self :: row_cnt ( & children[ 0 ] ) ;
512
+ let row_cnt = self . get_agg_row_cnt ( context, optimizer, child_row_cnt) ;
503
513
let ( _, compute_cost_1, _) = Self :: cost_tuple ( & children[ 1 ] ) ;
504
514
let ( _, compute_cost_2, _) = Self :: cost_tuple ( & children[ 2 ] ) ;
505
- Self :: cost ( row_cnt, row_cnt * ( compute_cost_1 + compute_cost_2) , 0.0 )
515
+ Self :: cost (
516
+ row_cnt,
517
+ child_row_cnt * ( compute_cost_1 + compute_cost_2) ,
518
+ 0.0 ,
519
+ )
506
520
}
507
521
OptRelNodeTyp :: List => {
508
522
let compute_cost = children
@@ -544,6 +558,58 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
544
558
}
545
559
}
546
560
561
+ fn get_agg_row_cnt (
562
+ & self ,
563
+ context : Option < RelNodeContext > ,
564
+ optimizer : Option < & CascadesOptimizer < OptRelNodeTyp > > ,
565
+ child_row_cnt : f64 ,
566
+ ) -> f64 {
567
+ if let ( Some ( context) , Some ( optimizer) ) = ( context, optimizer) {
568
+ let group_by_id = context. children_group_ids [ 2 ] ;
569
+ let mut group_by_exprs: Vec < Arc < RelNode < OptRelNodeTyp > > > =
570
+ optimizer. get_all_group_bindings ( group_by_id, false ) ;
571
+ assert ! (
572
+ group_by_exprs. len( ) == 1 ,
573
+ "ExprList expression should be the only expression in the GROUP BY group"
574
+ ) ;
575
+ let group_by = group_by_exprs. pop ( ) . unwrap ( ) ;
576
+ let group_by = ExprList :: from_rel_node ( group_by) . unwrap ( ) ;
577
+ if group_by. is_empty ( ) {
578
+ 1.0
579
+ } else {
580
+ // Multiply the n-distinct of all the group by columns.
581
+ // TODO: improve with multi-dimensional n-distinct
582
+ let base_table_col_refs = optimizer
583
+ . get_property_by_group :: < ColumnRefPropertyBuilder > ( context. group_id , 1 ) ;
584
+ base_table_col_refs
585
+ . iter ( )
586
+ . take ( group_by. len ( ) )
587
+ . map ( |col_ref| match col_ref {
588
+ ColumnRef :: BaseTableColumnRef { table, col_idx } => {
589
+ let table_stats = self . per_table_stats_map . get ( table) ;
590
+ let column_stats = table_stats. map ( |table_stats| {
591
+ table_stats. per_column_stats_vec . get ( * col_idx) . unwrap ( )
592
+ } ) ;
593
+
594
+ if let Some ( Some ( column_stats) ) = column_stats {
595
+ column_stats. ndistinct as f64
596
+ } else {
597
+ // The column type is not supported or stats are missing.
598
+ DEFAULT_N_DISTINCT as f64
599
+ }
600
+ }
601
+ ColumnRef :: Derived => DEFAULT_N_DISTINCT as f64 ,
602
+ _ => panic ! (
603
+ "GROUP BY base table column ref must either be derived or base table"
604
+ ) ,
605
+ } )
606
+ . product ( )
607
+ }
608
+ } else {
609
+ ( child_row_cnt * DEFAULT_UNK_SEL ) . max ( 1.0 )
610
+ }
611
+ }
612
+
547
613
/// The expr_tree input must be a "mixed expression tree"
548
614
/// An "expression node" refers to a RelNode that returns true for is_expression()
549
615
/// A "full expression tree" is where every node in the tree is an expression node
0 commit comments