Skip to content

Commit 31d0fac

Browse files
authored
fix(query): fix left semi optimize to inner join (#17458)
* fix(query): fix left semi optimize to inner join * fix(query): fix left semi optimize to inner join * fix(query): fix left semi optimize to inner join * fix(query): fix left semi optimize to inner join * fix(query): fix left semi optimize to inner join * fix(query): fix left semi optimize to inner join
1 parent 61cdf5a commit 31d0fac

File tree

8 files changed

+318
-2232
lines changed

8 files changed

+318
-2232
lines changed

src/query/expression/src/types/date.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use databend_common_io::cursor_ext::ReadBytesExt;
2424
use jiff::civil::Date;
2525
use jiff::fmt::strtime;
2626
use jiff::tz::TimeZone;
27-
use log::error;
2827
use num_traits::AsPrimitive;
2928

3029
use super::number::SimpleDomain;
@@ -54,7 +53,6 @@ pub fn clamp_date(days: i64) -> i32 {
5453
if (DATE_MIN as i64..=DATE_MAX as i64).contains(&days) {
5554
days as i32
5655
} else {
57-
error!("{}", format!("date {} is out of range", days));
5856
DATE_MIN
5957
}
6058
}

src/query/functions/src/aggregates/aggregate_stddev.rs

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ use databend_common_exception::Result;
2222
use databend_common_expression::types::decimal::Decimal;
2323
use databend_common_expression::types::decimal::Decimal128Type;
2424
use databend_common_expression::types::decimal::Decimal256Type;
25+
use databend_common_expression::types::nullable::NullableColumnBuilder;
2526
use databend_common_expression::types::number::Number;
26-
use databend_common_expression::types::number::F64;
2727
use databend_common_expression::types::DataType;
2828
use databend_common_expression::types::DecimalDataType;
2929
use databend_common_expression::types::Float64Type;
30+
use databend_common_expression::types::NullableType;
3031
use databend_common_expression::types::NumberDataType;
3132
use databend_common_expression::types::NumberType;
3233
use databend_common_expression::types::ValueType;
@@ -69,44 +70,47 @@ impl<const TYPE: u8> StddevState<TYPE> {
6970
}
7071

7172
fn state_merge(&mut self, other: &Self) -> Result<()> {
72-
if other.count == 0 {
73-
return Ok(());
74-
}
7573
if self.count == 0 {
7674
self.count = other.count;
7775
self.mean = other.mean;
7876
self.dsquared = other.dsquared;
7977
return Ok(());
8078
}
8179

82-
let count = self.count + other.count;
83-
let mean = (self.count as f64 * self.mean + other.count as f64 * other.mean) / count as f64;
84-
let delta = other.mean - self.mean;
80+
if other.count > 0 {
81+
let count = self.count + other.count;
82+
let mean =
83+
(self.count as f64 * self.mean + other.count as f64 * other.mean) / count as f64;
84+
let delta = other.mean - self.mean;
8585

86-
self.count = count;
87-
self.mean = mean;
88-
self.dsquared = other.dsquared
89-
+ self.dsquared
90-
+ delta * delta * other.count as f64 * self.count as f64 / count as f64;
86+
self.dsquared = other.dsquared
87+
+ self.dsquared
88+
+ delta * delta * other.count as f64 * self.count as f64 / count as f64;
89+
90+
self.mean = mean;
91+
self.count = count;
92+
}
9193

9294
Ok(())
9395
}
9496

95-
fn state_merge_result(&mut self, builder: &mut Vec<F64>) -> Result<()> {
96-
let result = if self.count <= 1 {
97-
0f64
97+
fn state_merge_result(
98+
&mut self,
99+
builder: &mut NullableColumnBuilder<Float64Type>,
100+
) -> Result<()> {
101+
// For single-record inputs, VAR_SAMP and STDDEV_SAMP should return NULL
102+
if self.count <= 1 && (TYPE == VAR_SAMP || TYPE == STD_SAMP) {
103+
builder.push_null();
98104
} else {
99-
match TYPE {
105+
let value = match TYPE {
100106
STD_POP => (self.dsquared / self.count as f64).sqrt(),
101107
STD_SAMP => (self.dsquared / (self.count - 1) as f64).sqrt(),
102108
VAR_POP => self.dsquared / self.count as f64,
103109
VAR_SAMP => self.dsquared / (self.count - 1) as f64,
104110
_ => unreachable!(),
105-
}
111+
};
112+
builder.push(value.into());
106113
};
107-
108-
builder.push(result.into());
109-
110114
Ok(())
111115
}
112116
}
@@ -116,7 +120,8 @@ struct NumberAggregateStddevState<const TYPE: u8> {
116120
state: StddevState<TYPE>,
117121
}
118122

119-
impl<T, const TYPE: u8> UnaryState<T, Float64Type> for NumberAggregateStddevState<TYPE>
123+
impl<T, const TYPE: u8> UnaryState<T, NullableType<Float64Type>>
124+
for NumberAggregateStddevState<TYPE>
120125
where
121126
T: ValueType,
122127
T::Scalar: Number + AsPrimitive<f64>,
@@ -136,7 +141,7 @@ where
136141

137142
fn merge_result(
138143
&mut self,
139-
builder: &mut Vec<F64>,
144+
builder: &mut NullableColumnBuilder<Float64Type>,
140145
_function_data: Option<&dyn FunctionData>,
141146
) -> Result<()> {
142147
self.state.state_merge_result(builder)
@@ -158,7 +163,8 @@ struct DecimalNumberAggregateStddevState<const TYPE: u8> {
158163
state: StddevState<TYPE>,
159164
}
160165

161-
impl<T, const TYPE: u8> UnaryState<T, Float64Type> for DecimalNumberAggregateStddevState<TYPE>
166+
impl<T, const TYPE: u8> UnaryState<T, NullableType<Float64Type>>
167+
for DecimalNumberAggregateStddevState<TYPE>
162168
where
163169
T: ValueType,
164170
T::Scalar: Decimal + BorshSerialize + BorshDeserialize,
@@ -184,7 +190,7 @@ where
184190

185191
fn merge_result(
186192
&mut self,
187-
builder: &mut Vec<F64>,
193+
builder: &mut NullableColumnBuilder<Float64Type>,
188194
_function_data: Option<&dyn FunctionData>,
189195
) -> Result<()> {
190196
self.state.state_merge_result(builder)
@@ -197,33 +203,32 @@ pub fn try_create_aggregate_stddev_pop_function<const TYPE: u8>(
197203
arguments: Vec<DataType>,
198204
) -> Result<Arc<dyn AggregateFunction>> {
199205
assert_unary_arguments(display_name, arguments.len())?;
206+
207+
let return_type = DataType::Number(NumberDataType::Float64).wrap_nullable();
200208
with_number_mapped_type!(|NUM_TYPE| match &arguments[0] {
201209
DataType::Number(NumberDataType::NUM_TYPE) => {
202-
let return_type = DataType::Number(NumberDataType::Float64);
203210
AggregateUnaryFunction::<
204211
NumberAggregateStddevState<TYPE>,
205212
NumberType<NUM_TYPE>,
206-
Float64Type,
213+
NullableType<Float64Type>,
207214
>::try_create_unary(display_name, return_type, params, arguments[0].clone())
208215
}
209216
DataType::Decimal(DecimalDataType::Decimal128(s)) => {
210-
let return_type = DataType::Number(NumberDataType::Float64);
211217
let func = AggregateUnaryFunction::<
212218
DecimalNumberAggregateStddevState<TYPE>,
213219
Decimal128Type,
214-
Float64Type,
220+
NullableType<Float64Type>,
215221
>::try_create(
216222
display_name, return_type, params, arguments[0].clone()
217223
)
218224
.with_function_data(Box::new(DecimalFuncData { scale: s.scale }));
219225
Ok(Arc::new(func))
220226
}
221227
DataType::Decimal(DecimalDataType::Decimal256(s)) => {
222-
let return_type = DataType::Number(NumberDataType::Float64);
223228
let func = AggregateUnaryFunction::<
224229
DecimalNumberAggregateStddevState<TYPE>,
225230
Decimal256Type,
226-
Float64Type,
231+
NullableType<Float64Type>,
227232
>::try_create(
228233
display_name, return_type, params, arguments[0].clone()
229234
)

src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ evaluation (internal):
690690
| Column | Data |
691691
+--------+-------------------------------------------------------------------------+
692692
| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } |
693-
| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______11] } |
693+
| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______00] } |
694694
+--------+-------------------------------------------------------------------------+
695695

696696

@@ -710,7 +710,7 @@ evaluation (internal):
710710
| Column | Data |
711711
+--------+-----------------------------------------------------------------------------------------+
712712
| dec | NullableColumn { column: Decimal128([1.10, 2.20, 0.00, 3.30]), validity: [0b____1011] } |
713-
| Output | NullableColumn { column: Float64([0, 0.7778174593]), validity: [0b______11] } |
713+
| Output | NullableColumn { column: Float64([0, 0.7778174593]), validity: [0b______10] } |
714714
+--------+-----------------------------------------------------------------------------------------+
715715

716716

@@ -720,7 +720,7 @@ evaluation (internal):
720720
| Column | Data |
721721
+--------+-------------------------------------------------------------------------+
722722
| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } |
723-
| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______11] } |
723+
| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______00] } |
724724
+--------+-------------------------------------------------------------------------+
725725

726726

src/query/sql/src/planner/optimizer/rule/rewrite/rule_semi_to_inner_join.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashMap;
1516
use std::collections::HashSet;
1617
use std::sync::Arc;
1718

1819
use databend_common_exception::Result;
20+
use databend_common_expression::types::DataType;
1921

2022
use crate::optimizer::extract::Matcher;
2123
use crate::optimizer::rule::Rule;
@@ -88,12 +90,20 @@ impl Rule for RuleSemiToInnerJoin {
8890
};
8991

9092
// Traverse child to find join keys in group by keys
91-
let mut group_by_keys = HashSet::new();
93+
let mut group_by_keys = HashMap::new();
9294
find_group_by_keys(child, &mut group_by_keys)?;
93-
if condition_cols
94-
.iter()
95-
.all(|condition| group_by_keys.contains(condition))
96-
{
95+
96+
// If condition are all group by keys and not nullable
97+
// we can rewrite semi join to inner join
98+
// inner join will ignore null values but semi join will keep them
99+
// this happens in Q38
100+
if condition_cols.iter().all(|condition| {
101+
if let Some(t) = group_by_keys.get(condition) {
102+
!t.is_nullable_or_null()
103+
} else {
104+
false
105+
}
106+
}) {
97107
join.join_type = JoinType::Inner;
98108
let mut join_expr = SExpr::create_binary(
99109
Arc::new(join.into()),
@@ -111,15 +121,18 @@ impl Rule for RuleSemiToInnerJoin {
111121
}
112122
}
113123

114-
fn find_group_by_keys(child: &SExpr, group_by_keys: &mut HashSet<IndexType>) -> Result<()> {
124+
fn find_group_by_keys(
125+
child: &SExpr,
126+
group_by_keys: &mut HashMap<IndexType, Box<DataType>>,
127+
) -> Result<()> {
115128
match child.plan() {
116129
RelOperator::EvalScalar(_) | RelOperator::Filter(_) | RelOperator::Window(_) => {
117130
find_group_by_keys(child.child(0)?, group_by_keys)?;
118131
}
119132
RelOperator::Aggregate(agg) => {
120133
for item in agg.group_items.iter() {
121134
if let ScalarExpr::BoundColumnRef(c) = &item.scalar {
122-
group_by_keys.insert(c.column.index);
135+
group_by_keys.insert(c.column.index, c.column.data_type.clone());
123136
}
124137
}
125138
}

tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,11 @@ select stddev_samp(a) from d;
419419
----
420420
7.055988000745655
421421

422+
query F
423+
select stddev_samp(a) from d where b = 'abc';
424+
----
425+
NULL
426+
422427
query TTTTT
423428
select json_array_agg(a), json_array_agg(b), json_array_agg(c), json_array_agg(d), json_array_agg(e), json_array_agg('a') from d
424429
----

0 commit comments

Comments
 (0)