@@ -6,12 +6,20 @@ use crate::{
6
6
plan_nodes:: { OptRelNodeRef , OptRelNodeTyp } ,
7
7
properties:: column_ref:: ColumnRef ,
8
8
} ;
9
+ use arrow_schema:: { ArrowError , DataType } ;
10
+ use datafusion:: arrow:: array:: {
11
+ Array , BooleanArray , Date32Array , Decimal128Array , Float32Array , Float64Array , Int16Array ,
12
+ Int32Array , Int8Array , RecordBatch , RecordBatchIterator , RecordBatchReader , UInt16Array ,
13
+ UInt32Array , UInt8Array ,
14
+ } ;
9
15
use itertools:: Itertools ;
10
16
use optd_core:: {
11
17
cascades:: { CascadesOptimizer , RelNodeContext } ,
12
18
cost:: { Cost , CostModel } ,
13
19
rel_node:: { RelNode , RelNodeTyp , Value } ,
14
20
} ;
21
+ use optd_gungnir:: stats:: hyperloglog:: { self , HyperLogLog } ;
22
+ use optd_gungnir:: stats:: tdigest:: { self , TDigest } ;
15
23
16
24
fn compute_plan_node_cost < T : RelNodeTyp , C : CostModel < T > > (
17
25
model : & C ,
@@ -34,9 +42,207 @@ pub struct OptCostModel {
34
42
per_table_stats_map : BaseTableStats ,
35
43
}
36
44
45
+ struct MockMostCommonValues {
46
+ mcvs : HashMap < Value , f64 > ,
47
+ }
48
+
49
+ impl MockMostCommonValues {
50
+ pub fn empty ( ) -> Self {
51
+ MockMostCommonValues {
52
+ mcvs : HashMap :: new ( ) ,
53
+ }
54
+ }
55
+ }
56
+
57
+ impl MostCommonValues for MockMostCommonValues {
58
+ fn freq ( & self , value : & Value ) -> Option < f64 > {
59
+ self . mcvs . get ( value) . copied ( )
60
+ }
61
+
62
+ fn total_freq ( & self ) -> f64 {
63
+ self . mcvs . values ( ) . sum ( )
64
+ }
65
+
66
+ fn freq_over_pred ( & self , pred : Box < dyn Fn ( & Value ) -> bool > ) -> f64 {
67
+ self . mcvs
68
+ . iter ( )
69
+ . filter ( |( val, _) | pred ( val) )
70
+ . map ( |( _, freq) | freq)
71
+ . sum ( )
72
+ }
73
+
74
+ fn cnt ( & self ) -> usize {
75
+ self . mcvs . len ( )
76
+ }
77
+ }
78
+
37
79
pub struct PerTableStats {
38
80
row_cnt : usize ,
39
- per_column_stats_vec : Vec < PerColumnStats > ,
81
+ per_column_stats_vec : Vec < Option < PerColumnStats > > ,
82
+ }
83
+
84
+ impl PerTableStats {
85
+ pub fn from_record_batches < I : IntoIterator < Item = Result < RecordBatch , ArrowError > > > (
86
+ batch_iter : RecordBatchIterator < I > ,
87
+ ) -> anyhow:: Result < Self > {
88
+ let schema = batch_iter. schema ( ) ;
89
+ let col_types = schema
90
+ . fields ( )
91
+ . iter ( )
92
+ . map ( |f| f. data_type ( ) . clone ( ) )
93
+ . collect_vec ( ) ;
94
+ let col_cnt = col_types. len ( ) ;
95
+
96
+ let mut row_cnt = 0 ;
97
+ let mut mcvs = col_types
98
+ . iter ( )
99
+ . map ( |col_type| {
100
+ if Self :: is_type_supported ( col_type) {
101
+ Some ( MockMostCommonValues :: empty ( ) )
102
+ } else {
103
+ None
104
+ }
105
+ } )
106
+ . collect_vec ( ) ;
107
+ let mut distr = col_types
108
+ . iter ( )
109
+ . map ( |col_type| {
110
+ if Self :: is_type_supported ( col_type) {
111
+ Some ( TDigest :: new ( tdigest:: DEFAULT_COMPRESSION ) )
112
+ } else {
113
+ None
114
+ }
115
+ } )
116
+ . collect_vec ( ) ;
117
+ let mut hlls = vec ! [ HyperLogLog :: new( hyperloglog:: DEFAULT_PRECISION ) ; col_cnt] ;
118
+ let mut null_cnt = vec ! [ 0 ; col_cnt] ;
119
+
120
+ for batch in batch_iter {
121
+ let batch = batch?;
122
+ row_cnt += batch. num_rows ( ) ;
123
+
124
+ // Enumerate the columns.
125
+ for ( i, col) in batch. columns ( ) . iter ( ) . enumerate ( ) {
126
+ let col_type = & col_types[ i] ;
127
+ if Self :: is_type_supported ( col_type) {
128
+ // Update null cnt.
129
+ null_cnt[ i] += col. null_count ( ) ;
130
+
131
+ Self :: generate_stats_for_column ( col, col_type, & mut distr[ i] , & mut hlls[ i] ) ;
132
+ }
133
+ }
134
+ }
135
+
136
+ // Assemble the per-column stats.
137
+ let mut per_column_stats_vec = Vec :: with_capacity ( col_cnt) ;
138
+ for i in 0 ..col_cnt {
139
+ per_column_stats_vec. push ( if Self :: is_type_supported ( & col_types[ i] ) {
140
+ Some ( PerColumnStats {
141
+ mcvs : Box :: new ( mcvs[ i] . take ( ) . unwrap ( ) ) as Box < dyn MostCommonValues > ,
142
+ ndistinct : hlls[ i] . n_distinct ( ) ,
143
+ null_frac : null_cnt[ i] as f64 / row_cnt as f64 ,
144
+ distr : Box :: new ( distr[ i] . take ( ) . unwrap ( ) ) as Box < dyn Distribution > ,
145
+ } )
146
+ } else {
147
+ None
148
+ } ) ;
149
+ }
150
+ Ok ( Self {
151
+ row_cnt,
152
+ per_column_stats_vec,
153
+ } )
154
+ }
155
+
156
+ fn is_type_supported ( data_type : & DataType ) -> bool {
157
+ matches ! (
158
+ data_type,
159
+ DataType :: Boolean
160
+ | DataType :: Int8
161
+ | DataType :: Int16
162
+ | DataType :: Int32
163
+ | DataType :: UInt8
164
+ | DataType :: UInt16
165
+ | DataType :: UInt32
166
+ | DataType :: Float32
167
+ | DataType :: Float64
168
+ )
169
+ }
170
+
171
+ /// Generate statistics for a column.
172
+ fn generate_stats_for_column (
173
+ col : & Arc < dyn Array > ,
174
+ col_type : & DataType ,
175
+ distr : & mut Option < TDigest > ,
176
+ hll : & mut HyperLogLog ,
177
+ ) {
178
+ macro_rules! generate_stats_for_col {
179
+ ( { $col: expr, $distr: expr, $hll: expr, $array_type: path, $to_f64: ident } ) => { {
180
+ let array = $col. as_any( ) . downcast_ref:: <$array_type>( ) . unwrap( ) ;
181
+ // Filter out `None` values.
182
+ let values = array. iter( ) . filter_map( |x| x) . collect:: <Vec <_>>( ) ;
183
+
184
+ // Update distribution.
185
+ * $distr = {
186
+ let mut f64_values = values. iter( ) . map( |x| $to_f64( * x) ) . collect:: <Vec <_>>( ) ;
187
+ Some ( $distr. take( ) . unwrap( ) . merge_values( & mut f64_values) )
188
+ } ;
189
+
190
+ // Update hll.
191
+ $hll. aggregate( & values) ;
192
+ } } ;
193
+ }
194
+
195
+ /// Convert a value to f64 with no out of range or precision loss.
196
+ fn to_f64_safe < T : Into < f64 > > ( val : T ) -> f64 {
197
+ val. into ( )
198
+ }
199
+
200
+ /// Convert i128 to f64 with possible precision loss.
201
+ ///
202
+ /// Note: optd represents decimal with the significand as f64 (see `ConstantExpr::decimal`).
203
+ /// For instance 0.04 of type `Decimal128(15, 2)` is just 4.0, the type information
204
+ /// is discarded. Therefore we must use the significand to generate the statistics.
205
+ fn i128_to_f64 ( val : i128 ) -> f64 {
206
+ val as f64
207
+ }
208
+
209
+ match col_type {
210
+ DataType :: Boolean => {
211
+ generate_stats_for_col ! ( { col, distr, hll, BooleanArray , to_f64_safe } )
212
+ }
213
+ DataType :: Int8 => {
214
+ generate_stats_for_col ! ( { col, distr, hll, Int8Array , to_f64_safe } )
215
+ }
216
+ DataType :: Int16 => {
217
+ generate_stats_for_col ! ( { col, distr, hll, Int16Array , to_f64_safe } )
218
+ }
219
+ DataType :: Int32 => {
220
+ generate_stats_for_col ! ( { col, distr, hll, Int32Array , to_f64_safe } )
221
+ }
222
+ DataType :: UInt8 => {
223
+ generate_stats_for_col ! ( { col, distr, hll, UInt8Array , to_f64_safe } )
224
+ }
225
+ DataType :: UInt16 => {
226
+ generate_stats_for_col ! ( { col, distr, hll, UInt16Array , to_f64_safe } )
227
+ }
228
+ DataType :: UInt32 => {
229
+ generate_stats_for_col ! ( { col, distr, hll, UInt32Array , to_f64_safe } )
230
+ }
231
+ DataType :: Float32 => {
232
+ generate_stats_for_col ! ( { col, distr, hll, Float32Array , to_f64_safe } )
233
+ }
234
+ DataType :: Float64 => {
235
+ generate_stats_for_col ! ( { col, distr, hll, Float64Array , to_f64_safe } )
236
+ }
237
+ DataType :: Date32 => {
238
+ generate_stats_for_col ! ( { col, distr, hll, Date32Array , to_f64_safe } )
239
+ }
240
+ DataType :: Decimal128 ( _, _) => {
241
+ generate_stats_for_col ! ( { col, distr, hll, Decimal128Array , i128_to_f64 } )
242
+ }
243
+ _ => unreachable ! ( ) ,
244
+ }
245
+ }
40
246
}
41
247
42
248
pub struct PerColumnStats {
@@ -45,7 +251,7 @@ pub struct PerColumnStats {
45
251
46
252
// ndistinct _does_ include the values in mcvs
47
253
// ndistinct _does not_ include nulls
48
- ndistinct : i32 ,
254
+ ndistinct : u64 ,
49
255
50
256
// postgres uses null_frac instead of something like "num_nulls" so we'll follow suit
51
257
// my guess for why they use null_frac is because we only ever use the fraction of nulls, not the #
@@ -445,7 +651,8 @@ impl OptCostModel {
445
651
is_eq : bool ,
446
652
) -> Option < f64 > {
447
653
if let Some ( per_table_stats) = self . per_table_stats_map . get ( table) {
448
- if let Some ( per_column_stats) = per_table_stats. per_column_stats_vec . get ( col_idx) {
654
+ if let Some ( Some ( per_column_stats) ) = per_table_stats. per_column_stats_vec . get ( col_idx)
655
+ {
449
656
let eq_freq = if let Some ( freq) = per_column_stats. mcvs . freq ( value) {
450
657
freq
451
658
} else {
@@ -484,7 +691,8 @@ impl OptCostModel {
484
691
is_col_eq_val : bool ,
485
692
) -> Option < f64 > {
486
693
if let Some ( per_table_stats) = self . per_table_stats_map . get ( table) {
487
- if let Some ( per_column_stats) = per_table_stats. per_column_stats_vec . get ( col_idx) {
694
+ if let Some ( Some ( per_column_stats) ) = per_table_stats. per_column_stats_vec . get ( col_idx)
695
+ {
488
696
// because distr does not include the values in MCVs, we need to compute the CDFs there as well
489
697
// because nulls return false in any comparison, they are never included when computing range selectivity
490
698
let distr_leq_freq = per_column_stats. distr . cdf ( value) ;
@@ -555,7 +763,7 @@ impl OptCostModel {
555
763
}
556
764
557
765
impl PerTableStats {
558
- pub fn new ( row_cnt : usize , per_column_stats_vec : Vec < PerColumnStats > ) -> Self {
766
+ pub fn new ( row_cnt : usize , per_column_stats_vec : Vec < Option < PerColumnStats > > ) -> Self {
559
767
Self {
560
768
row_cnt,
561
769
per_column_stats_vec,
@@ -566,7 +774,7 @@ impl PerTableStats {
566
774
impl PerColumnStats {
567
775
pub fn new (
568
776
mcvs : Box < dyn MostCommonValues > ,
569
- ndistinct : i32 ,
777
+ ndistinct : u64 ,
570
778
null_frac : f64 ,
571
779
distr : Box < dyn Distribution > ,
572
780
) -> Self {
@@ -612,7 +820,7 @@ mod tests {
612
820
}
613
821
}
614
822
615
- fn empty ( ) -> Self {
823
+ pub fn empty ( ) -> Self {
616
824
MockMostCommonValues :: new ( vec ! [ ] )
617
825
}
618
826
}
@@ -664,7 +872,7 @@ mod tests {
664
872
OptCostModel :: new (
665
873
vec ! [ (
666
874
String :: from( TABLE1_NAME ) ,
667
- PerTableStats :: new( 100 , vec![ per_column_stats] ) ,
875
+ PerTableStats :: new( 100 , vec![ Some ( per_column_stats) ] ) ,
668
876
) ]
669
877
. into_iter ( )
670
878
. collect ( ) ,
0 commit comments