Skip to content

Commit 3a71ce3

Browse files
authored
Merge pull request #1091 from Kotlin/mean
Mean statistics fixes
2 parents b339b6a + 33e35bc commit 3a71ce3

File tree

49 files changed

+583
-516
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+583
-516
lines changed

core/api/core.api

+10-3
Original file line numberDiff line numberDiff line change
@@ -1777,12 +1777,12 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnArithmeticsKt {
17771777
}
17781778

17791779
public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt {
1780-
public static final fun isBigNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17811780
public static final fun isColumnGroup (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17821781
public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17831782
public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17841783
public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17851784
public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
1785+
public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17861786
public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z
17871787
public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17881788
public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
@@ -2842,8 +2842,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/MeanKt {
28422842
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
28432843
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
28442844
public static synthetic fun meanFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
2845-
public static final fun meanOrNull (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Z)Ljava/lang/Double;
2846-
public static synthetic fun meanOrNull$default (Lorg/jetbrains/kotlinx/dataframe/DataColumn;ZILjava/lang/Object;)Ljava/lang/Double;
28472845
public static final fun rowMean (Lorg/jetbrains/kotlinx/dataframe/DataRow;Z)D
28482846
public static synthetic fun rowMean$default (Lorg/jetbrains/kotlinx/dataframe/DataRow;ZILjava/lang/Object;)D
28492847
}
@@ -3967,6 +3965,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt {
39673965
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
39683966
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnSet;
39693967
public static final fun asComparable (Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;)Lorg/jetbrains/kotlinx/dataframe/columns/SingleColumn;
3968+
public static final fun asComparableNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
39703969
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
39713970
public static final fun asDataColumn (Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
39723971
public static final fun asDataFrame (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
@@ -5103,6 +5102,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt {
51035102
public static final fun suggestIfNull (Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;
51045103
}
51055104

5105+
public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt {
5106+
public static final fun getPrimitiveNumberTypes ()Ljava/util/Set;
5107+
}
5108+
51065109
public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt {
51075110
public static final fun getValuesType (Ljava/util/List;Lkotlin/reflect/KType;Lorg/jetbrains/kotlinx/dataframe/api/Infer;)Lkotlin/reflect/KType;
51085111
public static final synthetic fun guessValueType (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;Z)Lkotlin/reflect/KType;
@@ -5210,6 +5213,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/OfRowE
52105213
public static final fun aggregateOfDelegated (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/api/Grouped;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
52115214
}
52125215

5216+
public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/RowKt {
5217+
public static final fun aggregateOfRow (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
5218+
}
5219+
52135220
public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/WithinAllColumnsKt {
52145221
public static final fun aggregateAll (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
52155222
}

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt

+7-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
77
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
88
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
99
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
10+
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
1011
import org.jetbrains.kotlinx.dataframe.type
1112
import org.jetbrains.kotlinx.dataframe.typeClass
1213
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
1314
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE_REPLACE
1415
import org.jetbrains.kotlinx.dataframe.util.IS_INTER_COMPARABLE_IMPORT
15-
import java.math.BigDecimal
16-
import java.math.BigInteger
1716
import kotlin.contracts.ExperimentalContracts
1817
import kotlin.contracts.contract
1918
import kotlin.reflect.KType
@@ -46,9 +45,14 @@ public inline fun <reified T> AnyCol.isSubtypeOf(): Boolean = isSubtypeOf(typeOf
4645

4746
public inline fun <reified T> AnyCol.isType(): Boolean = type() == typeOf<T>()
4847

48+
/** Returns `true` when this column's type is a subtype of `Number?` */
4949
public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()
5050

51-
public fun AnyCol.isBigNumber(): Boolean = isSubtypeOf<BigInteger?>() || isSubtypeOf<BigDecimal?>()
51+
/**
52+
* Returns `true` when this column has the (nullable) type of either:
53+
* [Byte], [Short], [Int], [Long], [Float], or [Double].
54+
*/
55+
public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes
5256

5357
public fun AnyCol.isList(): Boolean = typeClass == List::class
5458

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/map.kt

+1-4
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ public inline fun <C, reified R> ColumnReference<C>.map(
3131

3232
// region DataColumn
3333

34-
public inline fun <T, reified R> DataColumn<T>.map(
35-
infer: Infer = Infer.Nulls,
36-
crossinline transform: (T) -> R,
37-
): DataColumn<R> {
34+
public inline fun <T, reified R> DataColumn<T>.map(infer: Infer = Infer.Nulls, transform: (T) -> R): DataColumn<R> {
3835
val newValues = Array(size()) { transform(get(it)) }.asList()
3936
return DataColumn.createByType(name(), newValues, typeOf<R>(), infer)
4037
}

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+36-16
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,60 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast2
1818
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
1919
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2020
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
21+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
2122
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of
22-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
23+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
2324
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
24-
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
25-
import org.jetbrains.kotlinx.dataframe.math.mean
25+
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
2626
import kotlin.reflect.KProperty
27+
import kotlin.reflect.full.withNullability
2728
import kotlin.reflect.typeOf
2829

30+
/*
31+
* TODO KDocs:
32+
* Calculating the mean is supported for all primitive number types.
33+
* Nulls are filtered from columns.
34+
* The return type is always Double, Double.NaN for empty input, never null.
35+
* (May introduce loss of precision for Longs).
36+
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
37+
*/
38+
2939
// region DataColumn
3040

3141
public fun <T : Number> DataColumn<T?>.mean(skipNA: Boolean = skipNA_default): Double =
32-
meanOrNull(skipNA).suggestIfNull("mean")
33-
34-
public fun <T : Number> DataColumn<T?>.meanOrNull(skipNA: Boolean = skipNA_default): Double? =
3542
Aggregators.mean(skipNA).aggregate(this)
3643

3744
public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
3845
skipNA: Boolean = skipNA_default,
3946
noinline expression: (T) -> R?,
40-
): Double = Aggregators.mean(skipNA).cast2<R?, Double>().aggregateOf(this, expression) ?: Double.NaN
47+
): Double =
48+
Aggregators.mean(skipNA)
49+
.cast2<R?, Double>()
50+
.aggregateOf(this, expression)
4151

4252
// endregion
4353

4454
// region DataRow
4555

4656
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
47-
values().filterIsInstance<Number>().map { it.toDouble() }.mean(skipNA)
48-
49-
public inline fun <reified T : Number> AnyRow.rowMeanOf(): Double = values().filterIsInstance<T>().mean(typeOf<T>())
57+
Aggregators.mean(skipNA).aggregateOfRow(this) {
58+
colsOf<Number?> { it.isPrimitiveNumber() }
59+
}
60+
61+
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
62+
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
63+
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
64+
}
65+
return Aggregators.mean(skipNA)
66+
.aggregateOfRow(this) { colsOf<T>() }
67+
}
5068

5169
// endregion
5270

5371
// region DataFrame
5472

55-
public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> = meanFor(skipNA, numberColumns())
73+
public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
74+
meanFor(skipNA, primitiveNumberColumns())
5675

5776
public fun <T, C : Number> DataFrame<T>.meanFor(
5877
skipNA: Boolean = skipNA_default,
@@ -77,7 +96,7 @@ public fun <T, C : Number> DataFrame<T>.meanFor(
7796
public fun <T, C : Number> DataFrame<T>.mean(
7897
skipNA: Boolean = skipNA_default,
7998
columns: ColumnsSelector<T, C?>,
80-
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns) as Double? ?: Double.NaN
99+
): Double = Aggregators.mean(skipNA).aggregateAll(this, columns)
81100

82101
public fun <T> DataFrame<T>.mean(vararg columns: String, skipNA: Boolean = skipNA_default): Double =
83102
mean(skipNA) { columns.toNumberColumns() }
@@ -95,14 +114,15 @@ public fun <T, C : Number> DataFrame<T>.mean(vararg columns: KProperty<C?>, skip
95114
public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
96115
skipNA: Boolean = skipNA_default,
97116
noinline expression: RowExpression<T, D?>,
98-
): Double = Aggregators.mean(skipNA).of(this, expression) ?: Double.NaN
117+
): Double = Aggregators.mean(skipNA).of(this, expression)
99118

100119
// endregion
101120

102121
// region GroupBy
103122
@Refine
104123
@Interpretable("GroupByMean1")
105-
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> = meanFor(skipNA, numberColumns())
124+
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
125+
meanFor(skipNA, primitiveNumberColumns())
106126

107127
@Refine
108128
@Interpretable("GroupByMean0")
@@ -167,7 +187,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
167187
// region Pivot
168188

169189
public fun <T> Pivot<T>.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow<T> =
170-
meanFor(skipNA, separate, numberColumns())
190+
meanFor(skipNA, separate, primitiveNumberColumns())
171191

172192
public fun <T, C : Number> Pivot<T>.meanFor(
173193
skipNA: Boolean = skipNA_default,
@@ -210,7 +230,7 @@ public inline fun <T, reified R : Number> Pivot<T>.meanOf(
210230
// region PivotGroupBy
211231

212232
public fun <T> PivotGroupBy<T>.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame<T> =
213-
meanFor(skipNA, separate, numberColumns())
233+
meanFor(skipNA, separate, primitiveNumberColumns())
214234

215235
public fun <T, C : Number> PivotGroupBy<T>.meanFor(
216236
skipNA: Boolean = skipNA_default,

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt

+7-1
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,17 @@ public fun DataColumn<Any>.asNumbers(): ValueColumn<Number> {
8383
return this as ValueColumn<Number>
8484
}
8585

86-
public fun <T> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
86+
public fun <T : Any> DataColumn<T>.asComparable(): DataColumn<Comparable<T>> {
8787
require(valuesAreComparable())
8888
return this as DataColumn<Comparable<T>>
8989
}
9090

91+
@JvmName("asComparableNullable")
92+
public fun <T : Any?> DataColumn<T?>.asComparable(): DataColumn<Comparable<T>?> {
93+
require(valuesAreComparable())
94+
return this as DataColumn<Comparable<T>?>
95+
}
96+
9197
public fun <T> ColumnReference<T?>.castToNotNullable(): ColumnReference<T> = cast()
9298

9399
public fun <T> DataColumn<T?>.castToNotNullable(): DataColumn<T> {

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt

+60-10
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,21 @@ internal fun Iterable<KClass<*>>.unifiedNumberClass(
234234
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
235235
*
236236
* @param commonNumberType The desired common numeric type to convert the elements to.
237-
* This is determined by default using the types of the elements in the iterable.
237+
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
238238
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
239239
* @throws IllegalStateException if an element cannot be converted to the common number type.
240240
* @see UnifyingNumbers
241241
*/
242242
@Suppress("UNCHECKED_CAST")
243-
internal fun Iterable<Number>.convertToUnifiedNumberType(
243+
@JvmName("convertNullableIterableToUnifiedNumberType")
244+
internal fun Iterable<Number?>.convertToUnifiedNumberType(
244245
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
245-
commonNumberType: KType = this.types().unifiedNumberType(options),
246-
): Iterable<Number> {
246+
commonNumberType: KType? = null,
247+
): Iterable<Number?> {
248+
val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options)
247249
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
248250
return map {
251+
if (it == null) return@map null
249252
converter(it) ?: error("Can not convert $it to $commonNumberType")
250253
}
251254
}
@@ -255,23 +258,62 @@ internal fun Iterable<Number>.convertToUnifiedNumberType(
255258
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
256259
*
257260
* @param commonNumberType The desired common numeric type to convert the elements to.
258-
* This is determined by default using the types of the elements in the iterable.
261+
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
259262
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
260263
* @throws IllegalStateException if an element cannot be converted to the common number type.
261264
* @see UnifyingNumbers */
262-
@JvmName("convertToUnifiedNumberTypeSequence")
263265
@Suppress("UNCHECKED_CAST")
264-
internal fun Sequence<Number>.convertToUnifiedNumberType(
266+
@JvmName("convertIterableToUnifiedNumberType")
267+
internal fun Iterable<Number>.convertToUnifiedNumberType(
268+
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
269+
commonNumberType: KType? = null,
270+
): Iterable<Number> =
271+
(this as Iterable<Number?>)
272+
.convertToUnifiedNumberType(options, commonNumberType) as Iterable<Number>
273+
274+
/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
275+
* The common numeric type is determined using the provided [commonNumberType] parameter
276+
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
277+
*
278+
* @param commonNumberType The desired common numeric type to convert the elements to.
279+
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
280+
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
281+
* @throws IllegalStateException if an element cannot be converted to the common number type.
282+
* @see UnifyingNumbers */
283+
@Suppress("UNCHECKED_CAST")
284+
@JvmName("convertNullableSequenceToUnifiedNumberType")
285+
internal fun Sequence<Number?>.convertToUnifiedNumberType(
265286
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
266-
commonNumberType: KType = asIterable().types().unifiedNumberType(options),
267-
): Sequence<Number> {
287+
commonNumberType: KType? = null,
288+
): Sequence<Number?> {
289+
val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options)
268290
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
269291
return map {
292+
if (it == null) return@map null
270293
converter(it) ?: error("Can not convert $it to $commonNumberType")
271294
}
272295
}
273296

274-
internal val primitiveNumberTypes =
297+
/** Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
298+
* The common numeric type is determined using the provided [commonNumberType] parameter
299+
* or calculated with [Iterable.unifiedNumberType][kotlin.collections.Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
300+
*
301+
* @param commonNumberType The desired common numeric type to convert the elements to.
302+
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
303+
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
304+
* @throws IllegalStateException if an element cannot be converted to the common number type.
305+
* @see UnifyingNumbers */
306+
@Suppress("UNCHECKED_CAST")
307+
@JvmName("convert=SequenceToUnifiedNumberType")
308+
internal fun Sequence<Number>.convertToUnifiedNumberType(
309+
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
310+
commonNumberType: KType? = null,
311+
): Sequence<Number> =
312+
(this as Sequence<Number?>)
313+
.convertToUnifiedNumberType(options, commonNumberType) as Sequence<Number>
314+
315+
@PublishedApi
316+
internal val primitiveNumberTypes: Set<KType> =
275317
setOf(
276318
typeOf<Byte>(),
277319
typeOf<Short>(),
@@ -280,3 +322,11 @@ internal val primitiveNumberTypes =
280322
typeOf<Float>(),
281323
typeOf<Double>(),
282324
)
325+
326+
internal fun Any.isPrimitiveNumber(): Boolean =
327+
this is Byte ||
328+
this is Short ||
329+
this is Int ||
330+
this is Long ||
331+
this is Float ||
332+
this is Double

0 commit comments

Comments
 (0)