Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sum statistics and aggregator improvements #1103

Merged
merged 9 commits into from
Mar 20, 2025
29 changes: 20 additions & 9 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1781,8 +1781,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt {
public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isPrimitiveOrMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z
public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
Expand Down Expand Up @@ -3816,6 +3818,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt {

public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number;
public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number;
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number;
public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
Expand All @@ -3839,6 +3842,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow;
Expand All @@ -3863,8 +3870,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt {
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number;
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I
public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I
public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I
}

public final class org/jetbrains/kotlinx/dataframe/api/TailKt {
Expand Down Expand Up @@ -5104,6 +5118,9 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt {

public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt {
public static final fun getPrimitiveNumberTypes ()Ljava/util/Set;
public static final fun isMixedNumber (Lkotlin/reflect/KType;)Z
public static final fun isPrimitiveNumber (Lkotlin/reflect/KType;)Z
public static final fun isPrimitiveOrMixedNumber (Lkotlin/reflect/KType;)Z
}

public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt {
Expand Down Expand Up @@ -6153,13 +6170,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt {
}

public final class org/jetbrains/kotlinx/dataframe/math/SumKt {
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal;
public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger;
public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal;
public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger;
public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number;
public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number;
public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number;
}

public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
import org.jetbrains.kotlinx.dataframe.type
import org.jetbrains.kotlinx.dataframe.typeClass
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
Expand Down Expand Up @@ -48,11 +50,23 @@ public inline fun <reified T> AnyCol.isType(): Boolean = type() == typeOf<T>()
/** Returns `true` when this column's type is a subtype of `Number?` */
public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()

/** Returns `true` only when this column's type is exactly `Number` or `Number?`. */
public fun AnyCol.isMixedNumber(): Boolean = type().isMixedNumber()

/**
* Returns `true` when this column has the (nullable) type of either:
* [Byte], [Short], [Int], [Long], [Float], or [Double].
*/
public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes
public fun AnyCol.isPrimitiveNumber(): Boolean = type().isPrimitiveNumber()

/**
* Returns `true` when this column has the (nullable) type of either:
* [Byte], [Short], [Int], [Long], [Float], [Double], or [Number].
*
* Careful: Will return `true` if the column contains multiple number types that
* might NOT be primitive.
*/
public fun AnyCol.isPrimitiveOrMixedNumber(): Boolean = type().isPrimitiveOrMixedNumber()

public fun AnyCol.isList(): Boolean = typeClass == List::class

Expand Down
27 changes: 11 additions & 16 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
import kotlin.reflect.KProperty
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf

/*
* TODO KDocs:
/* TODO KDocs:
* Calculating the mean is supported for all primitive number types.
* Nulls are filtered from columns.
* Nulls are filtered out.
* The return type is always Double, Double.NaN for empty input, never null.
* (May introduce loss of precision for Longs).
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
Expand All @@ -48,24 +46,21 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
// region DataRow

public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
Aggregators.mean(skipNA).aggregateOfRow(this) {
colsOf<Number?> { it.isPrimitiveNumber() }
}
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveOrMixedNumberColumns())

public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
require(typeOf<T>().isPrimitiveOrMixedNumber()) {
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
}
return Aggregators.mean(skipNA)
.aggregateOfRow(this) { colsOf<T>() }
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
}

// endregion

// region DataFrame

public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
meanFor(skipNA, primitiveNumberColumns())
meanFor(skipNA, primitiveOrMixedNumberColumns())

public fun <T, C : Number> DataFrame<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand Down Expand Up @@ -116,7 +111,7 @@ public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
@Refine
@Interpretable("GroupByMean1")
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
meanFor(skipNA, primitiveNumberColumns())
meanFor(skipNA, primitiveOrMixedNumberColumns())

@Refine
@Interpretable("GroupByMean0")
Expand Down Expand Up @@ -181,7 +176,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
// region Pivot

public fun <T> Pivot<T>.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow<T> =
meanFor(skipNA, separate, primitiveNumberColumns())
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())

public fun <T, C : Number> Pivot<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand Down Expand Up @@ -224,7 +219,7 @@ public inline fun <T, reified R : Number> Pivot<T>.meanOf(
// region PivotGroupBy

public fun <T> PivotGroupBy<T>.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame<T> =
meanFor(skipNA, separate, primitiveNumberColumns())
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())

public fun <T, C : Number> PivotGroupBy<T>.meanFor(
skipNA: Boolean = skipNA_default,
Expand Down
Loading
Loading