Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 4e47fad

Browse files
authored
refactor & fix: refactor the interface of range filter selectivity, fix null frac (#155)
- Refactor the interface of `get_column_range_selectivity` so that `BETWEEN` can also use this function. - Fix NULL handling. - Update unit tests accordingly.
1 parent 3b0e6b7 commit 4e47fad

File tree

2 files changed

+93
-94
lines changed

2 files changed

+93
-94
lines changed

optd-core/src/rel_node.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub trait RelNodeTyp:
2727
fn list_typ() -> Self;
2828
}
2929

30-
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
30+
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
3131
pub struct SerializableOrderedF64(pub OrderedFloat<f64>);
3232

3333
impl Serialize for SerializableOrderedF64 {
@@ -51,7 +51,7 @@ impl<'de> Deserialize<'de> for SerializableOrderedF64 {
5151
}
5252
}
5353

54-
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
54+
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
5555
pub enum Value {
5656
UInt8(u8),
5757
UInt16(u16),

optd-datafusion-repr/src/cost/base_cost.rs

+91-92
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::ops::Bound;
12
use std::{collections::HashMap, sync::Arc};
23

34
use crate::plan_nodes::{
@@ -985,34 +986,28 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
985986
BinOpType::Neq => {
986987
self.get_column_equality_selectivity(table, *col_idx, value, false)
987988
}
988-
BinOpType::Lt => self.get_column_range_selectivity(
989-
table,
990-
*col_idx,
991-
value,
992-
is_left_col_ref,
993-
false,
994-
),
995-
BinOpType::Leq => self.get_column_range_selectivity(
996-
table,
997-
*col_idx,
998-
value,
999-
is_left_col_ref,
1000-
true,
1001-
),
1002-
BinOpType::Gt => self.get_column_range_selectivity(
1003-
table,
1004-
*col_idx,
1005-
value,
1006-
!is_left_col_ref,
1007-
false,
1008-
),
1009-
BinOpType::Geq => self.get_column_range_selectivity(
1010-
table,
1011-
*col_idx,
1012-
value,
1013-
!is_left_col_ref,
1014-
true,
1015-
),
989+
BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => {
990+
let start = match (comp_bin_op_typ, is_left_col_ref) {
991+
(BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded,
992+
(BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded,
993+
(BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value),
994+
(BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value),
995+
_ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"),
996+
};
997+
let end = match (comp_bin_op_typ, is_left_col_ref) {
998+
(BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value),
999+
(BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value),
1000+
(BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded,
1001+
(BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded,
1002+
_ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"),
1003+
};
1004+
self.get_column_range_selectivity(
1005+
table,
1006+
*col_idx,
1007+
start,
1008+
end,
1009+
)
1010+
},
10161011
_ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"),
10171012
}
10181013
}
@@ -1148,56 +1143,61 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
11481143
}
11491144
}
11501145

1151-
/// Get the selectivity of an expression of the form "column </<=/>=/> value" (or "value </<=/>=/> column")
1152-
/// Computes selectivity based off of statistics
1153-
/// Range predicates are handled entirely differently from equality predicates so this is its own function
1154-
/// If it is unable to find the statistics, it returns None
1155-
/// Like in the Postgres source code, we decompose the four operators "</<=/>=/>" into "is_lt" and "is_eq"
1156-
/// The "is_lt" and "is_eq" values are set as if column is on the left hand side
1157-
fn get_column_range_selectivity(
1146+
/// Compute the frequency of values in a column less than or equal to the given value.
1147+
fn get_column_leq_value_freq(per_column_stats: &PerColumnStats<M, D>, value: &Value) -> f64 {
1148+
// because distr does not include the values in MCVs, we need to compute the CDFs there as well
1149+
// because nulls return false in any comparison, they are never included when computing range selectivity
1150+
let distr_leq_freq = per_column_stats.distr.cdf(value);
1151+
let value = value.clone();
1152+
let pred = Box::new(move |val: &Value| val <= &value);
1153+
let mcvs_leq_freq = per_column_stats.mcvs.freq_over_pred(pred);
1154+
distr_leq_freq + mcvs_leq_freq
1155+
}
1156+
1157+
/// Compute the frequency of values in a column less than the given value.
1158+
fn get_column_lt_value_freq(
11581159
&self,
1160+
per_column_stats: &PerColumnStats<M, D>,
11591161
table: &str,
11601162
col_idx: usize,
11611163
value: &Value,
1162-
is_col_lt_val: bool,
1163-
is_col_eq_val: bool,
1164+
) -> f64 {
1165+
// depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf into total_leq_cdf
1166+
// this logic just so happens to be the exact same logic as get_column_equality_selectivity implements
1167+
Self::get_column_leq_value_freq(per_column_stats, value)
1168+
- self.get_column_equality_selectivity(table, col_idx, value, true)
1169+
}
1170+
1171+
/// Get the selectivity of an expression of the form "column </<=/>=/> value" (or "value </<=/>=/> column").
1172+
/// Computes selectivity based off of statistics.
1173+
/// Range predicates are handled entirely differently from equality predicates so this is its own function.
1174+
/// If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL.
1175+
/// The selectivity is computed as quantile of the right bound minus quantile of the left bound.
1176+
fn get_column_range_selectivity(
1177+
&self,
1178+
table: &str,
1179+
col_idx: usize,
1180+
start: Bound<&Value>,
1181+
end: Bound<&Value>,
11641182
) -> f64 {
11651183
if let Some(per_column_stats) = self.get_per_column_stats(table, col_idx) {
1166-
// because distr does not include the values in MCVs, we need to compute the CDFs there as well
1167-
// because nulls return false in any comparison, they are never included when computing range selectivity
1168-
let distr_leq_freq = per_column_stats.distr.cdf(value);
1169-
let value_clone = value.clone(); // clone the value so that we can move it into the closure to avoid lifetime issues
1170-
// TODO: in a future PR, figure out how to make Values comparable. rn I just hardcoded as_i32() to work around this
1171-
let pred = Box::new(move |val: &Value| val.as_i32() <= value_clone.as_i32());
1172-
let mcvs_leq_freq = per_column_stats.mcvs.freq_over_pred(pred);
1173-
let total_leq_freq = distr_leq_freq + mcvs_leq_freq;
1174-
1175-
// depending on whether value is in mcvs or not, we use different logic to turn total_leq_cdf into total_lt_cdf
1176-
// this logic just so happens to be the exact same logic as get_column_equality_selectivity implements
1177-
let total_lt_freq =
1178-
total_leq_freq - self.get_column_equality_selectivity(table, col_idx, value, true);
1179-
1180-
// use either total_leq_freq or total_lt_freq to get the selectivity
1181-
if is_col_lt_val {
1182-
if is_col_eq_val {
1183-
// this branch means <=
1184-
total_leq_freq
1185-
} else {
1186-
// this branch means <
1187-
total_lt_freq
1184+
let left_quantile = match start {
1185+
Bound::Unbounded => 0.0,
1186+
Bound::Included(value) => {
1187+
self.get_column_lt_value_freq(per_column_stats, table, col_idx, value)
11881188
}
1189-
} else {
1190-
// clippy wants me to collapse this into an else if, but keeping two nested if else statements is clearer
1191-
#[allow(clippy::collapsible_else_if)]
1192-
if is_col_eq_val {
1193-
// this branch means >=, which is 1 - < - null_frac
1194-
// we need to subtract null_frac since that isn't included in >= either
1195-
1.0 - total_lt_freq - per_column_stats.null_frac
1196-
} else {
1197-
// this branch means >. same logic as above
1198-
1.0 - total_leq_freq - per_column_stats.null_frac
1189+
Bound::Excluded(value) => Self::get_column_leq_value_freq(per_column_stats, value),
1190+
};
1191+
let right_quantile = match end {
1192+
Bound::Unbounded => 1.0,
1193+
Bound::Included(value) => Self::get_column_leq_value_freq(per_column_stats, value),
1194+
Bound::Excluded(value) => {
1195+
self.get_column_lt_value_freq(per_column_stats, table, col_idx, value)
11991196
}
1200-
}
1197+
};
1198+
assert!(left_quantile <= right_quantile);
1199+
// `Distribution` does not account for NULL values, so the selectivity is smaller than frequency.
1200+
(right_quantile - left_quantile) * (1.0 - per_column_stats.null_frac)
12011201
} else {
12021202
DEFAULT_INEQ_SEL
12031203
}
@@ -1541,7 +1541,7 @@ mod tests {
15411541
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
15421542
));
15431543
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
1544-
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
1544+
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
15451545
let column_refs = vec![ColumnRef::BaseTableColumnRef {
15461546
table: String::from(TABLE1_NAME),
15471547
col_idx: 0,
@@ -1565,18 +1565,18 @@ mod tests {
15651565
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
15661566
));
15671567
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
1568-
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
1568+
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
15691569
let column_refs = vec![ColumnRef::BaseTableColumnRef {
15701570
table: String::from(TABLE1_NAME),
15711571
col_idx: 0,
15721572
}];
15731573
assert_approx_eq::assert_approx_eq!(
15741574
cost_model.get_filter_selectivity(expr_tree, &column_refs),
1575-
0.7
1575+
0.7 * 0.9
15761576
);
15771577
assert_approx_eq::assert_approx_eq!(
15781578
cost_model.get_filter_selectivity(expr_tree_rev, &column_refs),
1579-
0.7
1579+
0.7 * 0.9
15801580
);
15811581
}
15821582

@@ -1598,7 +1598,7 @@ mod tests {
15981598
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
15991599
));
16001600
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
1601-
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
1601+
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
16021602
let column_refs = vec![ColumnRef::BaseTableColumnRef {
16031603
table: String::from(TABLE1_NAME),
16041604
col_idx: 0,
@@ -1627,7 +1627,7 @@ mod tests {
16271627
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
16281628
));
16291629
let expr_tree = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(15)));
1630-
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
1630+
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
16311631
let column_refs = vec![ColumnRef::BaseTableColumnRef {
16321632
table: String::from(TABLE1_NAME),
16331633
col_idx: 0,
@@ -1651,7 +1651,7 @@ mod tests {
16511651
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
16521652
));
16531653
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
1654-
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
1654+
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
16551655
let column_refs = vec![ColumnRef::BaseTableColumnRef {
16561656
table: String::from(TABLE1_NAME),
16571657
col_idx: 0,
@@ -1675,18 +1675,18 @@ mod tests {
16751675
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
16761676
));
16771677
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
1678-
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
1678+
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
16791679
let column_refs = vec![ColumnRef::BaseTableColumnRef {
16801680
table: String::from(TABLE1_NAME),
16811681
col_idx: 0,
16821682
}];
16831683
assert_approx_eq::assert_approx_eq!(
16841684
cost_model.get_filter_selectivity(expr_tree, &column_refs),
1685-
0.6
1685+
0.6 * 0.9
16861686
);
16871687
assert_approx_eq::assert_approx_eq!(
16881688
cost_model.get_filter_selectivity(expr_tree_rev, &column_refs),
1689-
0.6
1689+
0.6 * 0.9
16901690
);
16911691
}
16921692

@@ -1708,7 +1708,7 @@ mod tests {
17081708
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
17091709
));
17101710
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
1711-
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
1711+
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
17121712
// TODO(phw2): make column_refs a function
17131713
let column_refs = vec![ColumnRef::BaseTableColumnRef {
17141714
table: String::from(TABLE1_NAME),
@@ -1742,7 +1742,7 @@ mod tests {
17421742
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
17431743
));
17441744
let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15)));
1745-
let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0));
1745+
let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), col_ref(0));
17461746
let column_refs = vec![ColumnRef::BaseTableColumnRef {
17471747
table: String::from(TABLE1_NAME),
17481748
col_idx: 0,
@@ -1768,7 +1768,7 @@ mod tests {
17681768
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
17691769
));
17701770
let expr_tree = bin_op(BinOpType::Gt, col_ref(0), cnst(Value::Int32(15)));
1771-
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
1771+
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
17721772
let column_refs = vec![ColumnRef::BaseTableColumnRef {
17731773
table: String::from(TABLE1_NAME),
17741774
col_idx: 0,
@@ -1792,19 +1792,18 @@ mod tests {
17921792
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
17931793
));
17941794
let expr_tree = bin_op(BinOpType::Gt, col_ref(0), cnst(Value::Int32(15)));
1795-
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
1795+
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
17961796
let column_refs = vec![ColumnRef::BaseTableColumnRef {
17971797
table: String::from(TABLE1_NAME),
17981798
col_idx: 0,
17991799
}];
1800-
// we have to subtract 0.1 since we don't want to include them in GT either
18011800
assert_approx_eq::assert_approx_eq!(
18021801
cost_model.get_filter_selectivity(expr_tree, &column_refs),
1803-
1.0 - 0.7 - 0.1
1802+
(1.0 - 0.7) * 0.9
18041803
);
18051804
assert_approx_eq::assert_approx_eq!(
18061805
cost_model.get_filter_selectivity(expr_tree_rev, &column_refs),
1807-
1.0 - 0.7 - 0.1
1806+
(1.0 - 0.7) * 0.9
18081807
);
18091808
}
18101809

@@ -1818,7 +1817,7 @@ mod tests {
18181817
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
18191818
));
18201819
let expr_tree = bin_op(BinOpType::Geq, col_ref(0), cnst(Value::Int32(15)));
1821-
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
1820+
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
18221821
let column_refs = vec![ColumnRef::BaseTableColumnRef {
18231822
table: String::from(TABLE1_NAME),
18241823
col_idx: 0,
@@ -1842,19 +1841,19 @@ mod tests {
18421841
TestDistribution::new(vec![(Value::Int32(15), 0.7)]),
18431842
));
18441843
let expr_tree = bin_op(BinOpType::Geq, col_ref(0), cnst(Value::Int32(15)));
1845-
let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), col_ref(0));
1844+
let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), col_ref(0));
18461845
let column_refs = vec![ColumnRef::BaseTableColumnRef {
18471846
table: String::from(TABLE1_NAME),
18481847
col_idx: 0,
18491848
}];
1850-
// we have to subtract 0.1 since we don't want to include them in GT either
1849+
// we have to add 0.1 since it's Geq
18511850
assert_approx_eq::assert_approx_eq!(
18521851
cost_model.get_filter_selectivity(expr_tree, &column_refs),
1853-
1.0 - 0.6 - 0.1
1852+
(1.0 - 0.7 + 0.1) * 0.9
18541853
);
18551854
assert_approx_eq::assert_approx_eq!(
18561855
cost_model.get_filter_selectivity(expr_tree_rev, &column_refs),
1857-
1.0 - 0.6 - 0.1
1856+
(1.0 - 0.7 + 0.1) * 0.9
18581857
);
18591858
}
18601859

0 commit comments

Comments
 (0)