Skip to content

Commit b8043c6

Browse files
authored
Merge pull request #1122 from Kotlin/median
Median overhaul
2 parents 51592e6 + fe5c4a9 commit b8043c6

File tree

22 files changed

+1209
-257
lines changed

22 files changed

+1209
-257
lines changed

core/api/core.api

+107-55
Large diffs are not rendered by default.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.maxOrNull(skipNaN: Boolean = skipN
3636

3737
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxBy(
3838
skipNaN: Boolean = skipNaNDefault,
39-
noinline selector: (T) -> R,
39+
crossinline selector: (T) -> R,
4040
): T & Any = maxByOrNull(skipNaN, selector).suggestIfNull("maxBy")
4141

4242
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxByOrNull(
4343
skipNaN: Boolean = skipNaNDefault,
44-
noinline selector: (T) -> R,
44+
crossinline selector: (T) -> R,
4545
): T? = Aggregators.max<R>(skipNaN).aggregateByOrNull(this, selector)
4646

4747
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOf(
@@ -59,10 +59,10 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOfOrNul
5959
// region DataRow
6060

6161
@Deprecated(ROW_MAX_OR_NULL, level = DeprecationLevel.ERROR)
62-
public fun AnyRow.rowMaxOrNull(): Any? = error(ROW_MAX_OR_NULL)
62+
public fun AnyRow.rowMaxOrNull(): Nothing? = error(ROW_MAX_OR_NULL)
6363

6464
@Deprecated(ROW_MAX, level = DeprecationLevel.ERROR)
65-
public fun AnyRow.rowMax(): Any = error(ROW_MAX)
65+
public fun AnyRow.rowMax(): Nothing = error(ROW_MAX)
6666

6767
public inline fun <reified T : Comparable<T>> AnyRow.rowMaxOfOrNull(skipNaN: Boolean = skipNaNDefault): T? =
6868
Aggregators.max<T>(skipNaN).aggregateOfRow(this) { colsOf<T?>() }

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

+411-111
Large diffs are not rendered by default.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.minOrNull(skipNaN: Boolean = skipN
3636

3737
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minBy(
3838
skipNaN: Boolean = skipNaNDefault,
39-
noinline selector: (T) -> R,
39+
crossinline selector: (T) -> R,
4040
): T & Any = minByOrNull(skipNaN, selector).suggestIfNull("minBy")
4141

4242
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minByOrNull(
4343
skipNaN: Boolean = skipNaNDefault,
44-
noinline selector: (T) -> R,
44+
crossinline selector: (T) -> R,
4545
): T? = Aggregators.min<R>(skipNaN).aggregateByOrNull(this, selector)
4646

4747
public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOf(
@@ -59,10 +59,10 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOfOrNul
5959
// region DataRow
6060

6161
@Deprecated(ROW_MIN_OR_NULL, level = DeprecationLevel.ERROR)
62-
public fun AnyRow.rowMinOrNull(): Any? = error(ROW_MIN_OR_NULL)
62+
public fun AnyRow.rowMinOrNull(): Nothing? = error(ROW_MIN_OR_NULL)
6363

6464
@Deprecated(ROW_MIN, level = DeprecationLevel.ERROR)
65-
public fun AnyRow.rowMin(): Any = error(ROW_MIN)
65+
public fun AnyRow.rowMin(): Nothing = error(ROW_MIN)
6666

6767
public inline fun <reified T : Comparable<T>> AnyRow.rowMinOfOrNull(skipNaN: Boolean = skipNaNDefault): T? =
6868
Aggregators.min<T>(skipNaN).aggregateOfRow(this) { colsOf<T?>() }

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
2020
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
2121
import org.jetbrains.kotlinx.dataframe.math.percentile
2222
import kotlin.reflect.KProperty
23+
import kotlin.reflect.typeOf
2324

2425
// region DataColumn
2526

@@ -52,7 +53,7 @@ public fun AnyRow.rowPercentile(percentile: Double): Any =
5253
rowPercentileOrNull(percentile).suggestIfNull("rowPercentile")
5354

5455
public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOfOrNull(percentile: Double): T? =
55-
valuesOf<T>().percentile(percentile)
56+
valuesOf<T>().percentile(percentile, typeOf<T>())
5657

5758
public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOf(percentile: Double): T =
5859
rowPercentileOfOrNull<T>(percentile).suggestIfNull("rowPercentileOf")

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.DataColumn
4+
import org.jetbrains.kotlinx.dataframe.api.asSequence
45
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
56
import kotlin.reflect.KType
67

@@ -28,7 +29,11 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
2829
* Aggregates the data in the given column and computes a single resulting value.
2930
* Calls [aggregateSequence].
3031
*/
31-
fun aggregateSingleColumn(column: DataColumn<Value?>): Return
32+
fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
33+
aggregateSequence(
34+
values = column.asSequence(),
35+
valueType = column.type().toValueType(),
36+
)
3237

3338
/**
3439
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

+46-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

3+
import org.jetbrains.kotlinx.dataframe.api.skipNaNDefault
4+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.HybridAggregationHandler
35
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.ReducingAggregationHandler
46
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
57
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.AnyInputHandler
68
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
79
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
810
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
911
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
12+
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
1013
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
1114
import org.jetbrains.kotlinx.dataframe.math.maxOrNull
1215
import org.jetbrains.kotlinx.dataframe.math.maxTypeConversion
1316
import org.jetbrains.kotlinx.dataframe.math.mean
1417
import org.jetbrains.kotlinx.dataframe.math.meanTypeConversion
15-
import org.jetbrains.kotlinx.dataframe.math.median
18+
import org.jetbrains.kotlinx.dataframe.math.medianConversion
19+
import org.jetbrains.kotlinx.dataframe.math.medianOrNull
1620
import org.jetbrains.kotlinx.dataframe.math.minOrNull
1721
import org.jetbrains.kotlinx.dataframe.math.minTypeConversion
1822
import org.jetbrains.kotlinx.dataframe.math.percentile
@@ -29,13 +33,23 @@ internal object Aggregators {
2933
private fun <Value : Return & Any, Return : Any?> twoStepSelectingForAny(
3034
getReturnType: CalculateReturnType,
3135
indexOfResult: IndexOfResult<Value>,
32-
stepOneReducer: Reducer<Value, Return>,
36+
stepOneSelector: Selector<Value, Return>,
3337
) = Aggregator(
34-
aggregationHandler = SelectingAggregationHandler(stepOneReducer, indexOfResult, getReturnType),
38+
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
3539
inputHandler = AnyInputHandler(),
3640
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
3741
)
3842

43+
private fun <Value : Any, Return : Any?> flattenHybridForAny(
44+
getReturnType: CalculateReturnType,
45+
indexOfResult: IndexOfResult<Value>,
46+
reducer: Reducer<Value, Return>,
47+
) = Aggregator(
48+
aggregationHandler = HybridAggregationHandler(reducer, indexOfResult, getReturnType),
49+
inputHandler = AnyInputHandler(),
50+
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
51+
)
52+
3953
private fun <Value : Any, Return : Any?> twoStepReducingForAny(
4054
getReturnType: CalculateReturnType,
4155
stepOneReducer: Reducer<Value, Return>,
@@ -101,7 +115,7 @@ internal object Aggregators {
101115
private val min by withOneOption { skipNaN: Boolean ->
102116
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
103117
getReturnType = minTypeConversion,
104-
stepOneReducer = { type -> minOrNull(type, skipNaN) },
118+
stepOneSelector = { type -> minOrNull(type, skipNaN) },
105119
indexOfResult = { type -> indexOfMin(type, skipNaN) },
106120
)
107121
}
@@ -113,15 +127,15 @@ internal object Aggregators {
113127
private val max by withOneOption { skipNaN: Boolean ->
114128
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
115129
getReturnType = maxTypeConversion,
116-
stepOneReducer = { type -> maxOrNull(type, skipNaN) },
130+
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
117131
indexOfResult = { type -> indexOfMax(type, skipNaN) },
118132
)
119133
}
120134

121135
// T: Number? -> Double
122-
val std by withTwoOptions { skipNA: Boolean, ddof: Int ->
136+
val std by withTwoOptions { skipNaN: Boolean, ddof: Int ->
123137
flattenReducingForNumbers(stdTypeConversion) { type ->
124-
std(type, skipNA, ddof)
138+
std(type, skipNaN, ddof)
125139
}
126140
}
127141

@@ -140,9 +154,31 @@ internal object Aggregators {
140154
}
141155
}
142156

143-
// T: Comparable<T>? -> T
144-
val median by flattenReducingForAny<Comparable<Any?>> { type ->
145-
asIterable().median(type)
157+
// T : primitive Number? -> Double?
158+
// T : Comparable<T & Any>? -> T?
159+
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
160+
where T : Comparable<T & Any>? =
161+
median.invoke(skipNaN).cast2()
162+
163+
// T : Comparable<T & Any>? -> T?
164+
fun <T> medianComparables(): Aggregator<T & Any, T?>
165+
where T : Comparable<T & Any>? =
166+
medianCommon<T>(skipNaNDefault).cast2()
167+
168+
// T : primitive Number? -> Double?
169+
fun <T> medianNumbers(
170+
skipNaN: Boolean,
171+
): Aggregator<T & Any, Double?>
172+
where T : Comparable<T & Any>?, T : Number? =
173+
medianCommon<T>(skipNaN).cast2()
174+
175+
@Suppress("UNCHECKED_CAST")
176+
private val median by withOneOption { skipNaN: Boolean ->
177+
flattenHybridForAny<Comparable<Any>, Comparable<Any>?>(
178+
getReturnType = medianConversion,
179+
reducer = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
180+
indexOfResult = { type -> indexOfMedian(type, skipNaN) },
181+
)
146182
}
147183

148184
// T: Number -> T
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
2+
3+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
4+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
5+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
6+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.IndexOfResult
7+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
8+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
9+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
10+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
11+
import kotlin.reflect.KType
12+
import kotlin.reflect.full.withNullability
13+
14+
/**
15+
* Implementation of [AggregatorAggregationHandler] which functions like a selector ánd reducer:
16+
* it takes a sequence of values and returns a single value, which is likely part of the input, but not necessarily.
17+
*
18+
* In practice, this means the handler implements both [indexOfAggregationResultSingleSequence]
19+
* (meaning it can give an index of the result in the input), and [aggregateSequence] with a return type that is
20+
* potentially different from the input.
21+
* The return value of [aggregateSequence] and the value at the index retrieved from [indexOfAggregationResultSingleSequence]
22+
* may differ.
23+
*
24+
* @param reducer This function actually does the selection/reduction.
25+
* Before it is called, nulls are filtered out. The type of the values is passed as [KType] to the selector.
26+
* @param indexOfResult This function must be supplied to give the index of the result in the input values.
27+
* @param getReturnType This function must be supplied to give the return type of [reducer] given some input type and
28+
* whether the input is empty.
29+
* When selecting, the return type is always `typeOf<Value>()` or `typeOf<Value?>()`, when reducing it can be anything.
30+
* @see [ReducingAggregationHandler]
31+
*/
32+
internal class HybridAggregationHandler<in Value : Any, out Return : Any?>(
33+
val reducer: Reducer<Value, Return>,
34+
val indexOfResult: IndexOfResult<Value>,
35+
val getReturnType: CalculateReturnType,
36+
) : AggregatorAggregationHandler<Value, Return> {
37+
38+
/**
39+
* Function that can give the index of the aggregation result in the input [values].
40+
* Calls the supplied [indexOfResult] after preprocessing the input.
41+
*/
42+
@Suppress("UNCHECKED_CAST")
43+
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int {
44+
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
45+
return indexOfResult(values, valueType)
46+
}
47+
48+
/**
49+
* Base function of [Aggregator].
50+
*
51+
* Aggregates the given values, taking [valueType] into account,
52+
* filtering nulls (only if [valueType.type.isMarkedNullable][KType.isMarkedNullable]),
53+
* and computes a single resulting value.
54+
*
55+
* When the exact [valueType] is unknown, use [calculateValueType] or [aggregateCalculatingValueType].
56+
*
57+
* Calls the supplied [reducer].
58+
*/
59+
@Suppress("UNCHECKED_CAST")
60+
override fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return {
61+
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
62+
return reducer(
63+
// values =
64+
if (valueType.isMarkedNullable) {
65+
values.filterNotNull()
66+
} else {
67+
values as Sequence<Value>
68+
},
69+
// type =
70+
valueType.withNullability(false),
71+
)
72+
}
73+
74+
/**
75+
* Give the return type of [reducer] given some input type and whether the input is empty.
76+
* Calls the supplied [getReturnType].
77+
*/
78+
override fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType =
79+
getReturnType(valueType.withNullability(false), emptyInput)
80+
81+
override var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? = null
82+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/aggregationHandlers/ReducingAggregationHandler.kt

-14
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
22

3-
import org.jetbrains.kotlinx.dataframe.DataColumn
4-
import org.jetbrains.kotlinx.dataframe.api.asSequence
53
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
64
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
75
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
86
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
97
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
108
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
119
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
12-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
1310
import kotlin.reflect.KType
1411
import kotlin.reflect.full.withNullability
1512

@@ -54,17 +51,6 @@ internal class ReducingAggregationHandler<in Value : Any, out Return : Any?>(
5451
)
5552
}
5653

57-
/**
58-
* Aggregates the data in the given column and computes a single resulting value.
59-
* Calls [aggregateSequence].
60-
*/
61-
@Suppress("UNCHECKED_CAST")
62-
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
63-
aggregateSequence(
64-
values = column.asSequence(),
65-
valueType = column.type().toValueType(),
66-
)
67-
6854
/** This function always returns `-1` because the result of a reducer is not in the input values. */
6955
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int = -1
7056

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/aggregationHandlers/SelectingAggregationHandler.kt

+1-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
22

3-
import org.jetbrains.kotlinx.dataframe.DataColumn
4-
import org.jetbrains.kotlinx.dataframe.api.asSequence
53
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
64
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
75
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
@@ -10,7 +8,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Selector
108
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
119
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
1210
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
13-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
1411
import kotlin.reflect.KType
1512
import kotlin.reflect.full.withNullability
1613

@@ -70,16 +67,6 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
7067
)
7168
}
7269

73-
/**
74-
* Aggregates the data in the given column and computes a single resulting value.
75-
* Calls [aggregateSequence].
76-
*/
77-
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
78-
aggregateSequence(
79-
values = column.asSequence(),
80-
valueType = column.type().toValueType(),
81-
)
82-
8370
/**
8471
* Give the return type of [selector] given some input type and whether the input is empty.
8572
* Calls the supplied [getReturnType].
@@ -91,7 +78,7 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
9178
require(it == valueType.withNullability(false) || it == valueType.withNullability(true)) {
9279
"The return type of the selector must be either ${valueType.withNullability(false)} or ${
9380
valueType.withNullability(true)
94-
}"
81+
} but was $it."
9582
}
9683
}
9784

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt

+10-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@ internal inline fun <T> Aggregatable<T>.remainingColumns(
1717
crossinline predicate: (AnyCol) -> Boolean,
1818
): ColumnsSelector<T, Any?> = remainingColumnsSelector().filter { predicate(it.data) }
1919

20+
/**
21+
* Emulates selecting all columns whose values are comparable to each other.
22+
* These are columns of type `R` where `R : Comparable<R>`.
23+
*
24+
* There is no way to denote this generically in types, however,
25+
* hence the _fake_ type `Comparable<Any>` is used.
26+
* (`Comparable<Nothing>` would be more correct, but then the compiler complains)
27+
*/
2028
@Suppress("UNCHECKED_CAST")
21-
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any?>> =
22-
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any?>>
29+
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any>?> =
30+
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any>?>
2331

2432
@Suppress("UNCHECKED_CAST")
2533
internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =

0 commit comments

Comments
 (0)