Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit e0ebd4a

Browse files
authored
feat: implement filter selectivity for InListExpr (#154)
Only support the form of `colA = (1, 2, 3)`, where `colA` is a base table column and `(1, 2, 3)` consists of all constants. The selectivity is computed by **adding the selectivity of individual elements together**.
1 parent 4e47fad commit e0ebd4a

File tree

3 files changed

+136
-10
lines changed

3 files changed

+136
-10
lines changed

optd-datafusion-repr/src/cost/base_cost.rs

+135-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::ops::Bound;
22
use std::{collections::HashMap, sync::Arc};
33

44
use crate::plan_nodes::{
5-
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, Expr, ExprList, LogOpExpr, LogOpType,
6-
OptRelNode, UnOpType,
5+
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, Expr, ExprList, InListExpr, LogOpExpr,
6+
LogOpType, OptRelNode, UnOpType,
77
};
88
use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs};
99
use crate::{
@@ -710,7 +710,10 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
710710
OptRelNodeTyp::DataType(_) => {
711711
panic!("the selectivity of a data type is not defined")
712712
}
713-
OptRelNodeTyp::InList => UNIMPLEMENTED_SEL,
713+
OptRelNodeTyp::InList => {
714+
let in_list_expr = InListExpr::from_rel_node(expr_tree).unwrap();
715+
self.get_filter_in_list_selectivity(&in_list_expr, column_refs)
716+
}
714717
_ => unreachable!(
715718
"all expression OptRelNodeTyp were enumerated. this should be unreachable"
716719
),
@@ -1125,6 +1128,9 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
11251128
// always safe because usize is at least as large as i32
11261129
let ndistinct_as_usize = per_column_stats.ndistinct as usize;
11271130
let non_mcv_cnt = ndistinct_as_usize - per_column_stats.mcvs.cnt();
1131+
if non_mcv_cnt == 0 {
1132+
return 0.0;
1133+
}
11281134
// note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - 1 if null_frac > 0
11291135
(non_mcv_freq - per_column_stats.null_frac) / (non_mcv_cnt as f64)
11301136
};
@@ -1220,6 +1226,61 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
12201226
}
12211227
}
12221228

1229+
/// Only support colA in (val1, val2, val3) where colA is a column ref and
1230+
/// val1, val2, val3 are constants.
1231+
fn get_filter_in_list_selectivity(
1232+
&self,
1233+
expr: &InListExpr,
1234+
column_refs: &GroupColumnRefs,
1235+
) -> f64 {
1236+
let child = expr.child();
1237+
1238+
// Check child is a column ref.
1239+
if !matches!(child.typ(), OptRelNodeTyp::ColumnRef) {
1240+
return UNIMPLEMENTED_SEL;
1241+
}
1242+
1243+
// Check all expressions in the list are constants.
1244+
let list_exprs = expr.list().to_vec();
1245+
if list_exprs
1246+
.iter()
1247+
.any(|expr| !matches!(expr.typ(), OptRelNodeTyp::Constant(_)))
1248+
{
1249+
return UNIMPLEMENTED_SEL;
1250+
}
1251+
1252+
// Convert child and const expressions to concrete types.
1253+
let col_ref_idx = ColumnRefExpr::from_rel_node(child.into_rel_node())
1254+
.unwrap()
1255+
.index();
1256+
let list_exprs = list_exprs
1257+
.into_iter()
1258+
.map(|expr| {
1259+
ConstantExpr::from_rel_node(expr.into_rel_node())
1260+
.expect("we already checked all list elements are constants")
1261+
})
1262+
.collect::<Vec<_>>();
1263+
let negated = expr.negated();
1264+
1265+
if let ColumnRef::BaseTableColumnRef { table, col_idx } = &column_refs[col_ref_idx] {
1266+
let in_sel = list_exprs
1267+
.iter()
1268+
.map(|expr| {
1269+
self.get_column_equality_selectivity(table, *col_idx, &expr.value(), true)
1270+
})
1271+
.sum::<f64>()
1272+
.min(1.0);
1273+
if negated {
1274+
1.0 - in_sel
1275+
} else {
1276+
in_sel
1277+
}
1278+
} else {
1279+
// Child is a derived column.
1280+
UNIMPLEMENTED_SEL
1281+
}
1282+
}
1283+
12231284
pub fn get_row_cnt(&self, table: &str) -> Option<usize> {
12241285
self.per_table_stats_map
12251286
.get(table)
@@ -1241,14 +1302,15 @@ impl<M: MostCommonValues, D: Distribution> PerTableStats<M, D> {
12411302
/// and optd-datafusion-repr
12421303
#[cfg(test)]
12431304
mod tests {
1305+
use itertools::Itertools;
12441306
use optd_core::rel_node::Value;
12451307
use std::collections::HashMap;
12461308

12471309
use crate::{
12481310
cost::base_cost::DEFAULT_EQ_SEL,
12491311
plan_nodes::{
1250-
BinOpExpr, BinOpType, ColumnRefExpr, ConstantExpr, Expr, ExprList, JoinType, LogOpExpr,
1251-
LogOpType, OptRelNode, OptRelNodeRef, UnOpExpr, UnOpType,
1312+
BinOpExpr, BinOpType, ColumnRefExpr, ConstantExpr, Expr, ExprList, InListExpr,
1313+
JoinType, LogOpExpr, LogOpType, OptRelNode, OptRelNodeRef, UnOpExpr, UnOpType,
12521314
},
12531315
properties::column_ref::{ColumnRef, GroupColumnRefs},
12541316
};
@@ -1410,6 +1472,18 @@ mod tests {
14101472
.into_rel_node()
14111473
}
14121474

1475+
fn in_list(col_ref_idx: u64, list: Vec<Value>, negated: bool) -> InListExpr {
1476+
InListExpr::new(
1477+
Expr::from_rel_node(col_ref(col_ref_idx)).unwrap(),
1478+
ExprList::new(
1479+
list.into_iter()
1480+
.map(|v| Expr::from_rel_node(cnst(v)).unwrap())
1481+
.collect_vec(),
1482+
),
1483+
negated,
1484+
)
1485+
}
1486+
14131487
/// The reason this isn't an associated function of PerColumnStats is because that would require
14141488
/// adding an empty() function to the trait definitions of MostCommonValues and Distribution,
14151489
/// which I wanted to avoid
@@ -1983,6 +2057,62 @@ mod tests {
19832057
);
19842058
}
19852059

2060+
#[test]
2061+
fn test_filtersel_in_list() {
2062+
let cost_model = create_one_column_cost_model(TestPerColumnStats::new(
2063+
TestMostCommonValues::new(vec![(Value::Int32(1), 0.8), (Value::Int32(2), 0.2)]),
2064+
2,
2065+
0.0,
2066+
TestDistribution::empty(),
2067+
));
2068+
let column_refs = vec![ColumnRef::BaseTableColumnRef {
2069+
table: String::from(TABLE1_NAME),
2070+
col_idx: 0,
2071+
}];
2072+
assert_approx_eq::assert_approx_eq!(
2073+
cost_model.get_filter_in_list_selectivity(
2074+
&in_list(0, vec![Value::Int32(1)], false),
2075+
&column_refs
2076+
),
2077+
0.8
2078+
);
2079+
assert_approx_eq::assert_approx_eq!(
2080+
cost_model.get_filter_in_list_selectivity(
2081+
&in_list(0, vec![Value::Int32(1), Value::Int32(2)], false),
2082+
&column_refs
2083+
),
2084+
1.0
2085+
);
2086+
assert_approx_eq::assert_approx_eq!(
2087+
cost_model.get_filter_in_list_selectivity(
2088+
&in_list(0, vec![Value::Int32(3)], false),
2089+
&column_refs
2090+
),
2091+
0.0
2092+
);
2093+
assert_approx_eq::assert_approx_eq!(
2094+
cost_model.get_filter_in_list_selectivity(
2095+
&in_list(0, vec![Value::Int32(1)], true),
2096+
&column_refs
2097+
),
2098+
0.2
2099+
);
2100+
assert_approx_eq::assert_approx_eq!(
2101+
cost_model.get_filter_in_list_selectivity(
2102+
&in_list(0, vec![Value::Int32(1), Value::Int32(2)], true),
2103+
&column_refs
2104+
),
2105+
0.0
2106+
);
2107+
assert_approx_eq::assert_approx_eq!(
2108+
cost_model.get_filter_in_list_selectivity(
2109+
&in_list(0, vec![Value::Int32(3)], true),
2110+
&column_refs
2111+
),
2112+
1.0
2113+
);
2114+
}
2115+
19862116
/// A wrapper around get_join_selectivity_from_expr_tree that extracts the table row counts from the cost model
19872117
fn test_get_join_selectivity(
19882118
cost_model: &TestOptCostModel,

optd-datafusion-repr/src/properties/column_ref.rs

-4
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ impl PropertyBuilder<OptRelNodeTyp> for ColumnRefPropertyBuilder {
3939
data: Option<optd_core::rel_node::Value>,
4040
children: &[&Self::Prop],
4141
) -> Self::Prop {
42-
// println!(
43-
// "derive column_ref: {:?}, data: {:?}, children: {:?}",
44-
// typ, data, children
45-
// );
4642
match typ {
4743
// Should account for PhysicalScan.
4844
OptRelNodeTyp::Scan => {

optd-perftest/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ enum Commands {
3232
#[clap(long)]
3333
#[clap(value_delimiter = ',', num_args = 1..)]
3434
// this is the current list of all queries that work in perftest
35-
#[clap(default_value = "2,3,5,6,7,8,9,10,11,12,13,14,17")]
35+
#[clap(default_value = "2,3,5,6,7,8,9,10,11,12,13,14,17,19")]
3636
#[clap(help = "The queries to get the Q-error of")]
3737
query_ids: Vec<u32>,
3838

0 commit comments

Comments
 (0)