@@ -22,11 +22,12 @@ use databend_common_exception::Result;
22
22
use databend_common_expression:: types:: decimal:: Decimal ;
23
23
use databend_common_expression:: types:: decimal:: Decimal128Type ;
24
24
use databend_common_expression:: types:: decimal:: Decimal256Type ;
25
+ use databend_common_expression:: types:: nullable:: NullableColumnBuilder ;
25
26
use databend_common_expression:: types:: number:: Number ;
26
- use databend_common_expression:: types:: number:: F64 ;
27
27
use databend_common_expression:: types:: DataType ;
28
28
use databend_common_expression:: types:: DecimalDataType ;
29
29
use databend_common_expression:: types:: Float64Type ;
30
+ use databend_common_expression:: types:: NullableType ;
30
31
use databend_common_expression:: types:: NumberDataType ;
31
32
use databend_common_expression:: types:: NumberType ;
32
33
use databend_common_expression:: types:: ValueType ;
@@ -69,44 +70,47 @@ impl<const TYPE: u8> StddevState<TYPE> {
69
70
}
70
71
71
72
fn state_merge ( & mut self , other : & Self ) -> Result < ( ) > {
72
- if other. count == 0 {
73
- return Ok ( ( ) ) ;
74
- }
75
73
if self . count == 0 {
76
74
self . count = other. count ;
77
75
self . mean = other. mean ;
78
76
self . dsquared = other. dsquared ;
79
77
return Ok ( ( ) ) ;
80
78
}
81
79
82
- let count = self . count + other. count ;
83
- let mean = ( self . count as f64 * self . mean + other. count as f64 * other. mean ) / count as f64 ;
84
- let delta = other. mean - self . mean ;
80
+ if other. count > 0 {
81
+ let count = self . count + other. count ;
82
+ let mean =
83
+ ( self . count as f64 * self . mean + other. count as f64 * other. mean ) / count as f64 ;
84
+ let delta = other. mean - self . mean ;
85
85
86
- self . count = count;
87
- self . mean = mean;
88
- self . dsquared = other. dsquared
89
- + self . dsquared
90
- + delta * delta * other. count as f64 * self . count as f64 / count as f64 ;
86
+ self . dsquared = other. dsquared
87
+ + self . dsquared
88
+ + delta * delta * other. count as f64 * self . count as f64 / count as f64 ;
89
+
90
+ self . mean = mean;
91
+ self . count = count;
92
+ }
91
93
92
94
Ok ( ( ) )
93
95
}
94
96
95
- fn state_merge_result ( & mut self , builder : & mut Vec < F64 > ) -> Result < ( ) > {
96
- let result = if self . count <= 1 {
97
- 0f64
97
+ fn state_merge_result (
98
+ & mut self ,
99
+ builder : & mut NullableColumnBuilder < Float64Type > ,
100
+ ) -> Result < ( ) > {
101
+ // For single-record inputs, VAR_SAMP and STDDEV_SAMP should return NULL
102
+ if self . count <= 1 && ( TYPE == VAR_SAMP || TYPE == STD_SAMP ) {
103
+ builder. push_null ( ) ;
98
104
} else {
99
- match TYPE {
105
+ let value = match TYPE {
100
106
STD_POP => ( self . dsquared / self . count as f64 ) . sqrt ( ) ,
101
107
STD_SAMP => ( self . dsquared / ( self . count - 1 ) as f64 ) . sqrt ( ) ,
102
108
VAR_POP => self . dsquared / self . count as f64 ,
103
109
VAR_SAMP => self . dsquared / ( self . count - 1 ) as f64 ,
104
110
_ => unreachable ! ( ) ,
105
- }
111
+ } ;
112
+ builder. push ( value. into ( ) ) ;
106
113
} ;
107
-
108
- builder. push ( result. into ( ) ) ;
109
-
110
114
Ok ( ( ) )
111
115
}
112
116
}
@@ -116,7 +120,8 @@ struct NumberAggregateStddevState<const TYPE: u8> {
116
120
state : StddevState < TYPE > ,
117
121
}
118
122
119
- impl < T , const TYPE : u8 > UnaryState < T , Float64Type > for NumberAggregateStddevState < TYPE >
123
+ impl < T , const TYPE : u8 > UnaryState < T , NullableType < Float64Type > >
124
+ for NumberAggregateStddevState < TYPE >
120
125
where
121
126
T : ValueType ,
122
127
T :: Scalar : Number + AsPrimitive < f64 > ,
@@ -136,7 +141,7 @@ where
136
141
137
142
fn merge_result (
138
143
& mut self ,
139
- builder : & mut Vec < F64 > ,
144
+ builder : & mut NullableColumnBuilder < Float64Type > ,
140
145
_function_data : Option < & dyn FunctionData > ,
141
146
) -> Result < ( ) > {
142
147
self . state . state_merge_result ( builder)
@@ -158,7 +163,8 @@ struct DecimalNumberAggregateStddevState<const TYPE: u8> {
158
163
state : StddevState < TYPE > ,
159
164
}
160
165
161
- impl < T , const TYPE : u8 > UnaryState < T , Float64Type > for DecimalNumberAggregateStddevState < TYPE >
166
+ impl < T , const TYPE : u8 > UnaryState < T , NullableType < Float64Type > >
167
+ for DecimalNumberAggregateStddevState < TYPE >
162
168
where
163
169
T : ValueType ,
164
170
T :: Scalar : Decimal + BorshSerialize + BorshDeserialize ,
@@ -184,7 +190,7 @@ where
184
190
185
191
fn merge_result (
186
192
& mut self ,
187
- builder : & mut Vec < F64 > ,
193
+ builder : & mut NullableColumnBuilder < Float64Type > ,
188
194
_function_data : Option < & dyn FunctionData > ,
189
195
) -> Result < ( ) > {
190
196
self . state . state_merge_result ( builder)
@@ -197,33 +203,32 @@ pub fn try_create_aggregate_stddev_pop_function<const TYPE: u8>(
197
203
arguments : Vec < DataType > ,
198
204
) -> Result < Arc < dyn AggregateFunction > > {
199
205
assert_unary_arguments ( display_name, arguments. len ( ) ) ?;
206
+
207
+ let return_type = DataType :: Number ( NumberDataType :: Float64 ) . wrap_nullable ( ) ;
200
208
with_number_mapped_type ! ( |NUM_TYPE | match & arguments[ 0 ] {
201
209
DataType :: Number ( NumberDataType :: NUM_TYPE ) => {
202
- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
203
210
AggregateUnaryFunction :: <
204
211
NumberAggregateStddevState <TYPE >,
205
212
NumberType <NUM_TYPE >,
206
- Float64Type ,
213
+ NullableType < Float64Type > ,
207
214
>:: try_create_unary( display_name, return_type, params, arguments[ 0 ] . clone( ) )
208
215
}
209
216
DataType :: Decimal ( DecimalDataType :: Decimal128 ( s) ) => {
210
- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
211
217
let func = AggregateUnaryFunction :: <
212
218
DecimalNumberAggregateStddevState <TYPE >,
213
219
Decimal128Type ,
214
- Float64Type ,
220
+ NullableType < Float64Type > ,
215
221
>:: try_create(
216
222
display_name, return_type, params, arguments[ 0 ] . clone( )
217
223
)
218
224
. with_function_data( Box :: new( DecimalFuncData { scale: s. scale } ) ) ;
219
225
Ok ( Arc :: new( func) )
220
226
}
221
227
DataType :: Decimal ( DecimalDataType :: Decimal256 ( s) ) => {
222
- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
223
228
let func = AggregateUnaryFunction :: <
224
229
DecimalNumberAggregateStddevState <TYPE >,
225
230
Decimal256Type ,
226
- Float64Type ,
231
+ NullableType < Float64Type > ,
227
232
>:: try_create(
228
233
display_name, return_type, params, arguments[ 0 ] . clone( )
229
234
)
0 commit comments