Skip to content

Commit f2e7a32

Browse files
committed
extra overloads for sum api, fixed parts of aggregateOfRow
1 parent b046891 commit f2e7a32

File tree

6 files changed

+120
-77
lines changed

6 files changed

+120
-77
lines changed

core/api/core.api

+15-9
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt {
38163816

38173817
public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38183818
public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number;
3819+
public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number;
38193820
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38203821
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number;
38213822
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
@@ -3839,6 +3840,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38393840
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38403841
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38413842
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3843+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3844+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3845+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3846+
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38423847
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38433848
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
38443849
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
@@ -3863,8 +3868,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
38633868
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38643869
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
38653870
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
3866-
public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3867-
public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3871+
public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
3872+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3873+
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3874+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
3875+
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3876+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
3877+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
3878+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
3879+
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
38683880
}
38693881

38703882
public final class org/jetbrains/kotlinx/dataframe/api/TailKt {
@@ -6153,13 +6165,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt {
61536165
}
61546166

61556167
public final class org/jetbrains/kotlinx/dataframe/math/SumKt {
6156-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal;
6157-
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger;
6158-
public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6159-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal;
6160-
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger;
6161-
public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
6162-
public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number;
6168+
public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number;
61636169
}
61646170

61656171
public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+2-5
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,13 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
4848
// region DataRow
4949

5050
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
51-
Aggregators.mean(skipNA).aggregateOfRow(this) {
52-
colsOf<Number?> { it.isPrimitiveNumber() }
53-
}
51+
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns())
5452

5553
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
5654
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
5755
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
5856
}
59-
return Aggregators.mean(skipNA)
60-
.aggregateOfRow(this) { colsOf<T>() }
57+
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
6158
}
6259

6360
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+89-51
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
@file:OptIn(ExperimentalTypeInference::class)
2+
@file:Suppress("LocalVariableName")
23

34
package org.jetbrains.kotlinx.dataframe.api
45

@@ -21,83 +22,78 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2122
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2223
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
2324
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
25+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
2426
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
2527
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
26-
import org.jetbrains.kotlinx.dataframe.impl.zero
2728
import kotlin.experimental.ExperimentalTypeInference
29+
import kotlin.reflect.KClass
2830
import kotlin.reflect.KProperty
31+
import kotlin.reflect.KType
32+
import kotlin.reflect.full.withNullability
2933
import kotlin.reflect.typeOf
3034

3135
// region DataColumn
3236

33-
@JvmName("sumInt")
34-
public fun DataColumn<Int?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
35-
3637
@JvmName("sumShort")
3738
public fun DataColumn<Short?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
3839

3940
@JvmName("sumByte")
4041
public fun DataColumn<Byte?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
4142

42-
@JvmName("sumLong")
43-
public fun DataColumn<Long?>.sum(): Long = Aggregators.sum.aggregate(this) as Long
44-
45-
@JvmName("sumFloat")
46-
public fun DataColumn<Float?>.sum(): Float = Aggregators.sum.aggregate(this) as Float
47-
48-
@JvmName("sumDouble")
49-
public fun DataColumn<Double?>.sum(): Double = Aggregators.sum.aggregate(this) as Double
50-
43+
@Suppress("UNCHECKED_CAST")
5144
@JvmName("sumNumber")
52-
public fun DataColumn<Number?>.sum(): Number = Aggregators.sum.aggregate(this)
53-
54-
@JvmName("sumOfInt")
55-
@OverloadResolutionByLambdaReturnType
56-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
45+
public fun <T : Number> DataColumn<T?>.sum(): T = Aggregators.sum.aggregate(this) as T
5746

5847
@JvmName("sumOfShort")
5948
@OverloadResolutionByLambdaReturnType
60-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Short?): Int =
49+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Short?): Int =
6150
Aggregators.sum.aggregateOf(this, expression) as Int
6251

6352
@JvmName("sumOfByte")
6453
@OverloadResolutionByLambdaReturnType
65-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
66-
67-
@JvmName("sumOfLong")
68-
@OverloadResolutionByLambdaReturnType
69-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Long?): Long =
70-
Aggregators.sum.aggregateOf(this, expression) as Long
71-
72-
@JvmName("sumOfFloat")
73-
@OverloadResolutionByLambdaReturnType
74-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Float?): Float =
75-
Aggregators.sum.aggregateOf(this, expression) as Float
76-
77-
@JvmName("sumOfDouble")
78-
@OverloadResolutionByLambdaReturnType
79-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Double?): Double =
80-
Aggregators.sum.aggregateOf(this, expression) as Double
54+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
8155

8256
@JvmName("sumOfNumber")
8357
@OverloadResolutionByLambdaReturnType
84-
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression)
58+
public inline fun <C, reified V : Number> DataColumn<C>.sumOf(crossinline expression: (C) -> V?): V =
59+
Aggregators.sum.aggregateOf(this, expression) as V
8560

8661
// endregion
8762

8863
// region DataRow
8964

90-
public fun AnyRow.rowSum(): Number =
91-
Aggregators.sum.aggregateOfRow(this) {
92-
colsOf<Number?> { it.isPrimitiveNumber() }
93-
}
65+
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns())
9466

95-
public inline fun <reified T : Number> AnyRow.rowSumOf(): Number /*todo*/ {
96-
require(typeOf<T>() in primitiveNumberTypes) {
97-
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
67+
@JvmName("rowSumOfShort")
68+
public inline fun <reified T : Short?> AnyRow.rowSumOf(_kClass: KClass<Short> = Short::class): Int =
69+
rowSumOf(typeOf<T>()) as Int
70+
71+
@JvmName("rowSumOfByte")
72+
public inline fun <reified T : Byte?> AnyRow.rowSumOf(_kClass: KClass<Byte> = Byte::class): Int =
73+
rowSumOf(typeOf<T>()) as Int
74+
75+
@JvmName("rowSumOfInt")
76+
public inline fun <reified T : Int?> AnyRow.rowSumOf(_kClass: KClass<Int> = Int::class): Int =
77+
rowSumOf(typeOf<T>()) as Int
78+
79+
@JvmName("rowSumOfLong")
80+
public inline fun <reified T : Long?> AnyRow.rowSumOf(_kClass: KClass<Long> = Long::class): Long =
81+
rowSumOf(typeOf<T>()) as Long
82+
83+
@JvmName("rowSumOfFloat")
84+
public inline fun <reified T : Float?> AnyRow.rowSumOf(_kClass: KClass<Float> = Float::class): Float =
85+
rowSumOf(typeOf<T>()) as Float
86+
87+
@JvmName("rowSumOfDouble")
88+
public inline fun <reified T : Double?> AnyRow.rowSumOf(_kClass: KClass<Double> = Double::class): Double =
89+
rowSumOf(typeOf<T>()) as Double
90+
91+
// unfortunately, we cannot make a `reified T : Number?` due to clashes
92+
public fun AnyRow.rowSumOf(type: KType): Number {
93+
require(type.withNullability(false) in primitiveNumberTypes) {
94+
"Type $type is not a primitive number type. Mean only supports primitive number types."
9895
}
99-
return Aggregators.sum
100-
.aggregateOfRow(this) { colsOf<T>() }
96+
return Aggregators.sum.aggregateOfRow(this) { colsOf(type) }
10197
}
10298
// endregion
10399

@@ -118,21 +114,63 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
118114
public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: KProperty<C?>): DataRow<T> =
119115
sumFor { columns.toColumnSet() }
120116

117+
@JvmName("sumShort")
118+
@OverloadResolutionByLambdaReturnType
119+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Short?>): Int =
120+
Aggregators.sum.aggregateAll(this, columns) as Int
121+
122+
@JvmName("sumByte")
123+
@OverloadResolutionByLambdaReturnType
124+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Byte?>): Int =
125+
Aggregators.sum.aggregateAll(this, columns) as Int
126+
127+
@JvmName("sumNumber")
128+
@OverloadResolutionByLambdaReturnType
121129
public inline fun <T, reified C : Number> DataFrame<T>.sum(noinline columns: ColumnsSelector<T, C?>): C =
122-
(Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero()
130+
Aggregators.sum.aggregateAll(this, columns) as C
123131

132+
@JvmName("sumShort")
133+
@AccessApiOverload
134+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Short?>): Int = sum { columns.toColumnSet() }
135+
136+
@JvmName("sumByte")
137+
@AccessApiOverload
138+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Byte?>): Int = sum { columns.toColumnSet() }
139+
140+
@JvmName("sumNumber")
124141
@AccessApiOverload
125142
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: ColumnReference<C?>): C =
126143
sum { columns.toColumnSet() }
127144

128-
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() }
145+
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf<Number?>() }
146+
147+
@JvmName("sumShort")
148+
@AccessApiOverload
149+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Short?>): Int = sum { columns.toColumnSet() }
150+
151+
@JvmName("sumByte")
152+
@AccessApiOverload
153+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Byte?>): Int = sum { columns.toColumnSet() }
129154

155+
@JvmName("sumNumber")
130156
@AccessApiOverload
131157
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: KProperty<C?>): C =
132158
sum { columns.toColumnSet() }
133159

134-
public inline fun <T, reified C : Number?> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C>): C =
135-
rows().sumOf(typeOf<C>()) { expression(it, it) }
160+
@JvmName("sumOfShort")
161+
@OverloadResolutionByLambdaReturnType
162+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Short?>): Int =
163+
Aggregators.sum.aggregateOf(this, expression) as Int
164+
165+
@JvmName("sumOfByte")
166+
@OverloadResolutionByLambdaReturnType
167+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Byte?>): Int =
168+
Aggregators.sum.aggregateOf(this, expression) as Int
169+
170+
@JvmName("sumOfNumber")
171+
@OverloadResolutionByLambdaReturnType
172+
public inline fun <T, reified C : Number> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C?>): C =
173+
Aggregators.sum.aggregateOf(this, expression) as C
136174

137175
// endregion
138176

@@ -213,7 +251,7 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
213251
@AccessApiOverload
214252
public fun <T, C : Number> Pivot<T>.sum(vararg columns: KProperty<C?>): DataRow<T> = sum { columns.toColumnSet() }
215253

216-
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R>): DataRow<T> =
254+
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R?>): DataRow<T> =
217255
delegate { sumOf(expression) }
218256

219257
// endregion
@@ -256,7 +294,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
256294
sum { columns.toColumnSet() }
257295

258296
public inline fun <T, reified R : Number> PivotGroupBy<T>.sumOf(
259-
crossinline expression: RowExpression<T, R>,
297+
crossinline expression: RowExpression<T, R?>,
260298
): DataFrame<T> = Aggregators.sum.aggregateOf(this, expression)
261299

262300
// endregion

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt

+6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation
22

33
import org.jetbrains.kotlinx.dataframe.AnyCol
44
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
5+
import org.jetbrains.kotlinx.dataframe.DataRow
56
import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable
67
import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
8+
import org.jetbrains.kotlinx.dataframe.api.cast
79
import org.jetbrains.kotlinx.dataframe.api.filter
810
import org.jetbrains.kotlinx.dataframe.api.isNumber
911
import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber
1012
import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable
1113
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
1214
import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType
15+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
1316

1417
internal inline fun <T> Aggregatable<T>.remainingColumns(
1518
crossinline predicate: (AnyCol) -> Boolean,
@@ -27,6 +30,9 @@ internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =
2730
internal fun <T> Aggregatable<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
2831
remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector<T, Number?>
2932

33+
internal fun <T> DataRow<T>.primitiveNumberColumns(): ColumnsSelector<T, Number?> =
34+
{ cols { it.isPrimitiveNumber() }.cast() }
35+
3036
internal fun NamedValue.toColumnWithPath() =
3137
path to createColumnGuessingType(
3238
name = path.last(),

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes
22

3-
import org.jetbrains.kotlinx.dataframe.AnyRow
43
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
4+
import org.jetbrains.kotlinx.dataframe.DataRow
55
import org.jetbrains.kotlinx.dataframe.api.getColumns
66
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
77

@@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1414
* @param columns selector of which columns inside the [row] to aggregate
1515
*/
1616
@PublishedApi
17-
internal fun <V, R> Aggregator<V, R>.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R {
17+
internal fun <T, V, R> Aggregator<V, R>.aggregateOfRow(row: DataRow<T>, columns: ColumnsSelector<T, V?>): R {
1818
val filteredColumns = row.df().getColumns(columns)
1919
return aggregateCalculatingType(
2020
values = filteredColumns.mapNotNull { row[it] },

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt

+6-10
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,14 @@ internal fun Sequence<Number?>.sum(type: KType): Number {
4545

4646
/** T: Number? -> T */
4747
internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ ->
48-
when (type.withNullability(false)) {
49-
typeOf<Int>(),
50-
typeOf<Short>(),
51-
typeOf<Byte>(),
52-
-> typeOf<Int>()
48+
when (val type = type.withNullability(false)) {
49+
// type changes to Int
50+
typeOf<Short>(), typeOf<Byte>() -> typeOf<Int>()
5351

54-
typeOf<Long>() -> typeOf<Long>()
55-
56-
typeOf<Double>() -> typeOf<Double>()
57-
58-
typeOf<Float>() -> typeOf<Float>()
52+
// type remains the same
53+
typeOf<Int>(), typeOf<Long>(), typeOf<Double>(), typeOf<Float>() -> type
5954

55+
// defaults to Double
6056
else -> typeOf<Double>()
6157
}
6258
}

0 commit comments

Comments
 (0)