@@ -2,8 +2,8 @@ use std::ops::Bound;
2
2
use std:: { collections:: HashMap , sync:: Arc } ;
3
3
4
4
use crate :: plan_nodes:: {
5
- BinOpType , ColumnRefExpr , ConstantExpr , ConstantType , Expr , ExprList , LogOpExpr , LogOpType ,
6
- OptRelNode , UnOpType ,
5
+ BinOpType , ColumnRefExpr , ConstantExpr , ConstantType , Expr , ExprList , InListExpr , LogOpExpr ,
6
+ LogOpType , OptRelNode , UnOpType ,
7
7
} ;
8
8
use crate :: properties:: column_ref:: { ColumnRefPropertyBuilder , GroupColumnRefs } ;
9
9
use crate :: {
@@ -710,7 +710,10 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
710
710
OptRelNodeTyp :: DataType ( _) => {
711
711
panic ! ( "the selectivity of a data type is not defined" )
712
712
}
713
- OptRelNodeTyp :: InList => UNIMPLEMENTED_SEL ,
713
+ OptRelNodeTyp :: InList => {
714
+ let in_list_expr = InListExpr :: from_rel_node ( expr_tree) . unwrap ( ) ;
715
+ self . get_filter_in_list_selectivity ( & in_list_expr, column_refs)
716
+ }
714
717
_ => unreachable ! (
715
718
"all expression OptRelNodeTyp were enumerated. this should be unreachable"
716
719
) ,
@@ -1125,6 +1128,9 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
1125
1128
// always safe because usize is at least as large as i32
1126
1129
let ndistinct_as_usize = per_column_stats. ndistinct as usize ;
1127
1130
let non_mcv_cnt = ndistinct_as_usize - per_column_stats. mcvs . cnt ( ) ;
1131
+ if non_mcv_cnt == 0 {
1132
+ return 0.0 ;
1133
+ }
1128
1134
// note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - 1 if null_frac > 0
1129
1135
( non_mcv_freq - per_column_stats. null_frac ) / ( non_mcv_cnt as f64 )
1130
1136
} ;
@@ -1220,6 +1226,61 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
1220
1226
}
1221
1227
}
1222
1228
1229
+ /// Only support colA in (val1, val2, val3) where colA is a column ref and
1230
+ /// val1, val2, val3 are constants.
1231
+ fn get_filter_in_list_selectivity (
1232
+ & self ,
1233
+ expr : & InListExpr ,
1234
+ column_refs : & GroupColumnRefs ,
1235
+ ) -> f64 {
1236
+ let child = expr. child ( ) ;
1237
+
1238
+ // Check child is a column ref.
1239
+ if !matches ! ( child. typ( ) , OptRelNodeTyp :: ColumnRef ) {
1240
+ return UNIMPLEMENTED_SEL ;
1241
+ }
1242
+
1243
+ // Check all expressions in the list are constants.
1244
+ let list_exprs = expr. list ( ) . to_vec ( ) ;
1245
+ if list_exprs
1246
+ . iter ( )
1247
+ . any ( |expr| !matches ! ( expr. typ( ) , OptRelNodeTyp :: Constant ( _) ) )
1248
+ {
1249
+ return UNIMPLEMENTED_SEL ;
1250
+ }
1251
+
1252
+ // Convert child and const expressions to concrete types.
1253
+ let col_ref_idx = ColumnRefExpr :: from_rel_node ( child. into_rel_node ( ) )
1254
+ . unwrap ( )
1255
+ . index ( ) ;
1256
+ let list_exprs = list_exprs
1257
+ . into_iter ( )
1258
+ . map ( |expr| {
1259
+ ConstantExpr :: from_rel_node ( expr. into_rel_node ( ) )
1260
+ . expect ( "we already checked all list elements are constants" )
1261
+ } )
1262
+ . collect :: < Vec < _ > > ( ) ;
1263
+ let negated = expr. negated ( ) ;
1264
+
1265
+ if let ColumnRef :: BaseTableColumnRef { table, col_idx } = & column_refs[ col_ref_idx] {
1266
+ let in_sel = list_exprs
1267
+ . iter ( )
1268
+ . map ( |expr| {
1269
+ self . get_column_equality_selectivity ( table, * col_idx, & expr. value ( ) , true )
1270
+ } )
1271
+ . sum :: < f64 > ( )
1272
+ . min ( 1.0 ) ;
1273
+ if negated {
1274
+ 1.0 - in_sel
1275
+ } else {
1276
+ in_sel
1277
+ }
1278
+ } else {
1279
+ // Child is a derived column.
1280
+ UNIMPLEMENTED_SEL
1281
+ }
1282
+ }
1283
+
1223
1284
pub fn get_row_cnt ( & self , table : & str ) -> Option < usize > {
1224
1285
self . per_table_stats_map
1225
1286
. get ( table)
@@ -1241,14 +1302,15 @@ impl<M: MostCommonValues, D: Distribution> PerTableStats<M, D> {
1241
1302
/// and optd-datafusion-repr
1242
1303
#[ cfg( test) ]
1243
1304
mod tests {
1305
+ use itertools:: Itertools ;
1244
1306
use optd_core:: rel_node:: Value ;
1245
1307
use std:: collections:: HashMap ;
1246
1308
1247
1309
use crate :: {
1248
1310
cost:: base_cost:: DEFAULT_EQ_SEL ,
1249
1311
plan_nodes:: {
1250
- BinOpExpr , BinOpType , ColumnRefExpr , ConstantExpr , Expr , ExprList , JoinType , LogOpExpr ,
1251
- LogOpType , OptRelNode , OptRelNodeRef , UnOpExpr , UnOpType ,
1312
+ BinOpExpr , BinOpType , ColumnRefExpr , ConstantExpr , Expr , ExprList , InListExpr ,
1313
+ JoinType , LogOpExpr , LogOpType , OptRelNode , OptRelNodeRef , UnOpExpr , UnOpType ,
1252
1314
} ,
1253
1315
properties:: column_ref:: { ColumnRef , GroupColumnRefs } ,
1254
1316
} ;
@@ -1410,6 +1472,18 @@ mod tests {
1410
1472
. into_rel_node ( )
1411
1473
}
1412
1474
1475
+ fn in_list ( col_ref_idx : u64 , list : Vec < Value > , negated : bool ) -> InListExpr {
1476
+ InListExpr :: new (
1477
+ Expr :: from_rel_node ( col_ref ( col_ref_idx) ) . unwrap ( ) ,
1478
+ ExprList :: new (
1479
+ list. into_iter ( )
1480
+ . map ( |v| Expr :: from_rel_node ( cnst ( v) ) . unwrap ( ) )
1481
+ . collect_vec ( ) ,
1482
+ ) ,
1483
+ negated,
1484
+ )
1485
+ }
1486
+
1413
1487
/// The reason this isn't an associated function of PerColumnStats is because that would require
1414
1488
/// adding an empty() function to the trait definitions of MostCommonValues and Distribution,
1415
1489
/// which I wanted to avoid
@@ -1983,6 +2057,62 @@ mod tests {
1983
2057
) ;
1984
2058
}
1985
2059
2060
+ #[ test]
2061
+ fn test_filtersel_in_list ( ) {
2062
+ let cost_model = create_one_column_cost_model ( TestPerColumnStats :: new (
2063
+ TestMostCommonValues :: new ( vec ! [ ( Value :: Int32 ( 1 ) , 0.8 ) , ( Value :: Int32 ( 2 ) , 0.2 ) ] ) ,
2064
+ 2 ,
2065
+ 0.0 ,
2066
+ TestDistribution :: empty ( ) ,
2067
+ ) ) ;
2068
+ let column_refs = vec ! [ ColumnRef :: BaseTableColumnRef {
2069
+ table: String :: from( TABLE1_NAME ) ,
2070
+ col_idx: 0 ,
2071
+ } ] ;
2072
+ assert_approx_eq:: assert_approx_eq!(
2073
+ cost_model. get_filter_in_list_selectivity(
2074
+ & in_list( 0 , vec![ Value :: Int32 ( 1 ) ] , false ) ,
2075
+ & column_refs
2076
+ ) ,
2077
+ 0.8
2078
+ ) ;
2079
+ assert_approx_eq:: assert_approx_eq!(
2080
+ cost_model. get_filter_in_list_selectivity(
2081
+ & in_list( 0 , vec![ Value :: Int32 ( 1 ) , Value :: Int32 ( 2 ) ] , false ) ,
2082
+ & column_refs
2083
+ ) ,
2084
+ 1.0
2085
+ ) ;
2086
+ assert_approx_eq:: assert_approx_eq!(
2087
+ cost_model. get_filter_in_list_selectivity(
2088
+ & in_list( 0 , vec![ Value :: Int32 ( 3 ) ] , false ) ,
2089
+ & column_refs
2090
+ ) ,
2091
+ 0.0
2092
+ ) ;
2093
+ assert_approx_eq:: assert_approx_eq!(
2094
+ cost_model. get_filter_in_list_selectivity(
2095
+ & in_list( 0 , vec![ Value :: Int32 ( 1 ) ] , true ) ,
2096
+ & column_refs
2097
+ ) ,
2098
+ 0.2
2099
+ ) ;
2100
+ assert_approx_eq:: assert_approx_eq!(
2101
+ cost_model. get_filter_in_list_selectivity(
2102
+ & in_list( 0 , vec![ Value :: Int32 ( 1 ) , Value :: Int32 ( 2 ) ] , true ) ,
2103
+ & column_refs
2104
+ ) ,
2105
+ 0.0
2106
+ ) ;
2107
+ assert_approx_eq:: assert_approx_eq!(
2108
+ cost_model. get_filter_in_list_selectivity(
2109
+ & in_list( 0 , vec![ Value :: Int32 ( 3 ) ] , true ) ,
2110
+ & column_refs
2111
+ ) ,
2112
+ 1.0
2113
+ ) ;
2114
+ }
2115
+
1986
2116
/// A wrapper around get_join_selectivity_from_expr_tree that extracts the table row counts from the cost model
1987
2117
fn test_get_join_selectivity (
1988
2118
cost_model : & TestOptCostModel ,
0 commit comments