Skip to content

Commit 5ebabf9

Browse files
committedMar 18, 2025·
extra overloads for sum api, fixed parts of aggregateOfRow
1 parent b046891 commit 5ebabf9

File tree

6 files changed

+131
-80
lines changed

6 files changed

+131
-80
lines changed
 

‎core/api/core.api

+15-9
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt {
38163816

38173817
public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38183818
public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number;
3819+
public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number;
38193820
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38203821
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number;
38213822
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
@@ -3839,6 +3840,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38393840
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38403841
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38413842
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3843+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3844+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3845+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3846+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38423847
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38433848
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38443849
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
@@ -3863,8 +3868,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38633868
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38643869
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38653870
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3866-
public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3867-
public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3871+
public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3872+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3873+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3874+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3875+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3876+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3877+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3878+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3879+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38683880
}
38693881

38703882
public final class org/jetbrains/kotlinx/dataframe/api/TailKt {
@@ -6153,13 +6165,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt {
61536165
}
61546166

61556167
public final class org/jetbrains/kotlinx/dataframe/math/SumKt {
6156-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal;
6157-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger;
6158-
public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6159-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal;
6160-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger;
6161-
public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6162-
public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number;
6168+
public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number;
61636169
}
61646170

61656171
public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema {

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+4-8
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ import kotlin.reflect.KProperty
2525
import kotlin.reflect.full.withNullability
2626
import kotlin.reflect.typeOf
2727

28-
/*
29-
* TODO KDocs:
28+
/* TODO KDocs:
3029
* Calculating the mean is supported for all primitive number types.
31-
* Nulls are filtered from columns.
30+
* Nulls are filtered out.
3231
* The return type is always Double, Double.NaN for empty input, never null.
3332
* (May introduce loss of precision for Longs).
3433
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
@@ -48,16 +47,13 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
4847
// region DataRow
4948

5049
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
51-
Aggregators.mean(skipNA).aggregateOfRow(this) {
52-
colsOf<Number?> { it.isPrimitiveNumber() }
53-
}
50+
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns())
5451

5552
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
5653
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
5754
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
5855
}
59-
return Aggregators.mean(skipNA)
60-
.aggregateOfRow(this) { colsOf<T>() }
56+
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
6157
}
6258

6359
// endregion

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+98-51
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@file:OptIn(ExperimentalTypeInference::class)
2+
@file:Suppress("LocalVariableName")
23

34
package org.jetbrains.kotlinx.dataframe.api
45

@@ -21,83 +22,87 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2122
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2223
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
2324
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
25+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
2426
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
2527
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
26-
import org.jetbrains.kotlinx.dataframe.impl.zero
2728
import kotlin.experimental.ExperimentalTypeInference
29+
import kotlin.reflect.KClass
2830
import kotlin.reflect.KProperty
31+
import kotlin.reflect.KType
32+
import kotlin.reflect.full.withNullability
2933
import kotlin.reflect.typeOf
3034

31-
// region DataColumn
35+
/* TODO KDocs
36+
* Calculating the sum is supported for all primitive number types.
37+
* Nulls are filtered out.
38+
* The return type is always the same as the input type (never null), except for `Byte` and `Short`,
39+
* which are converted to `Int`.
40+
* Empty input will result in 0 in the supplied number type.
41+
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum.
42+
*/
3243

33-
@JvmName("sumInt")
34-
public fun DataColumn<Int?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
44+
// region DataColumn
3545

3646
@JvmName("sumShort")
3747
public fun DataColumn<Short?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
3848

3949
@JvmName("sumByte")
4050
public fun DataColumn<Byte?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
4151

42-
@JvmName("sumLong")
43-
public fun DataColumn<Long?>.sum(): Long = Aggregators.sum.aggregate(this) as Long
44-
45-
@JvmName("sumFloat")
46-
public fun DataColumn<Float?>.sum(): Float = Aggregators.sum.aggregate(this) as Float
47-
48-
@JvmName("sumDouble")
49-
public fun DataColumn<Double?>.sum(): Double = Aggregators.sum.aggregate(this) as Double
50-
52+
@Suppress("UNCHECKED_CAST")
5153
@JvmName("sumNumber")
52-
public fun DataColumn<Number?>.sum(): Number = Aggregators.sum.aggregate(this)
53-
54-
@JvmName("sumOfInt")
55-
@OverloadResolutionByLambdaReturnType
56-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
54+
public fun <T : Number> DataColumn<T?>.sum(): T = Aggregators.sum.aggregate(this) as T
5755

5856
@JvmName("sumOfShort")
5957
@OverloadResolutionByLambdaReturnType
60-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Short?): Int =
58+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Short?): Int =
6159
Aggregators.sum.aggregateOf(this, expression) as Int
6260

6361
@JvmName("sumOfByte")
6462
@OverloadResolutionByLambdaReturnType
65-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
66-
67-
@JvmName("sumOfLong")
68-
@OverloadResolutionByLambdaReturnType
69-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Long?): Long =
70-
Aggregators.sum.aggregateOf(this, expression) as Long
71-
72-
@JvmName("sumOfFloat")
73-
@OverloadResolutionByLambdaReturnType
74-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Float?): Float =
75-
Aggregators.sum.aggregateOf(this, expression) as Float
76-
77-
@JvmName("sumOfDouble")
78-
@OverloadResolutionByLambdaReturnType
79-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Double?): Double =
80-
Aggregators.sum.aggregateOf(this, expression) as Double
63+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
8164

8265
@JvmName("sumOfNumber")
8366
@OverloadResolutionByLambdaReturnType
84-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression)
67+
public inline fun <C, reified V : Number> DataColumn<C>.sumOf(crossinline expression: (C) -> V?): V =
68+
Aggregators.sum.aggregateOf(this, expression) as V
8569

8670
// endregion
8771

8872
// region DataRow
8973

90-
public fun AnyRow.rowSum(): Number =
91-
Aggregators.sum.aggregateOfRow(this) {
92-
colsOf<Number?> { it.isPrimitiveNumber() }
93-
}
74+
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns())
9475

95-
public inline fun <reified T : Number> AnyRow.rowSumOf(): Number /*todo*/ {
96-
require(typeOf<T>() in primitiveNumberTypes) {
97-
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
76+
@JvmName("rowSumOfShort")
77+
public inline fun <reified T : Short?> AnyRow.rowSumOf(_kClass: KClass<Short> = Short::class): Int =
78+
rowSumOf(typeOf<T>()) as Int
79+
80+
@JvmName("rowSumOfByte")
81+
public inline fun <reified T : Byte?> AnyRow.rowSumOf(_kClass: KClass<Byte> = Byte::class): Int =
82+
rowSumOf(typeOf<T>()) as Int
83+
84+
@JvmName("rowSumOfInt")
85+
public inline fun <reified T : Int?> AnyRow.rowSumOf(_kClass: KClass<Int> = Int::class): Int =
86+
rowSumOf(typeOf<T>()) as Int
87+
88+
@JvmName("rowSumOfLong")
89+
public inline fun <reified T : Long?> AnyRow.rowSumOf(_kClass: KClass<Long> = Long::class): Long =
90+
rowSumOf(typeOf<T>()) as Long
91+
92+
@JvmName("rowSumOfFloat")
93+
public inline fun <reified T : Float?> AnyRow.rowSumOf(_kClass: KClass<Float> = Float::class): Float =
94+
rowSumOf(typeOf<T>()) as Float
95+
96+
@JvmName("rowSumOfDouble")
97+
public inline fun <reified T : Double?> AnyRow.rowSumOf(_kClass: KClass<Double> = Double::class): Double =
98+
rowSumOf(typeOf<T>()) as Double
99+
100+
// unfortunately, we cannot make a `reified T : Number?` due to clashes
101+
public fun AnyRow.rowSumOf(type: KType): Number {
102+
require(type.withNullability(false) in primitiveNumberTypes) {
103+
"Type $type is not a primitive number type. Mean only supports primitive number types."
98104
}
99-
return Aggregators.sum
100-
.aggregateOfRow(this) { colsOf<T>() }
105+
return Aggregators.sum.aggregateOfRow(this) { colsOf(type) }
101106
}
102107
// endregion
103108

@@ -118,21 +123,63 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
118123
public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: KProperty<C?>): DataRow<T> =
119124
sumFor { columns.toColumnSet() }
120125

126+
@JvmName("sumShort")
127+
@OverloadResolutionByLambdaReturnType
128+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Short?>): Int =
129+
Aggregators.sum.aggregateAll(this, columns) as Int
130+
131+
@JvmName("sumByte")
132+
@OverloadResolutionByLambdaReturnType
133+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Byte?>): Int =
134+
Aggregators.sum.aggregateAll(this, columns) as Int
135+
136+
@JvmName("sumNumber")
137+
@OverloadResolutionByLambdaReturnType
121138
public inline fun <T, reified C : Number> DataFrame<T>.sum(noinline columns: ColumnsSelector<T, C?>): C =
122-
(Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero()
139+
Aggregators.sum.aggregateAll(this, columns) as C
123140

141+
@JvmName("sumShort")
142+
@AccessApiOverload
143+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Short?>): Int = sum { columns.toColumnSet() }
144+
145+
@JvmName("sumByte")
146+
@AccessApiOverload
147+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Byte?>): Int = sum { columns.toColumnSet() }
148+
149+
@JvmName("sumNumber")
124150
@AccessApiOverload
125151
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: ColumnReference<C?>): C =
126152
sum { columns.toColumnSet() }
127153

128-
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() }
154+
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf<Number?>() }
155+
156+
@JvmName("sumShort")
157+
@AccessApiOverload
158+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Short?>): Int = sum { columns.toColumnSet() }
159+
160+
@JvmName("sumByte")
161+
@AccessApiOverload
162+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Byte?>): Int = sum { columns.toColumnSet() }
129163

164+
@JvmName("sumNumber")
130165
@AccessApiOverload
131166
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: KProperty<C?>): C =
132167
sum { columns.toColumnSet() }
133168

134-
public inline fun <T, reified C : Number?> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C>): C =
135-
rows().sumOf(typeOf<C>()) { expression(it, it) }
169+
@JvmName("sumOfShort")
170+
@OverloadResolutionByLambdaReturnType
171+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Short?>): Int =
172+
Aggregators.sum.aggregateOf(this, expression) as Int
173+
174+
@JvmName("sumOfByte")
175+
@OverloadResolutionByLambdaReturnType
176+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Byte?>): Int =
177+
Aggregators.sum.aggregateOf(this, expression) as Int
178+
179+
@JvmName("sumOfNumber")
180+
@OverloadResolutionByLambdaReturnType
181+
public inline fun <T, reified C : Number> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C?>): C =
182+
Aggregators.sum.aggregateOf(this, expression) as C
136183

137184
// endregion
138185

@@ -213,7 +260,7 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
213260
@AccessApiOverload
214261
public fun <T, C : Number> Pivot<T>.sum(vararg columns: KProperty<C?>): DataRow<T> = sum { columns.toColumnSet() }
215262

216-
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R>): DataRow<T> =
263+
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R?>): DataRow<T> =
217264
delegate { sumOf(expression) }
218265

219266
// endregion
@@ -256,7 +303,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
256303
sum { columns.toColumnSet() }
257304

258305
public inline fun <T, reified R : Number> PivotGroupBy<T>.sumOf(
259-
crossinline expression: RowExpression<T, R>,
306+
crossinline expression: RowExpression<T, R?>,
260307
): DataFrame<T> = Aggregators.sum.aggregateOf(this, expression)
261308

262309
// endregion

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt

+6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation
22

33
import org.jetbrains.kotlinx.dataframe.AnyCol
44
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
5+
import org.jetbrains.kotlinx.dataframe.DataRow
56
import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable
67
import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
8+
import org.jetbrains.kotlinx.dataframe.api.cast
79
import org.jetbrains.kotlinx.dataframe.api.filter
810
import org.jetbrains.kotlinx.dataframe.api.isNumber
911
import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber
1012
import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable
1113
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
1214
import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType
15+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
1316

1417
internal inline fun <T> Aggregatable<T>.remainingColumns(
1518
crossinline predicate: (AnyCol) -> Boolean,
@@ -27,6 +30,9 @@ internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =
2730
internal fun <T> Aggregatable<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
2831
remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector<T, Number?>
2932

33+
internal fun <T> DataRow<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
34+
{ cols { it.isPrimitiveNumber() }.cast() }
35+
3036
internal fun NamedValue.toColumnWithPath() =
3137
path to createColumnGuessingType(
3238
name = path.last(),

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes
22

3-
import org.jetbrains.kotlinx.dataframe.AnyRow
43
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
4+
import org.jetbrains.kotlinx.dataframe.DataRow
55
import org.jetbrains.kotlinx.dataframe.api.getColumns
66
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
77

@@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1414
* @param columns selector of which columns inside the [row] to aggregate
1515
*/
1616
@PublishedApi
17-
internal fun <V, R> Aggregator<V, R>.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R {
17+
internal fun <T, V, R> Aggregator<V, R>.aggregateOfRow(row: DataRow<T>, columns: ColumnsSelector<T, V?>): R {
1818
val filteredColumns = row.df().getColumns(columns)
1919
return aggregateCalculatingType(
2020
values = filteredColumns.mapNotNull { row[it] },

‎core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt

+6-10
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ internal fun Sequence<Number?>.sum(type: KType): Number {
4545

4646
/** T: Number? -> T */
4747
internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ ->
48-
when (type.withNullability(false)) {
49-
typeOf<Int>(),
50-
typeOf<Short>(),
51-
typeOf<Byte>(),
52-
-> typeOf<Int>()
48+
when (val type = type.withNullability(false)) {
49+
// type changes to Int
50+
typeOf<Short>(), typeOf<Byte>() -> typeOf<Int>()
5351

54-
typeOf<Long>() -> typeOf<Long>()
55-
56-
typeOf<Double>() -> typeOf<Double>()
57-
58-
typeOf<Float>() -> typeOf<Float>()
52+
// type remains the same
53+
typeOf<Int>(), typeOf<Long>(), typeOf<Double>(), typeOf<Float>() -> type
5954

55+
// defaults to Double
6056
else -> typeOf<Double>()
6157
}
6258
}

0 commit comments

Comments
 (0)
Please sign in to comment.