Skip to content

Commit 9aaf84d

Browse files
committedMar 18, 2025·
extra overloads for sum api, fixed parts of aggregateOfRow
1 parent b046891 commit 9aaf84d

File tree

6 files changed

+135
-85
lines changed

6 files changed

+135
-85
lines changed
 

‎core/api/core.api

+15-9
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt {
38163816

38173817
public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38183818
public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number;
3819+
public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number;
38193820
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38203821
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number;
38213822
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
@@ -3839,6 +3840,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38393840
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38403841
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38413842
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3843+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3844+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3845+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3846+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38423847
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38433848
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38443849
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
@@ -3863,8 +3868,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38633868
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38643869
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38653870
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3866-
public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3867-
public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3871+
public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3872+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3873+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3874+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3875+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3876+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3877+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3878+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3879+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38683880
}
38693881

38703882
public final class org/jetbrains/kotlinx/dataframe/api/TailKt {
@@ -6153,13 +6165,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt {
61536165
}
61546166

61556167
public final class org/jetbrains/kotlinx/dataframe/math/SumKt {
6156-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal;
6157-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger;
6158-
public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6159-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal;
6160-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger;
6161-
public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6162-
public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number;
6168+
public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number;
61636169
}
61646170

61656171
public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema {

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+4-8
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ import kotlin.reflect.KProperty
2525
import kotlin.reflect.full.withNullability
2626
import kotlin.reflect.typeOf
2727

28-
/*
29-
* TODO KDocs:
28+
/* TODO KDocs:
3029
* Calculating the mean is supported for all primitive number types.
31-
* Nulls are filtered from columns.
30+
* Nulls are filtered out.
3231
* The return type is always Double, Double.NaN for empty input, never null.
3332
* (May introduce loss of precision for Longs).
3433
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
@@ -48,16 +47,13 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
4847
// region DataRow
4948

5049
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
51-
Aggregators.mean(skipNA).aggregateOfRow(this) {
52-
colsOf<Number?> { it.isPrimitiveNumber() }
53-
}
50+
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns())
5451

5552
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
5653
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
5754
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
5855
}
59-
return Aggregators.mean(skipNA)
60-
.aggregateOfRow(this) { colsOf<T>() }
56+
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
6157
}
6258

6359
// endregion

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+102-56
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@file:OptIn(ExperimentalTypeInference::class)
2+
@file:Suppress("LocalVariableName")
23

34
package org.jetbrains.kotlinx.dataframe.api
45

@@ -20,90 +21,93 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
2021
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2122
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2223
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
23-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
24+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
2425
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
2526
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
26-
import org.jetbrains.kotlinx.dataframe.impl.zero
2727
import kotlin.experimental.ExperimentalTypeInference
28+
import kotlin.reflect.KClass
2829
import kotlin.reflect.KProperty
30+
import kotlin.reflect.KType
31+
import kotlin.reflect.full.withNullability
2932
import kotlin.reflect.typeOf
3033

31-
// region DataColumn
34+
/* TODO KDocs
35+
* Calculating the sum is supported for all primitive number types.
36+
* Nulls are filtered out.
37+
* The return type is always the same as the input type (never null), except for `Byte` and `Short`,
38+
* which are converted to `Int`.
39+
* Empty input will result in 0 in the supplied number type.
40+
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum.
41+
*/
3242

33-
@JvmName("sumInt")
34-
public fun DataColumn<Int?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
43+
// region DataColumn
3544

3645
@JvmName("sumShort")
3746
public fun DataColumn<Short?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
3847

3948
@JvmName("sumByte")
4049
public fun DataColumn<Byte?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
4150

42-
@JvmName("sumLong")
43-
public fun DataColumn<Long?>.sum(): Long = Aggregators.sum.aggregate(this) as Long
44-
45-
@JvmName("sumFloat")
46-
public fun DataColumn<Float?>.sum(): Float = Aggregators.sum.aggregate(this) as Float
47-
48-
@JvmName("sumDouble")
49-
public fun DataColumn<Double?>.sum(): Double = Aggregators.sum.aggregate(this) as Double
50-
51+
@Suppress("UNCHECKED_CAST")
5152
@JvmName("sumNumber")
52-
public fun DataColumn<Number?>.sum(): Number = Aggregators.sum.aggregate(this)
53-
54-
@JvmName("sumOfInt")
55-
@OverloadResolutionByLambdaReturnType
56-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
53+
public fun <T : Number> DataColumn<T?>.sum(): T = Aggregators.sum.aggregate(this) as T
5754

5855
@JvmName("sumOfShort")
5956
@OverloadResolutionByLambdaReturnType
60-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Short?): Int =
57+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Short?): Int =
6158
Aggregators.sum.aggregateOf(this, expression) as Int
6259

6360
@JvmName("sumOfByte")
6461
@OverloadResolutionByLambdaReturnType
65-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
66-
67-
@JvmName("sumOfLong")
68-
@OverloadResolutionByLambdaReturnType
69-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Long?): Long =
70-
Aggregators.sum.aggregateOf(this, expression) as Long
71-
72-
@JvmName("sumOfFloat")
73-
@OverloadResolutionByLambdaReturnType
74-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Float?): Float =
75-
Aggregators.sum.aggregateOf(this, expression) as Float
76-
77-
@JvmName("sumOfDouble")
78-
@OverloadResolutionByLambdaReturnType
79-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Double?): Double =
80-
Aggregators.sum.aggregateOf(this, expression) as Double
62+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
8163

8264
@JvmName("sumOfNumber")
8365
@OverloadResolutionByLambdaReturnType
84-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression)
66+
public inline fun <C, reified V : Number> DataColumn<C>.sumOf(crossinline expression: (C) -> V?): V =
67+
Aggregators.sum.aggregateOf(this, expression) as V
8568

8669
// endregion
8770

8871
// region DataRow
8972

90-
public fun AnyRow.rowSum(): Number =
91-
Aggregators.sum.aggregateOfRow(this) {
92-
colsOf<Number?> { it.isPrimitiveNumber() }
93-
}
73+
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns())
9474

95-
public inline fun <reified T : Number> AnyRow.rowSumOf(): Number /*todo*/ {
96-
require(typeOf<T>() in primitiveNumberTypes) {
97-
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
75+
@JvmName("rowSumOfShort")
76+
public inline fun <reified T : Short?> AnyRow.rowSumOf(_kClass: KClass<Short> = Short::class): Int =
77+
rowSumOf(typeOf<T>()) as Int
78+
79+
@JvmName("rowSumOfByte")
80+
public inline fun <reified T : Byte?> AnyRow.rowSumOf(_kClass: KClass<Byte> = Byte::class): Int =
81+
rowSumOf(typeOf<T>()) as Int
82+
83+
@JvmName("rowSumOfInt")
84+
public inline fun <reified T : Int?> AnyRow.rowSumOf(_kClass: KClass<Int> = Int::class): Int =
85+
rowSumOf(typeOf<T>()) as Int
86+
87+
@JvmName("rowSumOfLong")
88+
public inline fun <reified T : Long?> AnyRow.rowSumOf(_kClass: KClass<Long> = Long::class): Long =
89+
rowSumOf(typeOf<T>()) as Long
90+
91+
@JvmName("rowSumOfFloat")
92+
public inline fun <reified T : Float?> AnyRow.rowSumOf(_kClass: KClass<Float> = Float::class): Float =
93+
rowSumOf(typeOf<T>()) as Float
94+
95+
@JvmName("rowSumOfDouble")
96+
public inline fun <reified T : Double?> AnyRow.rowSumOf(_kClass: KClass<Double> = Double::class): Double =
97+
rowSumOf(typeOf<T>()) as Double
98+
99+
// unfortunately, we cannot make a `reified T : Number?` due to clashes
100+
public fun AnyRow.rowSumOf(type: KType): Number {
101+
require(type.withNullability(false) in primitiveNumberTypes) {
102+
"Type $type is not a primitive number type. Mean only supports primitive number types."
98103
}
99-
return Aggregators.sum
100-
.aggregateOfRow(this) { colsOf<T>() }
104+
return Aggregators.sum.aggregateOfRow(this) { colsOf(type) }
101105
}
102106
// endregion
103107

104108
// region DataFrame
105109

106-
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(numberColumns())
110+
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(primitiveNumberColumns())
107111

108112
public fun <T, C : Number> DataFrame<T>.sumFor(columns: ColumnsForAggregateSelector<T, C?>): DataRow<T> =
109113
Aggregators.sum.aggregateFor(this, columns)
@@ -118,28 +122,70 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
118122
public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: KProperty<C?>): DataRow<T> =
119123
sumFor { columns.toColumnSet() }
120124

125+
@JvmName("sumShort")
126+
@OverloadResolutionByLambdaReturnType
127+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Short?>): Int =
128+
Aggregators.sum.aggregateAll(this, columns) as Int
129+
130+
@JvmName("sumByte")
131+
@OverloadResolutionByLambdaReturnType
132+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Byte?>): Int =
133+
Aggregators.sum.aggregateAll(this, columns) as Int
134+
135+
@JvmName("sumNumber")
136+
@OverloadResolutionByLambdaReturnType
121137
public inline fun <T, reified C : Number> DataFrame<T>.sum(noinline columns: ColumnsSelector<T, C?>): C =
122-
(Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero()
138+
Aggregators.sum.aggregateAll(this, columns) as C
123139

140+
@JvmName("sumShort")
141+
@AccessApiOverload
142+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Short?>): Int = sum { columns.toColumnSet() }
143+
144+
@JvmName("sumByte")
145+
@AccessApiOverload
146+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Byte?>): Int = sum { columns.toColumnSet() }
147+
148+
@JvmName("sumNumber")
124149
@AccessApiOverload
125150
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: ColumnReference<C?>): C =
126151
sum { columns.toColumnSet() }
127152

128-
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() }
153+
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf<Number?>() }
154+
155+
@JvmName("sumShort")
156+
@AccessApiOverload
157+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Short?>): Int = sum { columns.toColumnSet() }
158+
159+
@JvmName("sumByte")
160+
@AccessApiOverload
161+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Byte?>): Int = sum { columns.toColumnSet() }
129162

163+
@JvmName("sumNumber")
130164
@AccessApiOverload
131165
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: KProperty<C?>): C =
132166
sum { columns.toColumnSet() }
133167

134-
public inline fun <T, reified C : Number?> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C>): C =
135-
rows().sumOf(typeOf<C>()) { expression(it, it) }
168+
@JvmName("sumOfShort")
169+
@OverloadResolutionByLambdaReturnType
170+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Short?>): Int =
171+
Aggregators.sum.aggregateOf(this, expression) as Int
172+
173+
@JvmName("sumOfByte")
174+
@OverloadResolutionByLambdaReturnType
175+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Byte?>): Int =
176+
Aggregators.sum.aggregateOf(this, expression) as Int
177+
178+
@JvmName("sumOfNumber")
179+
@OverloadResolutionByLambdaReturnType
180+
public inline fun <T, reified C : Number> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C?>): C =
181+
Aggregators.sum.aggregateOf(this, expression) as C
136182

137183
// endregion
138184

139185
// region GroupBy
140186
@Refine
141187
@Interpretable("GroupBySum1")
142-
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(numberColumns())
188+
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(primitiveNumberColumns())
143189

144190
@Refine
145191
@Interpretable("GroupBySum0")
@@ -183,7 +229,7 @@ public inline fun <T, reified R : Number> Grouped<T>.sumOf(
183229

184230
// region Pivot
185231

186-
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, numberColumns())
232+
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, primitiveNumberColumns())
187233

188234
public fun <T, R : Number> Pivot<T>.sumFor(
189235
separate: Boolean = false,
@@ -213,14 +259,14 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
213259
@AccessApiOverload
214260
public fun <T, C : Number> Pivot<T>.sum(vararg columns: KProperty<C?>): DataRow<T> = sum { columns.toColumnSet() }
215261

216-
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R>): DataRow<T> =
262+
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R?>): DataRow<T> =
217263
delegate { sumOf(expression) }
218264

219265
// endregion
220266

221267
// region PivotGroupBy
222268

223-
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> = sumFor(separate, numberColumns())
269+
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> = sumFor(separate, primitiveNumberColumns())
224270

225271
public fun <T, R : Number> PivotGroupBy<T>.sumFor(
226272
separate: Boolean = false,
@@ -256,7 +302,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
256302
sum { columns.toColumnSet() }
257303

258304
public inline fun <T, reified R : Number> PivotGroupBy<T>.sumOf(
259-
crossinline expression: RowExpression<T, R>,
305+
crossinline expression: RowExpression<T, R?>,
260306
): DataFrame<T> = Aggregators.sum.aggregateOf(this, expression)
261307

262308
// endregion

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt

+6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation
22

33
import org.jetbrains.kotlinx.dataframe.AnyCol
44
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
5+
import org.jetbrains.kotlinx.dataframe.DataRow
56
import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable
67
import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
8+
import org.jetbrains.kotlinx.dataframe.api.cast
79
import org.jetbrains.kotlinx.dataframe.api.filter
810
import org.jetbrains.kotlinx.dataframe.api.isNumber
911
import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber
1012
import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable
1113
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
1214
import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType
15+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
1316

1417
internal inline fun <T> Aggregatable<T>.remainingColumns(
1518
crossinline predicate: (AnyCol) -> Boolean,
@@ -27,6 +30,9 @@ internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =
2730
internal fun <T> Aggregatable<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
2831
remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector<T, Number?>
2932

33+
internal fun <T> DataRow<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
34+
{ cols { it.isPrimitiveNumber() }.cast() }
35+
3036
internal fun NamedValue.toColumnWithPath() =
3137
path to createColumnGuessingType(
3238
name = path.last(),

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes
22

3-
import org.jetbrains.kotlinx.dataframe.AnyRow
43
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
4+
import org.jetbrains.kotlinx.dataframe.DataRow
55
import org.jetbrains.kotlinx.dataframe.api.getColumns
66
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
77

@@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1414
* @param columns selector of which columns inside the [row] to aggregate
1515
*/
1616
@PublishedApi
17-
internal fun <V, R> Aggregator<V, R>.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R {
17+
internal fun <T, V, R> Aggregator<V, R>.aggregateOfRow(row: DataRow<T>, columns: ColumnsSelector<T, V?>): R {
1818
val filteredColumns = row.df().getColumns(columns)
1919
return aggregateCalculatingType(
2020
values = filteredColumns.mapNotNull { row[it] },

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt

+6-10
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ internal fun Sequence<Number?>.sum(type: KType): Number {
4545

4646
/** T: Number? -> T */
4747
internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ ->
48-
when (type.withNullability(false)) {
49-
typeOf<Int>(),
50-
typeOf<Short>(),
51-
typeOf<Byte>(),
52-
-> typeOf<Int>()
48+
when (val type = type.withNullability(false)) {
49+
// type changes to Int
50+
typeOf<Short>(), typeOf<Byte>() -> typeOf<Int>()
5351

54-
typeOf<Long>() -> typeOf<Long>()
55-
56-
typeOf<Double>() -> typeOf<Double>()
57-
58-
typeOf<Float>() -> typeOf<Float>()
52+
// type remains the same
53+
typeOf<Int>(), typeOf<Long>(), typeOf<Double>(), typeOf<Float>() -> type
5954

55+
// defaults to Double
6056
else -> typeOf<Double>()
6157
}
6258
}

0 commit comments

Comments
 (0)
Please sign in to comment.