@@ -22,11 +22,12 @@ use databend_common_exception::Result;
2222use databend_common_expression:: types:: decimal:: Decimal ;
2323use databend_common_expression:: types:: decimal:: Decimal128Type ;
2424use databend_common_expression:: types:: decimal:: Decimal256Type ;
25+ use databend_common_expression:: types:: nullable:: NullableColumnBuilder ;
2526use databend_common_expression:: types:: number:: Number ;
26- use databend_common_expression:: types:: number:: F64 ;
2727use databend_common_expression:: types:: DataType ;
2828use databend_common_expression:: types:: DecimalDataType ;
2929use databend_common_expression:: types:: Float64Type ;
30+ use databend_common_expression:: types:: NullableType ;
3031use databend_common_expression:: types:: NumberDataType ;
3132use databend_common_expression:: types:: NumberType ;
3233use databend_common_expression:: types:: ValueType ;
@@ -69,44 +70,47 @@ impl<const TYPE: u8> StddevState<TYPE> {
6970 }
7071
7172 fn state_merge ( & mut self , other : & Self ) -> Result < ( ) > {
72- if other. count == 0 {
73- return Ok ( ( ) ) ;
74- }
7573 if self . count == 0 {
7674 self . count = other. count ;
7775 self . mean = other. mean ;
7876 self . dsquared = other. dsquared ;
7977 return Ok ( ( ) ) ;
8078 }
8179
82- let count = self . count + other. count ;
83- let mean = ( self . count as f64 * self . mean + other. count as f64 * other. mean ) / count as f64 ;
84- let delta = other. mean - self . mean ;
80+ if other. count > 0 {
81+ let count = self . count + other. count ;
82+ let mean =
83+ ( self . count as f64 * self . mean + other. count as f64 * other. mean ) / count as f64 ;
84+ let delta = other. mean - self . mean ;
8585
86- self . count = count;
87- self . mean = mean;
88- self . dsquared = other. dsquared
89- + self . dsquared
90- + delta * delta * other. count as f64 * self . count as f64 / count as f64 ;
86+ self . dsquared = other. dsquared
87+ + self . dsquared
88+ + delta * delta * other. count as f64 * self . count as f64 / count as f64 ;
89+
90+ self . mean = mean;
91+ self . count = count;
92+ }
9193
9294 Ok ( ( ) )
9395 }
9496
95- fn state_merge_result ( & mut self , builder : & mut Vec < F64 > ) -> Result < ( ) > {
96- let result = if self . count <= 1 {
97- 0f64
97+ fn state_merge_result (
98+ & mut self ,
99+ builder : & mut NullableColumnBuilder < Float64Type > ,
100+ ) -> Result < ( ) > {
101+ // For single-record inputs, VAR_SAMP and STDDEV_SAMP should return NULL
102+ if self . count <= 1 && ( TYPE == VAR_SAMP || TYPE == STD_SAMP ) {
103+ builder. push_null ( ) ;
98104 } else {
99- match TYPE {
105+ let value = match TYPE {
100106 STD_POP => ( self . dsquared / self . count as f64 ) . sqrt ( ) ,
101107 STD_SAMP => ( self . dsquared / ( self . count - 1 ) as f64 ) . sqrt ( ) ,
102108 VAR_POP => self . dsquared / self . count as f64 ,
103109 VAR_SAMP => self . dsquared / ( self . count - 1 ) as f64 ,
104110 _ => unreachable ! ( ) ,
105- }
111+ } ;
112+ builder. push ( value. into ( ) ) ;
106113 } ;
107-
108- builder. push ( result. into ( ) ) ;
109-
110114 Ok ( ( ) )
111115 }
112116}
@@ -116,7 +120,8 @@ struct NumberAggregateStddevState<const TYPE: u8> {
116120 state : StddevState < TYPE > ,
117121}
118122
119- impl < T , const TYPE : u8 > UnaryState < T , Float64Type > for NumberAggregateStddevState < TYPE >
123+ impl < T , const TYPE : u8 > UnaryState < T , NullableType < Float64Type > >
124+ for NumberAggregateStddevState < TYPE >
120125where
121126 T : ValueType ,
122127 T :: Scalar : Number + AsPrimitive < f64 > ,
@@ -136,7 +141,7 @@ where
136141
137142 fn merge_result (
138143 & mut self ,
139- builder : & mut Vec < F64 > ,
144+ builder : & mut NullableColumnBuilder < Float64Type > ,
140145 _function_data : Option < & dyn FunctionData > ,
141146 ) -> Result < ( ) > {
142147 self . state . state_merge_result ( builder)
@@ -158,7 +163,8 @@ struct DecimalNumberAggregateStddevState<const TYPE: u8> {
158163 state : StddevState < TYPE > ,
159164}
160165
161- impl < T , const TYPE : u8 > UnaryState < T , Float64Type > for DecimalNumberAggregateStddevState < TYPE >
166+ impl < T , const TYPE : u8 > UnaryState < T , NullableType < Float64Type > >
167+ for DecimalNumberAggregateStddevState < TYPE >
162168where
163169 T : ValueType ,
164170 T :: Scalar : Decimal + BorshSerialize + BorshDeserialize ,
@@ -184,7 +190,7 @@ where
184190
185191 fn merge_result (
186192 & mut self ,
187- builder : & mut Vec < F64 > ,
193+ builder : & mut NullableColumnBuilder < Float64Type > ,
188194 _function_data : Option < & dyn FunctionData > ,
189195 ) -> Result < ( ) > {
190196 self . state . state_merge_result ( builder)
@@ -197,33 +203,32 @@ pub fn try_create_aggregate_stddev_pop_function<const TYPE: u8>(
197203 arguments : Vec < DataType > ,
198204) -> Result < Arc < dyn AggregateFunction > > {
199205 assert_unary_arguments ( display_name, arguments. len ( ) ) ?;
206+
207+ let return_type = DataType :: Number ( NumberDataType :: Float64 ) . wrap_nullable ( ) ;
200208 with_number_mapped_type ! ( |NUM_TYPE | match & arguments[ 0 ] {
201209 DataType :: Number ( NumberDataType :: NUM_TYPE ) => {
202- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
203210 AggregateUnaryFunction :: <
204211 NumberAggregateStddevState <TYPE >,
205212 NumberType <NUM_TYPE >,
206- Float64Type ,
213+ NullableType < Float64Type > ,
207214 >:: try_create_unary( display_name, return_type, params, arguments[ 0 ] . clone( ) )
208215 }
209216 DataType :: Decimal ( DecimalDataType :: Decimal128 ( s) ) => {
210- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
211217 let func = AggregateUnaryFunction :: <
212218 DecimalNumberAggregateStddevState <TYPE >,
213219 Decimal128Type ,
214- Float64Type ,
220+ NullableType < Float64Type > ,
215221 >:: try_create(
216222 display_name, return_type, params, arguments[ 0 ] . clone( )
217223 )
218224 . with_function_data( Box :: new( DecimalFuncData { scale: s. scale } ) ) ;
219225 Ok ( Arc :: new( func) )
220226 }
221227 DataType :: Decimal ( DecimalDataType :: Decimal256 ( s) ) => {
222- let return_type = DataType :: Number ( NumberDataType :: Float64 ) ;
223228 let func = AggregateUnaryFunction :: <
224229 DecimalNumberAggregateStddevState <TYPE >,
225230 Decimal256Type ,
226- Float64Type ,
231+ NullableType < Float64Type > ,
227232 >:: try_create(
228233 display_name, return_type, params, arguments[ 0 ] . clone( )
229234 )
0 commit comments