1
+ @file:OptIn(ExperimentalTypeInference ::class )
2
+ @file:Suppress(" LocalVariableName" )
3
+
1
4
package org.jetbrains.kotlinx.dataframe.api
2
5
3
6
import org.jetbrains.kotlinx.dataframe.AnyRow
@@ -13,50 +16,97 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine
13
16
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
14
17
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
15
18
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
16
- import org.jetbrains.kotlinx.dataframe.columns.values
17
- import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
18
19
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators
19
- import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast
20
20
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
21
21
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
22
22
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
23
- import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
23
+ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
24
+ import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
24
25
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
25
- import org.jetbrains.kotlinx.dataframe.impl.zero
26
- import org.jetbrains.kotlinx.dataframe.math.sum
27
- import org.jetbrains.kotlinx.dataframe.math.sumOf
26
+ import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
27
+ import kotlin.experimental.ExperimentalTypeInference
28
+ import kotlin.reflect.KClass
28
29
import kotlin.reflect.KProperty
29
- import kotlin.reflect.full.isSubtypeOf
30
+ import kotlin.reflect.KType
30
31
import kotlin.reflect.typeOf
31
32
33
+ /* TODO KDocs
34
+ * Calculating the sum is supported for all primitive number types.
35
+ * Nulls are filtered out.
36
+ * The return type is always the same as the input type (never null), except for `Byte` and `Short`,
37
+ * which are converted to `Int`.
38
+ * Empty input will result in 0 in the supplied number type.
39
+ * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum.
40
+ */
41
+
32
42
// region DataColumn
33
43
34
- @JvmName(" sumT" )
35
- public fun <T : Number > DataColumn<T>.sum (): T = values.sum(type())
44
+ @JvmName(" sumShort" )
45
+ public fun DataColumn<Short?>.sum (): Int = Aggregators .sum.aggregate(this ) as Int
46
+
47
+ @JvmName(" sumByte" )
48
+ public fun DataColumn<Byte?>.sum (): Int = Aggregators .sum.aggregate(this ) as Int
36
49
37
- @JvmName(" sumTNullable" )
38
- public fun <T : Number > DataColumn<T?>.sum (): T = values.sum(type())
50
+ @Suppress(" UNCHECKED_CAST" )
51
+ @JvmName(" sumNumber" )
52
+ public fun <T : Number > DataColumn<T?>.sum (): T = Aggregators .sum.aggregate(this ) as T
39
53
40
- public inline fun <T , reified R : Number > DataColumn<T>.sumOf (noinline expression : (T ) -> R ): R ? =
41
- (Aggregators .sum as Aggregator <* , * >).cast<R >().aggregateOf(this , expression)
54
+ @JvmName(" sumOfShort" )
55
+ @OverloadResolutionByLambdaReturnType
56
+ public fun <C > DataColumn<C>.sumOf (expression : (C ) -> Short? ): Int =
57
+ Aggregators .sum.aggregateOf(this , expression) as Int
58
+
59
+ @JvmName(" sumOfByte" )
60
+ @OverloadResolutionByLambdaReturnType
61
+ public fun <C > DataColumn<C>.sumOf (expression : (C ) -> Byte? ): Int = Aggregators .sum.aggregateOf(this , expression) as Int
62
+
63
+ @JvmName(" sumOfNumber" )
64
+ @OverloadResolutionByLambdaReturnType
65
+ public inline fun <C , reified V : Number > DataColumn<C>.sumOf (crossinline expression : (C ) -> V ? ): V =
66
+ Aggregators .sum.aggregateOf(this , expression) as V
42
67
43
68
// endregion
44
69
45
70
// region DataRow
46
71
47
- public fun AnyRow.rowSum (): Number =
48
- Aggregators .sum.aggregateCalculatingType(
49
- values = values().filterIsInstance< Number >(),
50
- valueTypes = columnTypes().filter { it.isSubtypeOf(typeOf< Number ?>()) }.toSet(),
51
- ) ? : 0
72
+ public fun AnyRow.rowSum (): Number = Aggregators .sum.aggregateOfRow( this , primitiveOrMixedNumberColumns())
73
+
74
+ @JvmName( " rowSumOfShort " )
75
+ public inline fun < reified T : Short ? > AnyRow. rowSumOf ( _kClass : KClass < Short > = Short : :class): Int =
76
+ rowSumOf(typeOf< T >()) as Int
52
77
53
- public inline fun <reified T : Number > AnyRow.rowSumOf (): T = values().filterIsInstance<T >().sum(typeOf<T >())
78
+ @JvmName(" rowSumOfByte" )
79
+ public inline fun <reified T : Byte ? > AnyRow.rowSumOf (_kClass : KClass <Byte > = Byte : :class): Int =
80
+ rowSumOf(typeOf<T >()) as Int
54
81
82
+ @JvmName(" rowSumOfInt" )
83
+ public inline fun <reified T : Int ? > AnyRow.rowSumOf (_kClass : KClass <Int > = Int : :class): Int =
84
+ rowSumOf(typeOf<T >()) as Int
85
+
86
+ @JvmName(" rowSumOfLong" )
87
+ public inline fun <reified T : Long ? > AnyRow.rowSumOf (_kClass : KClass <Long > = Long : :class): Long =
88
+ rowSumOf(typeOf<T >()) as Long
89
+
90
+ @JvmName(" rowSumOfFloat" )
91
+ public inline fun <reified T : Float ? > AnyRow.rowSumOf (_kClass : KClass <Float > = Float : :class): Float =
92
+ rowSumOf(typeOf<T >()) as Float
93
+
94
+ @JvmName(" rowSumOfDouble" )
95
+ public inline fun <reified T : Double ? > AnyRow.rowSumOf (_kClass : KClass <Double > = Double : :class): Double =
96
+ rowSumOf(typeOf<T >()) as Double
97
+
98
+ // unfortunately, we cannot make a `reified T : Number?` due to clashes
99
+ public fun AnyRow.rowSumOf (type : KType ): Number {
100
+ require(type.isPrimitiveOrMixedNumber()) {
101
+ " Type $type is not a primitive number type. Mean only supports primitive number types."
102
+ }
103
+ return Aggregators .sum.aggregateOfRow(this ) { colsOf(type) }
104
+ }
55
105
// endregion
56
106
57
107
// region DataFrame
58
108
59
- public fun <T > DataFrame<T>.sum (): DataRow <T > = sumFor(numberColumns ())
109
+ public fun <T > DataFrame<T>.sum (): DataRow <T > = sumFor(primitiveOrMixedNumberColumns ())
60
110
61
111
public fun <T , C : Number > DataFrame<T>.sumFor (columns : ColumnsForAggregateSelector <T , C ?>): DataRow <T > =
62
112
Aggregators .sum.aggregateFor(this , columns)
@@ -71,28 +121,70 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
71
121
public fun <T , C : Number > DataFrame<T>.sumFor (vararg columns : KProperty <C ?>): DataRow <T > =
72
122
sumFor { columns.toColumnSet() }
73
123
124
+ @JvmName(" sumShort" )
125
+ @OverloadResolutionByLambdaReturnType
126
+ public fun <T > DataFrame<T>.sum (columns : ColumnsSelector <T , Short ?>): Int =
127
+ Aggregators .sum.aggregateAll(this , columns) as Int
128
+
129
+ @JvmName(" sumByte" )
130
+ @OverloadResolutionByLambdaReturnType
131
+ public fun <T > DataFrame<T>.sum (columns : ColumnsSelector <T , Byte ?>): Int =
132
+ Aggregators .sum.aggregateAll(this , columns) as Int
133
+
134
+ @JvmName(" sumNumber" )
135
+ @OverloadResolutionByLambdaReturnType
74
136
public inline fun <T , reified C : Number > DataFrame<T>.sum (noinline columns : ColumnsSelector <T , C ?>): C =
75
- ( Aggregators .sum.aggregateAll(this , columns) as C ? ) ? : C :: class .zero()
137
+ Aggregators .sum.aggregateAll(this , columns) as C
76
138
139
+ @JvmName(" sumShort" )
140
+ @AccessApiOverload
141
+ public fun <T > DataFrame<T>.sum (vararg columns : ColumnReference <Short ?>): Int = sum { columns.toColumnSet() }
142
+
143
+ @JvmName(" sumByte" )
144
+ @AccessApiOverload
145
+ public fun <T > DataFrame<T>.sum (vararg columns : ColumnReference <Byte ?>): Int = sum { columns.toColumnSet() }
146
+
147
+ @JvmName(" sumNumber" )
77
148
@AccessApiOverload
78
149
public inline fun <T , reified C : Number > DataFrame<T>.sum (vararg columns : ColumnReference <C ?>): C =
79
150
sum { columns.toColumnSet() }
80
151
81
- public fun <T > DataFrame<T>.sum (vararg columns : String ): Number = sum { columns.toColumnsSetOf() }
152
+ public fun <T > DataFrame<T>.sum (vararg columns : String ): Number = sum { columns.toColumnsSetOf<Number ?>() }
153
+
154
+ @JvmName(" sumShort" )
155
+ @AccessApiOverload
156
+ public fun <T > DataFrame<T>.sum (vararg columns : KProperty <Short ?>): Int = sum { columns.toColumnSet() }
82
157
158
+ @JvmName(" sumByte" )
159
+ @AccessApiOverload
160
+ public fun <T > DataFrame<T>.sum (vararg columns : KProperty <Byte ?>): Int = sum { columns.toColumnSet() }
161
+
162
+ @JvmName(" sumNumber" )
83
163
@AccessApiOverload
84
164
public inline fun <T , reified C : Number > DataFrame<T>.sum (vararg columns : KProperty <C ?>): C =
85
165
sum { columns.toColumnSet() }
86
166
87
- public inline fun <T , reified C : Number ?> DataFrame<T>.sumOf (crossinline expression : RowExpression <T , C >): C =
88
- rows().sumOf(typeOf<C >()) { expression(it, it) }
167
+ @JvmName(" sumOfShort" )
168
+ @OverloadResolutionByLambdaReturnType
169
+ public fun <T > DataFrame<T>.sumOf (expression : RowExpression <T , Short ?>): Int =
170
+ Aggregators .sum.aggregateOf(this , expression) as Int
171
+
172
+ @JvmName(" sumOfByte" )
173
+ @OverloadResolutionByLambdaReturnType
174
+ public fun <T > DataFrame<T>.sumOf (expression : RowExpression <T , Byte ?>): Int =
175
+ Aggregators .sum.aggregateOf(this , expression) as Int
176
+
177
+ @JvmName(" sumOfNumber" )
178
+ @OverloadResolutionByLambdaReturnType
179
+ public inline fun <T , reified C : Number > DataFrame<T>.sumOf (crossinline expression : RowExpression <T , C ?>): C =
180
+ Aggregators .sum.aggregateOf(this , expression) as C
89
181
90
182
// endregion
91
183
92
184
// region GroupBy
93
185
@Refine
94
186
@Interpretable(" GroupBySum1" )
95
- public fun <T > Grouped<T>.sum (): DataFrame <T > = sumFor(numberColumns ())
187
+ public fun <T > Grouped<T>.sum (): DataFrame <T > = sumFor(primitiveOrMixedNumberColumns ())
96
188
97
189
@Refine
98
190
@Interpretable(" GroupBySum0" )
@@ -136,7 +228,7 @@ public inline fun <T, reified R : Number> Grouped<T>.sumOf(
136
228
137
229
// region Pivot
138
230
139
- public fun <T > Pivot<T>.sum (separate : Boolean = false): DataRow <T > = sumFor(separate, numberColumns ())
231
+ public fun <T > Pivot<T>.sum (separate : Boolean = false): DataRow <T > = sumFor(separate, primitiveOrMixedNumberColumns ())
140
232
141
233
public fun <T , R : Number > Pivot<T>.sumFor (
142
234
separate : Boolean = false,
@@ -166,14 +258,15 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
166
258
@AccessApiOverload
167
259
public fun <T , C : Number > Pivot<T>.sum (vararg columns : KProperty <C ?>): DataRow <T > = sum { columns.toColumnSet() }
168
260
169
- public inline fun <T , reified R : Number > Pivot<T>.sumOf (crossinline expression : RowExpression <T , R >): DataRow <T > =
261
+ public inline fun <T , reified R : Number > Pivot<T>.sumOf (crossinline expression : RowExpression <T , R ? >): DataRow <T > =
170
262
delegate { sumOf(expression) }
171
263
172
264
// endregion
173
265
174
266
// region PivotGroupBy
175
267
176
- public fun <T > PivotGroupBy<T>.sum (separate : Boolean = false): DataFrame <T > = sumFor(separate, numberColumns())
268
+ public fun <T > PivotGroupBy<T>.sum (separate : Boolean = false): DataFrame <T > =
269
+ sumFor(separate, primitiveOrMixedNumberColumns())
177
270
178
271
public fun <T , R : Number > PivotGroupBy<T>.sumFor (
179
272
separate : Boolean = false,
@@ -209,7 +302,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
209
302
sum { columns.toColumnSet() }
210
303
211
304
public inline fun <T , reified R : Number > PivotGroupBy<T>.sumOf (
212
- crossinline expression : RowExpression <T , R >,
305
+ crossinline expression : RowExpression <T , R ? >,
213
306
): DataFrame <T > = Aggregators .sum.aggregateOf(this , expression)
214
307
215
308
// endregion
0 commit comments