Skip to content

Commit b046891

Browse files
committed
update from mean branch, added sumOf
1 parent 1317fa1 commit b046891

File tree

4 files changed

+40
-46
lines changed

4 files changed

+40
-46
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+34-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@file:OptIn(ExperimentalTypeInference::class)
2+
13
package org.jetbrains.kotlinx.dataframe.api
24

35
import org.jetbrains.kotlinx.dataframe.AnyRow
@@ -13,19 +15,16 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine
1315
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1416
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1517
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
16-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1718
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators
18-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast
1919
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
2020
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2121
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2222
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
23-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of
2423
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
2524
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
2625
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
2726
import org.jetbrains.kotlinx.dataframe.impl.zero
28-
import org.jetbrains.kotlinx.dataframe.math.sumOf
27+
import kotlin.experimental.ExperimentalTypeInference
2928
import kotlin.reflect.KProperty
3029
import kotlin.reflect.typeOf
3130

@@ -52,8 +51,37 @@ public fun DataColumn<Double?>.sum(): Double = Aggregators.sum.aggregate(this) a
5251
@JvmName("sumNumber")
5352
public fun DataColumn<Number?>.sum(): Number = Aggregators.sum.aggregate(this)
5453

55-
public inline fun <T, reified R : Number> DataColumn<T>.sumOf(noinline expression: (T) -> R): R? =
56-
(Aggregators.sum as Aggregator<*, *>).cast<R>().aggregateOf(this, expression)
54+
@JvmName("sumOfInt")
55+
@OverloadResolutionByLambdaReturnType
56+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
57+
58+
@JvmName("sumOfShort")
59+
@OverloadResolutionByLambdaReturnType
60+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Short?): Int =
61+
Aggregators.sum.aggregateOf(this, expression) as Int
62+
63+
@JvmName("sumOfByte")
64+
@OverloadResolutionByLambdaReturnType
65+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
66+
67+
@JvmName("sumOfLong")
68+
@OverloadResolutionByLambdaReturnType
69+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Long?): Long =
70+
Aggregators.sum.aggregateOf(this, expression) as Long
71+
72+
@JvmName("sumOfFloat")
73+
@OverloadResolutionByLambdaReturnType
74+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Float?): Float =
75+
Aggregators.sum.aggregateOf(this, expression) as Float
76+
77+
@JvmName("sumOfDouble")
78+
@OverloadResolutionByLambdaReturnType
79+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Double?): Double =
80+
Aggregators.sum.aggregateOf(this, expression) as Double
81+
82+
@JvmName("sumOfNumber")
83+
@OverloadResolutionByLambdaReturnType
84+
public fun <T> DataColumn<T>.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression)
5785

5886
// endregion
5987

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ internal fun <T, C, R> Aggregator<*, R>.aggregateFor(
4141
}
4242

4343
internal fun <T, C, R> AggregateInternalDsl<T>.aggregateFor(
44-
columns: ColumnsForAggregateSelector<T, C>,
44+
columns: ColumnsForAggregateSelector<T, C?>,
4545
aggregator: Aggregator<C, R>,
4646
) {
4747
val cols = df.getAggregateColumns(columns)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,26 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal
1515
import org.jetbrains.kotlinx.dataframe.impl.emptyPath
1616

1717
@PublishedApi
18-
internal fun <T, C, R> Aggregator<*, R>.aggregateAll(data: DataFrame<T>, columns: ColumnsSelector<T, C>): R =
18+
internal fun <T, C, R> Aggregator<*, R>.aggregateAll(data: DataFrame<T>, columns: ColumnsSelector<T, C?>): R =
1919
data.aggregateAll(cast2(), columns)
2020

2121
internal fun <T, R, C> Aggregator<*, R>.aggregateAll(
2222
data: Grouped<T>,
2323
name: String?,
24-
columns: ColumnsSelector<T, C>,
24+
columns: ColumnsSelector<T, C?>,
2525
): DataFrame<T> = data.aggregateAll(cast(), columns, name)
2626

2727
internal fun <T, R, C> Aggregator<*, R>.aggregateAll(
2828
data: PivotGroupBy<T>,
29-
columns: ColumnsSelector<T, C>,
29+
columns: ColumnsSelector<T, C?>,
3030
): DataFrame<T> = data.aggregateAll(cast(), columns)
3131

3232
internal fun <T, C, R> DataFrame<T>.aggregateAll(aggregator: Aggregator<C, R>, columns: ColumnsSelector<T, C?>): R =
3333
aggregator.aggregate(get(columns))
3434

3535
internal fun <T, C, R> Grouped<T>.aggregateAll(
3636
aggregator: Aggregator<C, R>,
37-
columns: ColumnsSelector<T, C>,
37+
columns: ColumnsSelector<T, C?>,
3838
name: String?,
3939
): DataFrame<T> =
4040
aggregateInternal {
@@ -48,7 +48,7 @@ internal fun <T, C, R> Grouped<T>.aggregateAll(
4848

4949
internal fun <T, C, R> PivotGroupBy<T>.aggregateAll(
5050
aggregator: Aggregator<C, R>,
51-
columns: ColumnsSelector<T, C>,
51+
columns: ColumnsSelector<T, C?>,
5252
): DataFrame<T> =
5353
aggregate {
5454
val cols = get(columns)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt

-34
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,6 @@ import kotlin.reflect.full.withNullability
88
import kotlin.reflect.typeOf
99
import kotlin.sequences.filterNotNull
1010

11-
@PublishedApi
12-
internal fun <T, R : Number> Iterable<T>.sumOf(type: KType, selector: (T) -> R?): Number =
13-
asSequence().sumOf(type, selector)
14-
15-
@Suppress("UNCHECKED_CAST")
16-
@PublishedApi
17-
internal fun <T, R : Number> Sequence<T>.sumOf(type: KType, selector: (T) -> R?): Number {
18-
if (type.isMarkedNullable) {
19-
return filterNotNull().sumOf(type.withNullability(false), selector)
20-
}
21-
return when (type.withNullability(false)) {
22-
typeOf<Double>() -> sumOf(selector as (T) -> Double)
23-
24-
typeOf<Float>() -> map(selector as (T) -> Float).sum()
25-
26-
typeOf<Int>() -> sumOf(selector as (T) -> Int)
27-
28-
// Note: returns Int
29-
typeOf<Short>() -> map(selector as (T) -> Short).sum()
30-
31-
// Note: returns Int
32-
typeOf<Byte>() -> map(selector as (T) -> Byte).sum()
33-
34-
typeOf<Long>() -> sumOf(selector as (T) -> Long)
35-
36-
nothingType -> 0.0
37-
38-
typeOf<Number>() ->
39-
error("Encountered non-specific Number type in sumOf function. This should not occur.")
40-
41-
else -> throw IllegalArgumentException("sumOf is not supported for $type")
42-
}
43-
}
44-
4511
internal fun Iterable<Number?>.sum(type: KType): Number = asSequence().sum(type)
4612

4713
@Suppress("UNCHECKED_CAST")

0 commit comments

Comments
 (0)