diff --git a/core/api/core.api b/core/api/core.api index 0f002f0ab0..f03122f4c3 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1781,8 +1781,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt { public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isPrimitiveOrMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z @@ -3816,6 +3818,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt { public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number; + public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; @@ -3839,6 +3842,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow; @@ -3863,8 +3870,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; - public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; + public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; + public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I + public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I + public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I } public final class org/jetbrains/kotlinx/dataframe/api/TailKt { @@ -5104,6 +5118,9 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt { public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt { public static final fun getPrimitiveNumberTypes ()Ljava/util/Set; + public static final fun isMixedNumber (Lkotlin/reflect/KType;)Z + public static final fun isPrimitiveNumber (Lkotlin/reflect/KType;)Z + public static final fun isPrimitiveOrMixedNumber (Lkotlin/reflect/KType;)Z } public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt { @@ -6153,13 +6170,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt { } public final class org/jetbrains/kotlinx/dataframe/math/SumKt { - public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal; - public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger; - public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number; - public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal; - public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger; - public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number; - public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number; + public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number; } public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 9800f33463..a69d8db723 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -7,7 +7,9 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnKind import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.ValueColumn -import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes +import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber import org.jetbrains.kotlinx.dataframe.type import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE @@ -48,11 +50,23 @@ public inline fun AnyCol.isType(): Boolean = type() == typeOf() /** Returns `true` when this column's type is a subtype of `Number?` */ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() +/** Returns `true` only when this column's type is exactly `Number` or `Number?`. */ +public fun AnyCol.isMixedNumber(): Boolean = type().isMixedNumber() + /** * Returns `true` when this column has the (nullable) type of either: * [Byte], [Short], [Int], [Long], [Float], or [Double]. */ -public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes +public fun AnyCol.isPrimitiveNumber(): Boolean = type().isPrimitiveNumber() + +/** + * Returns `true` when this column has the (nullable) type of either: + * [Byte], [Short], [Int], [Long], [Float], [Double], or [Number]. + * + * Careful: Will return `true` if the column contains multiple number types that + * might NOT be primitive. + */ +public fun AnyCol.isPrimitiveOrMixedNumber(): Boolean = type().isPrimitiveOrMixedNumber() public fun AnyCol.isList(): Boolean = typeClass == List::class diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index bbdb9707d0..3ebb607306 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -18,17 +18,15 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns -import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber import kotlin.reflect.KProperty -import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf -/* - * TODO KDocs: +/* TODO KDocs: * Calculating the mean is supported for all primitive number types. - * Nulls are filtered from columns. + * Nulls are filtered out. * The return type is always Double, Double.NaN for empty input, never null. * (May introduce loss of precision for Longs). * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean. @@ -48,16 +46,13 @@ public inline fun DataColumn.meanOf( // region DataRow public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = - Aggregators.mean(skipNA).aggregateOfRow(this) { - colsOf { it.isPrimitiveNumber() } - } + Aggregators.mean(skipNA).aggregateOfRow(this, primitiveOrMixedNumberColumns()) public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { - require(typeOf().withNullability(false) in primitiveNumberTypes) { + require(typeOf().isPrimitiveOrMixedNumber()) { "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." } - return Aggregators.mean(skipNA) - .aggregateOfRow(this) { colsOf() } + return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf() } } // endregion @@ -65,7 +60,7 @@ public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipN // region DataFrame public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = - meanFor(skipNA, primitiveNumberColumns()) + meanFor(skipNA, primitiveOrMixedNumberColumns()) public fun DataFrame.meanFor( skipNA: Boolean = skipNA_default, @@ -116,7 +111,7 @@ public inline fun DataFrame.meanOf( @Refine @Interpretable("GroupByMean1") public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, primitiveNumberColumns()) + meanFor(skipNA, primitiveOrMixedNumberColumns()) @Refine @Interpretable("GroupByMean0") @@ -181,7 +176,7 @@ public inline fun Grouped.meanOf( // region Pivot public fun Pivot.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow = - meanFor(skipNA, separate, primitiveNumberColumns()) + meanFor(skipNA, separate, primitiveOrMixedNumberColumns()) public fun Pivot.meanFor( skipNA: Boolean = skipNA_default, @@ -224,7 +219,7 @@ public inline fun Pivot.meanOf( // region PivotGroupBy public fun PivotGroupBy.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, separate, primitiveNumberColumns()) + meanFor(skipNA, separate, primitiveOrMixedNumberColumns()) public fun PivotGroupBy.meanFor( skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index af0016ee18..c786320ac8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -1,3 +1,6 @@ +@file:OptIn(ExperimentalTypeInference::class) +@file:Suppress("LocalVariableName") + package org.jetbrains.kotlinx.dataframe.api import org.jetbrains.kotlinx.dataframe.AnyRow @@ -13,50 +16,97 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf -import org.jetbrains.kotlinx.dataframe.columns.values -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns -import org.jetbrains.kotlinx.dataframe.impl.zero -import org.jetbrains.kotlinx.dataframe.math.sum -import org.jetbrains.kotlinx.dataframe.math.sumOf +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber +import kotlin.experimental.ExperimentalTypeInference +import kotlin.reflect.KClass import kotlin.reflect.KProperty -import kotlin.reflect.full.isSubtypeOf +import kotlin.reflect.KType import kotlin.reflect.typeOf +/* TODO KDocs + * Calculating the sum is supported for all primitive number types. + * Nulls are filtered out. + * The return type is always the same as the input type (never null), except for `Byte` and `Short`, + * which are converted to `Int`. + * Empty input will result in 0 in the supplied number type. + * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum. + */ + // region DataColumn -@JvmName("sumT") -public fun DataColumn.sum(): T = values.sum(type()) +@JvmName("sumShort") +public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int + +@JvmName("sumByte") +public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int -@JvmName("sumTNullable") -public fun DataColumn.sum(): T = values.sum(type()) +@Suppress("UNCHECKED_CAST") +@JvmName("sumNumber") +public fun DataColumn.sum(): T = Aggregators.sum.aggregate(this) as T -public inline fun DataColumn.sumOf(noinline expression: (T) -> R): R? = - (Aggregators.sum as Aggregator<*, *>).cast().aggregateOf(this, expression) +@JvmName("sumOfShort") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (C) -> Short?): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfByte") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfNumber") +@OverloadResolutionByLambdaReturnType +public inline fun DataColumn.sumOf(crossinline expression: (C) -> V?): V = + Aggregators.sum.aggregateOf(this, expression) as V // endregion // region DataRow -public fun AnyRow.rowSum(): Number = - Aggregators.sum.aggregateCalculatingType( - values = values().filterIsInstance(), - valueTypes = columnTypes().filter { it.isSubtypeOf(typeOf()) }.toSet(), - ) ?: 0 +public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveOrMixedNumberColumns()) + +@JvmName("rowSumOfShort") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Short::class): Int = + rowSumOf(typeOf()) as Int -public inline fun AnyRow.rowSumOf(): T = values().filterIsInstance().sum(typeOf()) +@JvmName("rowSumOfByte") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Byte::class): Int = + rowSumOf(typeOf()) as Int +@JvmName("rowSumOfInt") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Int::class): Int = + rowSumOf(typeOf()) as Int + +@JvmName("rowSumOfLong") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Long::class): Long = + rowSumOf(typeOf()) as Long + +@JvmName("rowSumOfFloat") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Float::class): Float = + rowSumOf(typeOf()) as Float + +@JvmName("rowSumOfDouble") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Double::class): Double = + rowSumOf(typeOf()) as Double + +// unfortunately, we cannot make a `reified T : Number?` due to clashes +public fun AnyRow.rowSumOf(type: KType): Number { + require(type.isPrimitiveOrMixedNumber()) { + "Type $type is not a primitive number type. Mean only supports primitive number types." + } + return Aggregators.sum.aggregateOfRow(this) { colsOf(type) } +} // endregion // region DataFrame -public fun DataFrame.sum(): DataRow = sumFor(numberColumns()) +public fun DataFrame.sum(): DataRow = sumFor(primitiveOrMixedNumberColumns()) public fun DataFrame.sumFor(columns: ColumnsForAggregateSelector): DataRow = Aggregators.sum.aggregateFor(this, columns) @@ -71,28 +121,70 @@ public fun DataFrame.sumFor(vararg columns: ColumnReference DataFrame.sumFor(vararg columns: KProperty): DataRow = sumFor { columns.toColumnSet() } +@JvmName("sumShort") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sum(columns: ColumnsSelector): Int = + Aggregators.sum.aggregateAll(this, columns) as Int + +@JvmName("sumByte") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sum(columns: ColumnsSelector): Int = + Aggregators.sum.aggregateAll(this, columns) as Int + +@JvmName("sumNumber") +@OverloadResolutionByLambdaReturnType public inline fun DataFrame.sum(noinline columns: ColumnsSelector): C = - (Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero() + Aggregators.sum.aggregateAll(this, columns) as C +@JvmName("sumShort") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: ColumnReference): Int = sum { columns.toColumnSet() } + +@JvmName("sumByte") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: ColumnReference): Int = sum { columns.toColumnSet() } + +@JvmName("sumNumber") @AccessApiOverload public inline fun DataFrame.sum(vararg columns: ColumnReference): C = sum { columns.toColumnSet() } -public fun DataFrame.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() } +public fun DataFrame.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() } + +@JvmName("sumShort") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: KProperty): Int = sum { columns.toColumnSet() } +@JvmName("sumByte") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: KProperty): Int = sum { columns.toColumnSet() } + +@JvmName("sumNumber") @AccessApiOverload public inline fun DataFrame.sum(vararg columns: KProperty): C = sum { columns.toColumnSet() } -public inline fun DataFrame.sumOf(crossinline expression: RowExpression): C = - rows().sumOf(typeOf()) { expression(it, it) } +@JvmName("sumOfShort") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sumOf(expression: RowExpression): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfByte") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sumOf(expression: RowExpression): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfNumber") +@OverloadResolutionByLambdaReturnType +public inline fun DataFrame.sumOf(crossinline expression: RowExpression): C = + Aggregators.sum.aggregateOf(this, expression) as C // endregion // region GroupBy @Refine @Interpretable("GroupBySum1") -public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) +public fun Grouped.sum(): DataFrame = sumFor(primitiveOrMixedNumberColumns()) @Refine @Interpretable("GroupBySum0") @@ -136,7 +228,7 @@ public inline fun Grouped.sumOf( // region Pivot -public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, numberColumns()) +public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, primitiveOrMixedNumberColumns()) public fun Pivot.sumFor( separate: Boolean = false, @@ -166,14 +258,15 @@ public fun Pivot.sum(vararg columns: ColumnReference): Da @AccessApiOverload public fun Pivot.sum(vararg columns: KProperty): DataRow = sum { columns.toColumnSet() } -public inline fun Pivot.sumOf(crossinline expression: RowExpression): DataRow = +public inline fun Pivot.sumOf(crossinline expression: RowExpression): DataRow = delegate { sumOf(expression) } // endregion // region PivotGroupBy -public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = sumFor(separate, numberColumns()) +public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = + sumFor(separate, primitiveOrMixedNumberColumns()) public fun PivotGroupBy.sumFor( separate: Boolean = false, @@ -209,7 +302,7 @@ public fun PivotGroupBy.sum(vararg columns: KProperty): D sum { columns.toColumnSet() } public inline fun PivotGroupBy.sumOf( - crossinline expression: RowExpression, + crossinline expression: RowExpression, ): DataFrame = Aggregators.sum.aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index f3f42ced5b..36d087b872 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -40,8 +40,8 @@ private val unifiedNumberTypeGraphs = mutableMapOf?, second: KClass<*>, options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KClass<*> = +): KClass<*>? = when { first == null -> second - first == second -> first - else -> getUnifiedNumberClassGraph(options).findNearestCommonVertex(first, second) - ?: error("Can not find common number type for $first and $second") } /** @@ -156,28 +153,28 @@ internal fun getUnifiedNumberClass( * * @param options See [UnifiedNumberTypeOptions] * @return The nearest common numeric type between the input types. - * If no common type is found, it returns [Number]. + * If no common type is found, it returns `null`. * @see UnifyingNumbers */ -internal fun Iterable.unifiedNumberType( +internal fun Iterable.unifiedNumberTypeOrNull( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KType = +): KType? = fold(null as KType?) { a, b -> - getUnifiedNumberType(a, b, options) - } ?: typeOf() + getUnifiedNumberTypeOrNull(a, b, options) ?: return null + } -/** @include [unifiedNumberType] */ -internal fun Iterable>.unifiedNumberClass( +/** @include [unifiedNumberTypeOrNull] */ +internal fun Iterable>.unifiedNumberClassOrNull( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KClass<*> = +): KClass<*>? = fold(null as KClass<*>?) { a, b -> - getUnifiedNumberClass(a, b, options) - } ?: Number::class + getUnifiedNumberClassOrNull(a, b, options) ?: return null + } /** * Converts the elements of the given iterable of numbers into a common numeric type based on complexity. * The common numeric type is determined using the provided [commonNumberType] parameter - * or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. + * or calculated with [Iterable.unifiedNumberTypeOrNull] from the iterable's elements if not explicitly specified. * * @param commonNumberType The desired common numeric type to convert the elements to. * By default, (or if `null`), this is determined using the types of the elements in the iterable. @@ -191,7 +188,12 @@ internal fun Iterable.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Iterable { - val commonNumberType = commonNumberType ?: this.types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.types().let { types -> + types.unifiedNumberTypeOrNull(options) + ?: throw IllegalArgumentException( + "Cannot find unified number type of types: ${types.joinToString { renderType(it) }}", + ) + } val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null @@ -216,7 +218,12 @@ internal fun Sequence.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Sequence { - val commonNumberType = commonNumberType ?: this.asIterable().types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.asIterable().types().let { types -> + types.unifiedNumberTypeOrNull(options) + ?: throw IllegalArgumentException( + "Cannot find unified number type of types: ${types.joinToString { renderType(it) }}", + ) + } val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null @@ -245,7 +252,28 @@ internal val primitiveNumberTypes: Set = typeOf(), ) -internal fun Any.isPrimitiveNumber(): Boolean = +/** Returns `true` only when this type is exactly `Number` or `Number?`. */ +@PublishedApi +internal fun KType.isMixedNumber(): Boolean = this == typeOf() || this == typeOf() + +/** + * Returns `true` when this type is one of the following (nullable) types: + * [Byte], [Short], [Int], [Long], [Float], or [Double]. + */ +@PublishedApi +internal fun KType.isPrimitiveNumber(): Boolean = this.withNullability(false) in primitiveNumberTypes + +/** + * Returns `true` when this type is one of the following (nullable) types: + * [Byte], [Short], [Int], [Long], [Float], [Double], or [Number]. + * + * Careful: Will return `true` for `Number`. + * This type may arise as a supertype from multiple non-primitive number types. + */ +@PublishedApi +internal fun KType.isPrimitiveOrMixedNumber(): Boolean = isPrimitiveNumber() || isMixedNumber() + +internal fun Number.isPrimitiveNumber(): Boolean = this is Byte || this is Short || this is Int || diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt index d6b4b96ed4..d8e24f6d02 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt @@ -477,7 +477,7 @@ internal fun guessValueType( it.isSubclassOf(Number::class) && it != nothingClass } if (usedNumberClasses.isNotEmpty()) { - val unifiedNumberClass = usedNumberClasses.unifiedNumberClass() as KClass + val unifiedNumberClass = usedNumberClasses.unifiedNumberClassOrNull() as KClass classes -= usedNumberClasses classes += unifiedNumberClass } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index 31784c6036..8cfc4e92a7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -12,7 +12,7 @@ import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.impl.types -import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberType +import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberTypeOrNull import kotlin.reflect.KType import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.full.starProjectedType @@ -95,11 +95,15 @@ internal class TwoStepNumbersAggregator( calculateReturnTypeOrNull(type = type.withNullability(false), emptyInput = colsEmpty) } if (typesAfterStepOne.anyNull()) return null - val commonType = (typesAfterStepOne as List) - .toSet() - .unifiedNumberType(PRIMITIVES_ONLY) - .withNullability(false) - return commonType + val typeSet = (typesAfterStepOne as List).toSet() + val unifiedType = typeSet.unifiedNumberTypeOrNull(PRIMITIVES_ONLY) + ?.withNullability(false) + ?: throw IllegalArgumentException( + "Cannot calculate the $name of the number types: ${typeSet.joinToString { renderType(it) }}. " + + "Note, only primitive number types are supported in statistics.", + ) + + return unifiedType } /** @@ -151,24 +155,28 @@ internal class TwoStepNumbersAggregator( @Suppress("UNCHECKED_CAST") override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val valueTypes = valueTypes?.takeUnless { it.isEmpty() } ?: values.types() - val commonType = valueTypes.unifiedNumberType(PRIMITIVES_ONLY) + val unifiedType = valueTypes.unifiedNumberTypeOrNull(PRIMITIVES_ONLY) + ?: throw IllegalArgumentException( + "Cannot calculate the $name of the number types: ${valueTypes.joinToString { renderType(it) }}. " + + "Note, only primitive number types are supported in statistics.", + ) - if (commonType.isSubtypeOf(typeOf()) && + if (unifiedType.isSubtypeOf(typeOf()) && (typeOf() in valueTypes || typeOf() in valueTypes) ) { logger.warn { "Number unification of Long -> Double happened during aggregation. Loss of precision may have occurred." } } - if (commonType.withNullability(false) !in primitiveNumberTypes && !commonType.isNothing) { + if (unifiedType.withNullability(false) !in primitiveNumberTypes && !unifiedType.isNothing) { throw IllegalArgumentException( - "Cannot calculate $name of ${renderType(commonType)}, only primitive numbers are supported.", + "Cannot calculate $name of ${renderType(unifiedType)}, only primitive numbers are supported.", ) } return super.aggregate( - values = values.convertToUnifiedNumberType(commonNumberType = commonType), - type = commonType, + values = values.convertToUnifiedNumberType(commonNumberType = unifiedType), + type = unifiedType, ) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt index 66f598b02d..136efa687e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt @@ -2,11 +2,13 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation import org.jetbrains.kotlinx.dataframe.AnyCol import org.jetbrains.kotlinx.dataframe.ColumnsSelector +import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue +import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.isNumber -import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber +import org.jetbrains.kotlinx.dataframe.api.isPrimitiveOrMixedNumber import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType @@ -24,8 +26,11 @@ internal fun Aggregatable.numberColumns(): ColumnsSelector = remainingColumns { it.isNumber() } as ColumnsSelector @Suppress("UNCHECKED_CAST") -internal fun Aggregatable.primitiveNumberColumns(): ColumnsSelector = - remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector +internal fun Aggregatable.primitiveOrMixedNumberColumns(): ColumnsSelector = + remainingColumns { it.isPrimitiveOrMixedNumber() } as ColumnsSelector + +internal fun DataRow.primitiveOrMixedNumberColumns(): ColumnsSelector = + { cols { it.isPrimitiveOrMixedNumber() }.cast() } internal fun NamedValue.toColumnWithPath() = path to createColumnGuessingType( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt index 6ec6459ff0..70e55e2f8b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt @@ -41,7 +41,7 @@ internal fun Aggregator<*, R>.aggregateFor( } internal fun AggregateInternalDsl.aggregateFor( - columns: ColumnsForAggregateSelector, + columns: ColumnsForAggregateSelector, aggregator: Aggregator, ) { val cols = df.getAggregateColumns(columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt index 53abbbe9b5..d6adab9304 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt @@ -1,7 +1,7 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes -import org.jetbrains.kotlinx.dataframe.AnyRow import org.jetbrains.kotlinx.dataframe.ColumnsSelector +import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.api.getColumns import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator @@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator * @param columns selector of which columns inside the [row] to aggregate */ @PublishedApi -internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R { +internal fun Aggregator.aggregateOfRow(row: DataRow, columns: ColumnsSelector): R { val filteredColumns = row.df().getColumns(columns) return aggregateCalculatingType( values = filteredColumns.mapNotNull { row[it] }, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt index b7217d3b1e..7467f2da3f 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt @@ -15,18 +15,18 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal import org.jetbrains.kotlinx.dataframe.impl.emptyPath @PublishedApi -internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = +internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = data.aggregateAll(cast2(), columns) internal fun Aggregator<*, R>.aggregateAll( data: Grouped, name: String?, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns, name) internal fun Aggregator<*, R>.aggregateAll( data: PivotGroupBy, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns) internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R = @@ -34,7 +34,7 @@ internal fun DataFrame.aggregateAll(aggregator: Aggregator, c internal fun Grouped.aggregateAll( aggregator: Aggregator, - columns: ColumnsSelector, + columns: ColumnsSelector, name: String?, ): DataFrame = aggregateInternal { @@ -48,7 +48,7 @@ internal fun Grouped.aggregateAll( internal fun PivotGroupBy.aggregateAll( aggregator: Aggregator, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = aggregate { val cols = get(columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 07de30db44..7ef56e1147 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -1,139 +1,60 @@ package org.jetbrains.kotlinx.dataframe.math import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnTypeOrNull -import java.math.BigDecimal -import java.math.BigInteger +import org.jetbrains.kotlinx.dataframe.impl.nothingType +import org.jetbrains.kotlinx.dataframe.impl.renderType import kotlin.reflect.KType import kotlin.reflect.full.withNullability +import kotlin.reflect.typeOf +import kotlin.sequences.filterNotNull +internal fun Iterable.sum(type: KType): Number = asSequence().sum(type) + +@Suppress("UNCHECKED_CAST") +@JvmName("sumNullableT") @PublishedApi -internal fun Iterable.sumOf(type: KType, selector: (T) -> R?): R { +internal fun Sequence.sum(type: KType): Number { if (type.isMarkedNullable) { - val seq = asSequence().mapNotNull(selector).asIterable() - return seq.sum(type) + return filterNotNull().sum(type.withNullability(false)) } - return when (type.classifier) { - Double::class -> sumOf(selector as ((T) -> Double)) as R - - // careful, conversion to Double to Float occurs! TODO, Issue #558 - Float::class -> sumOf { (selector as ((T) -> Float))(it).toDouble() }.toFloat() as R + return when (type.withNullability(false)) { + typeOf() -> (this as Sequence).sum() - Int::class -> sumOf(selector as ((T) -> Int)) as R + typeOf() -> (this as Sequence).sum() - // careful, conversion to Int occurs! TODO, Issue #558 - Short::class -> sumOf { (selector as ((T) -> Short))(it).toInt() }.toShort() as R + typeOf() -> (this as Sequence).sum() - // careful, conversion to Int occurs! TODO, Issue #558 - Byte::class -> sumOf { (selector as ((T) -> Byte))(it).toInt() }.toByte() as R + // Note: returns Int + typeOf() -> (this as Sequence).sum() - Long::class -> sumOf(selector as ((T) -> Long)) as R + // Note: returns Int + typeOf() -> (this as Sequence).sum() - BigDecimal::class -> sumOf(selector as ((T) -> BigDecimal)) as R + typeOf() -> (this as Sequence).sum() - BigInteger::class -> sumOf(selector as ((T) -> BigInteger)) as R + typeOf() -> + error("Encountered non-specific Number type in sum function. This should not occur.") - Number::class -> sumOf { (selector as ((T) -> Number))(it).toDouble() } as R + nothingType -> 0.0 - Nothing::class -> 0.0 as R - - else -> throw IllegalArgumentException("sumOf is not supported for $type") + else -> throw IllegalArgumentException( + "Unable to compute the sum for ${renderType(type)}, Only primitive numbers are supported.", + ) } } -@PublishedApi -internal fun Iterable.sum(type: KType): T = - when (type.classifier) { - Double::class -> (this as Iterable).sum() as T - - Float::class -> (this as Iterable).sum() as T - - Int::class -> (this as Iterable).sum() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Short::class -> (this as Iterable).sum().toShort() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Byte::class -> (this as Iterable).sum().toByte() as T - - Long::class -> (this as Iterable).sum() as T - - BigDecimal::class -> (this as Iterable).sum() as T - - BigInteger::class -> (this as Iterable).sum() as T - - Number::class -> (this as Iterable).map { it.toDouble() }.sum() as T - - Nothing::class -> 0.0 as T - - else -> throw IllegalArgumentException("sum is not supported for $type") - } - -@JvmName("sumNullableT") -@PublishedApi -internal fun Iterable.sum(type: KType): T = - when (type.classifier) { - Double::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - Float::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - Int::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Short::class -> (this as Iterable).asSequence().filterNotNull().sum().toShort() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Byte::class -> (this as Iterable).asSequence().filterNotNull().sum().toByte() as T - - Long::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - BigDecimal::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - BigInteger::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - Number::class -> (this as Iterable).asSequence().filterNotNull().map { it.toDouble() }.sum() as T - - Nothing::class -> 0.0 as T - - else -> throw IllegalArgumentException("sum is not supported for $type") - } - /** T: Number? -> T */ internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ -> - type.withNullability(false) -} + when (val type = type.withNullability(false)) { + // type changes to Int + typeOf(), typeOf() -> typeOf() -@PublishedApi -internal fun Iterable.sum(): BigDecimal { - var sum: BigDecimal = BigDecimal.ZERO - for (element in this) { - sum += element - } - return sum -} - -@PublishedApi -internal fun Sequence.sum(): BigDecimal { - var sum: BigDecimal = BigDecimal.ZERO - for (element in this) { - sum += element - } - return sum -} + // type remains the same + typeOf(), typeOf(), typeOf(), typeOf(), typeOf() -> type -@PublishedApi -internal fun Iterable.sum(): BigInteger { - var sum: BigInteger = BigInteger.ZERO - for (element in this) { - sum += element - } - return sum -} + nothingType -> typeOf() -@PublishedApi -internal fun Sequence.sum(): BigInteger { - var sum: BigInteger = BigInteger.ZERO - for (element in this) { - sum += element + else -> + error("Unable to compute the sum for ${renderType(type)}, Only primitive numbers are supported.") } - return sum } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt index 1980e5da70..0e158e73d4 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt @@ -2,9 +2,10 @@ package org.jetbrains.kotlinx.dataframe.statistics import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe -import org.jetbrains.kotlinx.dataframe.DataColumn +import io.kotest.matchers.string.shouldContain import org.jetbrains.kotlinx.dataframe.api.columnOf import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.isEmpty import org.jetbrains.kotlinx.dataframe.api.rowSum import org.jetbrains.kotlinx.dataframe.api.sum import org.jetbrains.kotlinx.dataframe.api.sumOf @@ -45,7 +46,7 @@ class SumTests { fun `test multiple columns`() { val value1 by columnOf(1, 2, 3) val value2 by columnOf(4.0, 5.0, 6.0) - val value3: DataColumn by columnOf(7.0, 8, null) + val value3 by columnOf(7.0, 8, null) val df = dataFrameOf(value1, value2, value3) val expected1 = 6 val expected2 = 15.0 @@ -88,8 +89,21 @@ class SumTests { @Test fun `unknown number type`() { + columnOf(1.toBigDecimal(), 2.toBigDecimal()).toDataFrame() + .sum() + .isEmpty() shouldBe true + } + + @Test + fun `mixed numbers`() { + // mixed number types are picked up implicitly + columnOf(1.0, 2).toDataFrame() + .sum()[0] shouldBe 3.0 + + // in the slight case a mixed number column contains unsupported numbers + // we give a helpful exception telling about primitive support only shouldThrow { - columnOf(1.toBigDecimal(), 2.toBigDecimal()).toDataFrame().sum() - } + columnOf(1.0, 2, 3.0.toBigDecimal()).toDataFrame().sum()[0] + }.message?.lowercase() shouldContain "primitive" } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt index 3518039dbf..3b40948732 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt @@ -11,7 +11,7 @@ import org.jetbrains.kotlinx.dataframe.impl.commonParents import org.jetbrains.kotlinx.dataframe.impl.commonType import org.jetbrains.kotlinx.dataframe.impl.commonTypeListifyValues import org.jetbrains.kotlinx.dataframe.impl.createType -import org.jetbrains.kotlinx.dataframe.impl.getUnifiedNumberClass +import org.jetbrains.kotlinx.dataframe.impl.getUnifiedNumberClassOrNull import org.jetbrains.kotlinx.dataframe.impl.guessValueType import org.jetbrains.kotlinx.dataframe.impl.isArray import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveArray @@ -426,40 +426,40 @@ class UtilTests { @Test fun `common number types`() { // Same type - getUnifiedNumberClass(Int::class, Int::class) shouldBe Int::class - getUnifiedNumberClass(Double::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Int::class, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(Double::class, Double::class) shouldBe Double::class // Direct parent-child relationships - getUnifiedNumberClass(Int::class, UShort::class) shouldBe Int::class - getUnifiedNumberClass(Long::class, UInt::class) shouldBe Long::class - getUnifiedNumberClass(Double::class, Float::class) shouldBe Double::class - getUnifiedNumberClass(UShort::class, Short::class) shouldBe Int::class - getUnifiedNumberClass(UByte::class, Byte::class) shouldBe Short::class + getUnifiedNumberClassOrNull(Int::class, UShort::class) shouldBe Int::class + getUnifiedNumberClassOrNull(Long::class, UInt::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Double::class, Float::class) shouldBe Double::class + getUnifiedNumberClassOrNull(UShort::class, Short::class) shouldBe Int::class + getUnifiedNumberClassOrNull(UByte::class, Byte::class) shouldBe Short::class - getUnifiedNumberClass(UByte::class, UShort::class) shouldBe UShort::class + getUnifiedNumberClassOrNull(UByte::class, UShort::class) shouldBe UShort::class // Multi-level relationships - getUnifiedNumberClass(Byte::class, Int::class) shouldBe Int::class - getUnifiedNumberClass(UByte::class, Long::class) shouldBe Long::class - getUnifiedNumberClass(Short::class, Double::class) shouldBe Double::class - getUnifiedNumberClass(UInt::class, Int::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Byte::class, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(UByte::class, Long::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Short::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(UInt::class, Int::class) shouldBe Long::class // Top-level types - getUnifiedNumberClass(BigDecimal::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(BigInteger::class, Long::class) shouldBe BigInteger::class - getUnifiedNumberClass(BigDecimal::class, BigInteger::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigDecimal::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigInteger::class, Long::class) shouldBe BigInteger::class + getUnifiedNumberClassOrNull(BigDecimal::class, BigInteger::class) shouldBe BigDecimal::class // Distant relationships - getUnifiedNumberClass(Byte::class, BigDecimal::class) shouldBe BigDecimal::class - getUnifiedNumberClass(UByte::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Byte::class, BigDecimal::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(UByte::class, Double::class) shouldBe Double::class // Complex type promotions - getUnifiedNumberClass(Int::class, Float::class) shouldBe Double::class - getUnifiedNumberClass(Long::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(ULong::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(BigInteger::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(Int::class, Float::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Long::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(ULong::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigInteger::class, Double::class) shouldBe BigDecimal::class // Edge case with null - getUnifiedNumberClass(null, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(null, Int::class) shouldBe Int::class } }