Skip to content

Commit 8effc33

Browse files
committed
added median tests, added HybridAggregationHandler for median, median now always returns null when empty
1 parent 9b356c8 commit 8effc33

File tree

9 files changed

+425
-72
lines changed

9 files changed

+425
-72
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

+5-4
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ import kotlin.experimental.ExperimentalTypeInference
2929
import kotlin.reflect.KProperty
3030

3131
/* TODO KDocs
32-
* numbers -> Double
33-
* comparable -> itself
32+
* numbers -> Double or null
33+
* comparable -> itself or null
3434
*
3535
* TODO cases where the lambda dictates the return type require explicit type arguments for
3636
* non-number, comparable overloads: https://youtrack.jetbrains.com/issue/KT-76683
3737
* so, `df.median { intCol }` works, but needs `df.median<_, String> { stringCol }`
38+
* This needs to be explained by KDocs
3839
*
3940
* medianBy is new for all overloads :)
4041
*/
@@ -281,7 +282,7 @@ public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.medianBy(
281282
public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.medianByOrNull(
282283
skipNaN: Boolean = skipNaNDefault,
283284
crossinline expression: RowExpression<T, C>,
284-
): DataRow<T>? = Aggregators.min<C>(skipNaN).aggregateByOrNull(this, expression)
285+
): DataRow<T>? = Aggregators.medianCommon<C>(skipNaN).aggregateByOrNull(this, expression)
285286

286287
public fun <T> DataFrame<T>.medianByOrNull(column: String, skipNaN: Boolean = skipNaNDefault): DataRow<T>? =
287288
medianByOrNull(column.toColumnOf<Comparable<Any>?>(), skipNaN)
@@ -290,7 +291,7 @@ public fun <T> DataFrame<T>.medianByOrNull(column: String, skipNaN: Boolean = sk
290291
public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.medianByOrNull(
291292
column: ColumnReference<C>,
292293
skipNaN: Boolean = skipNaNDefault,
293-
): DataRow<T>? = Aggregators.min<C>(skipNaN).aggregateByOrNull(this, column)
294+
): DataRow<T>? = Aggregators.medianCommon<C>(skipNaN).aggregateByOrNull(this, column)
294295

295296
@AccessApiOverload
296297
public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.medianByOrNull(

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.DataColumn
4+
import org.jetbrains.kotlinx.dataframe.api.asSequence
45
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
56
import kotlin.reflect.KType
67

@@ -28,7 +29,11 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
2829
* Aggregates the data in the given column and computes a single resulting value.
2930
* Calls [aggregateSequence].
3031
*/
31-
fun aggregateSingleColumn(column: DataColumn<Value?>): Return
32+
fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
33+
aggregateSequence(
34+
values = column.asSequence(),
35+
valueType = column.type().toValueType(),
36+
)
3237

3338
/**
3439
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
22

33
import org.jetbrains.kotlinx.dataframe.api.skipNaNDefault
4+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.HybridAggregationHandler
45
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.ReducingAggregationHandler
56
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
67
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.AnyInputHandler
@@ -39,12 +40,12 @@ internal object Aggregators {
3940
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
4041
)
4142

42-
private fun <Value : Return & Any, Return : Any?> flattenSelectingForAny(
43+
private fun <Value : Any, Return : Any?> flattenHybridForAny(
4344
getReturnType: CalculateReturnType,
4445
indexOfResult: IndexOfResult<Value>,
45-
selector: Selector<Value, Return>,
46+
reducer: Reducer<Value, Return>,
4647
) = Aggregator(
47-
aggregationHandler = SelectingAggregationHandler(selector, indexOfResult, getReturnType),
48+
aggregationHandler = HybridAggregationHandler(reducer, indexOfResult, getReturnType),
4849
inputHandler = AnyInputHandler(),
4950
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
5051
)
@@ -153,7 +154,7 @@ internal object Aggregators {
153154
}
154155
}
155156

156-
// T : primitive Number? -> Double
157+
// T : primitive Number? -> Double?
157158
// T : Comparable<T & Any>? -> T?
158159
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
159160
where T : Comparable<T & Any>? =
@@ -164,19 +165,18 @@ internal object Aggregators {
164165
where T : Comparable<T & Any>? =
165166
medianCommon<T>(skipNaNDefault).cast2()
166167

167-
// T : primitive Number? -> Double
168+
// T : primitive Number? -> Double?
168169
fun <T> medianNumbers(
169170
skipNaN: Boolean,
170-
): Aggregator<T & Any, Double>
171+
): Aggregator<T & Any, Double?>
171172
where T : Comparable<T & Any>?, T : Number? =
172173
medianCommon<T>(skipNaN).cast2()
173174

174-
// T: Comparable<T>? -> T
175175
@Suppress("UNCHECKED_CAST")
176176
private val median by withOneOption { skipNaN: Boolean ->
177-
flattenSelectingForAny<Comparable<Any>, Comparable<Any>?>(
177+
flattenHybridForAny<Comparable<Any>, Comparable<Any>?>(
178178
getReturnType = medianConversion,
179-
selector = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
179+
reducer = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
180180
indexOfResult = { type -> indexOfMedian(type, skipNaN) },
181181
)
182182
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
2+
3+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
4+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
5+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
6+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.IndexOfResult
7+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
8+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
9+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
10+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
11+
import kotlin.reflect.KType
12+
import kotlin.reflect.full.withNullability
13+
14+
/**
15+
* Implementation of [AggregatorAggregationHandler] which functions like a selector ánd reducer:
16+
* it takes a sequence of values and returns a single value, which is likely part of the input, but not necessarily.
17+
*
18+
* In practice, this means the handler implements both [indexOfAggregationResultSingleSequence]
19+
* (meaning it can give an index of the result in the input), and [aggregateSequence] with a return type that is
20+
* potentially different from the input.
21+
* The return value of [aggregateSequence] and the value at the index retrieved from [indexOfAggregationResultSingleSequence]
22+
* may differ.
23+
*
24+
* @param reducer This function actually does the selection/reduction.
25+
* Before it is called, nulls are filtered out. The type of the values is passed as [KType] to the selector.
26+
* @param indexOfResult This function must be supplied to give the index of the result in the input values.
27+
* @param getReturnType This function must be supplied to give the return type of [reducer] given some input type and
28+
* whether the input is empty.
29+
* When selecting, the return type is always `typeOf<Value>()` or `typeOf<Value?>()`, when reducing it can be anything.
30+
* @see [ReducingAggregationHandler]
31+
*/
32+
internal class HybridAggregationHandler<in Value : Any, out Return : Any?>(
33+
val reducer: Reducer<Value, Return>,
34+
val indexOfResult: IndexOfResult<Value>,
35+
val getReturnType: CalculateReturnType,
36+
) : AggregatorAggregationHandler<Value, Return> {
37+
38+
/**
39+
* Function that can give the index of the aggregation result in the input [values].
40+
* Calls the supplied [indexOfResult] after preprocessing the input.
41+
*/
42+
@Suppress("UNCHECKED_CAST")
43+
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int {
44+
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
45+
return indexOfResult(values, valueType)
46+
}
47+
48+
/**
49+
* Base function of [Aggregator].
50+
*
51+
* Aggregates the given values, taking [valueType] into account,
52+
* filtering nulls (only if [valueType.type.isMarkedNullable][KType.isMarkedNullable]),
53+
* and computes a single resulting value.
54+
*
55+
* When the exact [valueType] is unknown, use [calculateValueType] or [aggregateCalculatingValueType].
56+
*
57+
* Calls the supplied [reducer].
58+
*/
59+
@Suppress("UNCHECKED_CAST")
60+
override fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return {
61+
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
62+
return reducer(
63+
// values =
64+
if (valueType.isMarkedNullable) {
65+
values.filterNotNull()
66+
} else {
67+
values as Sequence<Value>
68+
},
69+
// type =
70+
valueType.withNullability(false),
71+
)
72+
}
73+
74+
/**
75+
* Give the return type of [reducer] given some input type and whether the input is empty.
76+
* Calls the supplied [getReturnType].
77+
*/
78+
override fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType =
79+
getReturnType(valueType.withNullability(false), emptyInput)
80+
81+
override var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? = null
82+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/aggregationHandlers/ReducingAggregationHandler.kt

-14
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
22

3-
import org.jetbrains.kotlinx.dataframe.DataColumn
4-
import org.jetbrains.kotlinx.dataframe.api.asSequence
53
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
64
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
75
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
86
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
97
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
108
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
119
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
12-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
1310
import kotlin.reflect.KType
1411
import kotlin.reflect.full.withNullability
1512

@@ -54,17 +51,6 @@ internal class ReducingAggregationHandler<in Value : Any, out Return : Any?>(
5451
)
5552
}
5653

57-
/**
58-
* Aggregates the data in the given column and computes a single resulting value.
59-
* Calls [aggregateSequence].
60-
*/
61-
@Suppress("UNCHECKED_CAST")
62-
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
63-
aggregateSequence(
64-
values = column.asSequence(),
65-
valueType = column.type().toValueType(),
66-
)
67-
6854
/** This function always returns `-1` because the result of a reducer is not in the input values. */
6955
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int = -1
7056

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/aggregationHandlers/SelectingAggregationHandler.kt

+1-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers
22

3-
import org.jetbrains.kotlinx.dataframe.DataColumn
4-
import org.jetbrains.kotlinx.dataframe.api.asSequence
53
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
64
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
75
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
@@ -10,7 +8,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Selector
108
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
119
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
1210
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
13-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
1411
import kotlin.reflect.KType
1512
import kotlin.reflect.full.withNullability
1613

@@ -70,16 +67,6 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
7067
)
7168
}
7269

73-
/**
74-
* Aggregates the data in the given column and computes a single resulting value.
75-
* Calls [aggregateSequence].
76-
*/
77-
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
78-
aggregateSequence(
79-
values = column.asSequence(),
80-
valueType = column.type().toValueType(),
81-
)
82-
8370
/**
8471
* Give the return type of [selector] given some input type and whether the input is empty.
8572
* Calls the supplied [getReturnType].
@@ -91,7 +78,7 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
9178
require(it == valueType.withNullability(false) || it == valueType.withNullability(true)) {
9279
"The return type of the selector must be either ${valueType.withNullability(false)} or ${
9380
valueType.withNullability(true)
94-
}"
81+
} but was $it."
9582
}
9683
}
9784

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/median.kt

+11-19
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ private val logger = KotlinLogging.logger { }
2222

2323
/**
2424
* Returns the median of the comparable input:
25-
* - `null` if empty and primitive number
26-
* - `Double.NaN` if empty and primitive number
25+
* - `null` if empty
2726
* - `Double` if primitive number
2827
* - `Double.NaN` if ![skipNaN] and contains NaN
2928
* - (lower) middle else
@@ -64,19 +63,19 @@ internal fun <T : Comparable<T>> Sequence<T>.medianOrNull(type: KType, skipNaN:
6463
}.toList()
6564

6665
val size = list.size
67-
if (size == 0) return if (type.isPrimitiveNumber()) Double.NaN else null
66+
if (size == 0) return null
67+
68+
if (size == 1) {
69+
val single = list.single()
70+
return if (type.isPrimitiveNumber()) (single as Number).toDouble() else single
71+
}
6872

6973
val isOdd = size % 2 != 0
7074

7175
val middleIndex = (size - 1) / 2
7276
val lower = list.quickSelect(middleIndex)
7377
val upper = list.quickSelect(middleIndex + 1)
7478

75-
// check for quickSelect
76-
if (isOdd && lower.compareTo(upper) != 0) {
77-
error("lower and upper median are not equal while list-size is odd. This should not happen.")
78-
}
79-
8079
return when {
8180
isOdd && type.isPrimitiveNumber() -> (lower as Number).toDouble()
8281
isOdd -> lower
@@ -91,7 +90,7 @@ internal fun <T : Comparable<T>> Sequence<T>.medianOrNull(type: KType, skipNaN:
9190
}
9291

9392
/**
94-
* Primitive Number -> Double
93+
* Primitive Number -> Double?
9594
* T : Comparable<T> -> T?
9695
*/
9796
internal val medianConversion: CalculateReturnType = { type, isEmpty ->
@@ -101,10 +100,10 @@ internal val medianConversion: CalculateReturnType = { type, isEmpty ->
101100

102101
// closest rank method, preferring lower middle,
103102
// number 3 of Hyndman and Fan "Sample quantiles in statistical packages"
104-
type.isIntraComparable() -> type.withNullability(isEmpty)
103+
type.isIntraComparable() -> type
105104

106105
else -> error("Can not calculate median for type ${renderType(type)}")
107-
}
106+
}.withNullability(isEmpty)
108107
}
109108

110109
/**
@@ -156,17 +155,10 @@ internal fun <T : Comparable<T & Any>?> Sequence<T>.indexOfMedian(type: KType, s
156155

157156
val size = list.size
158157
if (size == 0) return -1
159-
160-
val isOdd = size % 2 != 0
158+
if (size == 1) return 0
161159

162160
val middleIndex = (size - 1) / 2
163161
val lower = list.quickSelect(middleIndex)
164-
val upper = list.quickSelect(middleIndex + 1)
165-
166-
// check for quickSelect
167-
if (isOdd && lower.compareTo(upper) != 0) {
168-
error("lower and upper median are not equal while list-size is odd. This should not happen.")
169-
}
170162

171163
return lower.index
172164
}

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Analyze.kt

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.api.aggregate
66
import org.jetbrains.kotlinx.dataframe.api.asComparable
77
import org.jetbrains.kotlinx.dataframe.api.asGroupBy
88
import org.jetbrains.kotlinx.dataframe.api.asNumbers
9+
import org.jetbrains.kotlinx.dataframe.api.cast
910
import org.jetbrains.kotlinx.dataframe.api.colsOf
1011
import org.jetbrains.kotlinx.dataframe.api.column
1112
import org.jetbrains.kotlinx.dataframe.api.columnGroup
@@ -455,7 +456,7 @@ class Analyze : TestBase() {
455456
df.max { name.firstName and name.lastName }
456457
df.sum { age and weight }
457458
df.mean { cols(1, 3).asNumbers() }
458-
df.median { name.cols().asComparable() }
459+
df.median<_, String> { name.cols().cast() }
459460
// SampleEnd
460461
}
461462

@@ -480,7 +481,7 @@ class Analyze : TestBase() {
480481
df.sum(age, weight)
481482

482483
df.mean { cols(1, 3).asNumbers() }
483-
df.median { name.cols().asComparable() }
484+
df.median<_, String> { name.cols().cast<String>() }
484485
// SampleEnd
485486
}
486487

@@ -498,7 +499,7 @@ class Analyze : TestBase() {
498499
df.sum { "age"<Int>() and "weight"<Int?>() }
499500

500501
df.mean { cols(1, 3).asNumbers() }
501-
df.median { name.cols().asComparable() }
502+
df.median<_, String> { name.cols().cast() }
502503
// SampleEnd
503504
}
504505

@@ -535,7 +536,7 @@ class Analyze : TestBase() {
535536
df.sum(age, weight)
536537

537538
df.mean { cols(1, 3).asNumbers() }
538-
df.median { name.cols().asComparable() }
539+
df.median<_, String> { name.cols().cast() }
539540
// SampleEnd
540541
}
541542

0 commit comments

Comments
 (0)