Skip to content

Commit 7502a80

Browse files
committed
made mean and sum use isPrimitiveOrMixedNumber(). Number unification can now return null, so we can give helpful errors from the aggregator
1 parent 9aaf84d commit 7502a80

File tree

11 files changed

+163
-94
lines changed

11 files changed

+163
-94
lines changed

core/api/core.api

+5
Original file line numberDiff line numberDiff line change
@@ -1781,8 +1781,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt {
17811781
public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17821782
public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17831783
public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
1784+
public static final fun isMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17841785
public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17851786
public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
1787+
public static final fun isPrimitiveOrMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17861788
public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z
17871789
public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
17881790
public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z
@@ -5116,6 +5118,9 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt {
51165118

51175119
public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt {
51185120
public static final fun getPrimitiveNumberTypes ()Ljava/util/Set;
5121+
public static final fun isMixedNumber (Lkotlin/reflect/KType;)Z
5122+
public static final fun isPrimitiveNumber (Lkotlin/reflect/KType;)Z
5123+
public static final fun isPrimitiveOrMixedNumber (Lkotlin/reflect/KType;)Z
51195124
}
51205125

51215126
public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt {

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt

+16-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
77
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
88
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
99
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
10-
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
10+
import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber
11+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
12+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
1113
import org.jetbrains.kotlinx.dataframe.type
1214
import org.jetbrains.kotlinx.dataframe.typeClass
1315
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
@@ -48,11 +50,23 @@ public inline fun <reified T> AnyCol.isType(): Boolean = type() == typeOf<T>()
4850
/** Returns `true` when this column's type is a subtype of `Number?` */
4951
public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()
5052

53+
/** Returns `true` only when this column's type is exactly `Number` or `Number?`. */
54+
public fun AnyCol.isMixedNumber(): Boolean = type().isMixedNumber()
55+
5156
/**
5257
* Returns `true` when this column has the (nullable) type of either:
5358
* [Byte], [Short], [Int], [Long], [Float], or [Double].
5459
*/
55-
public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes
60+
public fun AnyCol.isPrimitiveNumber(): Boolean = type().isPrimitiveNumber()
61+
62+
/**
63+
* Returns `true` when this column has the (nullable) type of either:
64+
* [Byte], [Short], [Int], [Long], [Float], [Double], or [Number].
65+
*
66+
* Careful: Will return `true` if the column contains multiple number types that
67+
* might NOT be primitive.
68+
*/
69+
public fun AnyCol.isPrimitiveOrMixedNumber(): Boolean = type().isPrimitiveOrMixedNumber()
5670

5771
public fun AnyCol.isList(): Boolean = typeClass == List::class
5872

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

+8-9
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
1818
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
1919
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2020
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
21-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
21+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
2222
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
23-
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
23+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
2424
import kotlin.reflect.KProperty
25-
import kotlin.reflect.full.withNullability
2625
import kotlin.reflect.typeOf
2726

2827
/* TODO KDocs:
@@ -47,10 +46,10 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
4746
// region DataRow
4847

4948
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
50-
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns())
49+
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveOrMixedNumberColumns())
5150

5251
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
53-
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
52+
require(typeOf<T>().isPrimitiveOrMixedNumber()) {
5453
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
5554
}
5655
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
@@ -61,7 +60,7 @@ public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipN
6160
// region DataFrame
6261

6362
public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
64-
meanFor(skipNA, primitiveNumberColumns())
63+
meanFor(skipNA, primitiveOrMixedNumberColumns())
6564

6665
public fun <T, C : Number> DataFrame<T>.meanFor(
6766
skipNA: Boolean = skipNA_default,
@@ -112,7 +111,7 @@ public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
112111
@Refine
113112
@Interpretable("GroupByMean1")
114113
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
115-
meanFor(skipNA, primitiveNumberColumns())
114+
meanFor(skipNA, primitiveOrMixedNumberColumns())
116115

117116
@Refine
118117
@Interpretable("GroupByMean0")
@@ -177,7 +176,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
177176
// region Pivot
178177

179178
public fun <T> Pivot<T>.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow<T> =
180-
meanFor(skipNA, separate, primitiveNumberColumns())
179+
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())
181180

182181
public fun <T, C : Number> Pivot<T>.meanFor(
183182
skipNA: Boolean = skipNA_default,
@@ -220,7 +219,7 @@ public inline fun <T, reified R : Number> Pivot<T>.meanOf(
220219
// region PivotGroupBy
221220

222221
public fun <T> PivotGroupBy<T>.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame<T> =
223-
meanFor(skipNA, separate, primitiveNumberColumns())
222+
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())
224223

225224
public fun <T, C : Number> PivotGroupBy<T>.meanFor(
226225
skipNA: Boolean = skipNA_default,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

+9-9
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
2121
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2222
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2323
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
24-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
24+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
2525
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
26-
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
26+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
2727
import kotlin.experimental.ExperimentalTypeInference
2828
import kotlin.reflect.KClass
2929
import kotlin.reflect.KProperty
3030
import kotlin.reflect.KType
31-
import kotlin.reflect.full.withNullability
3231
import kotlin.reflect.typeOf
3332

3433
/* TODO KDocs
@@ -70,7 +69,7 @@ public inline fun <C, reified V : Number> DataColumn<C>.sumOf(crossinline expres
7069

7170
// region DataRow
7271

73-
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns())
72+
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveOrMixedNumberColumns())
7473

7574
@JvmName("rowSumOfShort")
7675
public inline fun <reified T : Short?> AnyRow.rowSumOf(_kClass: KClass<Short> = Short::class): Int =
@@ -98,7 +97,7 @@ public inline fun <reified T : Double?> AnyRow.rowSumOf(_kClass: KClass<Double>
9897

9998
// unfortunately, we cannot make a `reified T : Number?` due to clashes
10099
public fun AnyRow.rowSumOf(type: KType): Number {
101-
require(type.withNullability(false) in primitiveNumberTypes) {
100+
require(type.isPrimitiveOrMixedNumber()) {
102101
"Type $type is not a primitive number type. Mean only supports primitive number types."
103102
}
104103
return Aggregators.sum.aggregateOfRow(this) { colsOf(type) }
@@ -107,7 +106,7 @@ public fun AnyRow.rowSumOf(type: KType): Number {
107106

108107
// region DataFrame
109108

110-
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(primitiveNumberColumns())
109+
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(primitiveOrMixedNumberColumns())
111110

112111
public fun <T, C : Number> DataFrame<T>.sumFor(columns: ColumnsForAggregateSelector<T, C?>): DataRow<T> =
113112
Aggregators.sum.aggregateFor(this, columns)
@@ -185,7 +184,7 @@ public inline fun <T, reified C : Number> DataFrame<T>.sumOf(crossinline express
185184
// region GroupBy
186185
@Refine
187186
@Interpretable("GroupBySum1")
188-
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(primitiveNumberColumns())
187+
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(primitiveOrMixedNumberColumns())
189188

190189
@Refine
191190
@Interpretable("GroupBySum0")
@@ -229,7 +228,7 @@ public inline fun <T, reified R : Number> Grouped<T>.sumOf(
229228

230229
// region Pivot
231230

232-
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, primitiveNumberColumns())
231+
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, primitiveOrMixedNumberColumns())
233232

234233
public fun <T, R : Number> Pivot<T>.sumFor(
235234
separate: Boolean = false,
@@ -266,7 +265,8 @@ public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression:
266265

267266
// region PivotGroupBy
268267

269-
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> = sumFor(separate, primitiveNumberColumns())
268+
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> =
269+
sumFor(separate, primitiveOrMixedNumberColumns())
270270

271271
public fun <T, R : Number> PivotGroupBy<T>.sumFor(
272272
separate: Boolean = false,

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt

+53-25
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ private val unifiedNumberTypeGraphs = mutableMapOf<UnifiedNumberTypeOptions, Dir
4040
* by calling [DirectedAcyclicGraph.findNearestCommonVertex].
4141
*
4242
* @param options See [UnifiedNumberTypeOptions]
43-
* @see getUnifiedNumberClass
44-
* @see unifiedNumberClass
43+
* @see getUnifiedNumberClassOrNull
44+
* @see unifiedNumberClassOrNull
4545
* @see UnifyingNumbers
4646
*/
4747
internal fun getUnifiedNumberTypeGraph(
@@ -107,11 +107,11 @@ internal fun getUnifiedNumberClassGraph(
107107
* If no common class is found, [IllegalStateException] is thrown.
108108
* @see UnifyingNumbers
109109
*/
110-
internal fun getUnifiedNumberType(
110+
internal fun getUnifiedNumberTypeOrNull(
111111
first: KType?,
112112
second: KType,
113113
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
114-
): KType {
114+
): KType? {
115115
if (first == null) return second
116116

117117
val firstWithoutNullability = first.withNullability(false)
@@ -121,7 +121,7 @@ internal fun getUnifiedNumberType(
121121
firstWithoutNullability
122122
} else {
123123
getUnifiedNumberTypeGraph(options).findNearestCommonVertex(firstWithoutNullability, secondWithoutNullability)
124-
?: error("Can not find common number type for $first and $second")
124+
?: return null
125125
}
126126

127127
return if (first.isMarkedNullable || second.isMarkedNullable) {
@@ -131,20 +131,17 @@ internal fun getUnifiedNumberType(
131131
}
132132
}
133133

134-
/** @include [getUnifiedNumberType] */
134+
/** @include [getUnifiedNumberTypeOrNull] */
135135
@Suppress("IntroduceWhenSubject")
136-
internal fun getUnifiedNumberClass(
136+
internal fun getUnifiedNumberClassOrNull(
137137
first: KClass<*>?,
138138
second: KClass<*>,
139139
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
140-
): KClass<*> =
140+
): KClass<*>? =
141141
when {
142142
first == null -> second
143-
144143
first == second -> first
145-
146144
else -> getUnifiedNumberClassGraph(options).findNearestCommonVertex(first, second)
147-
?: error("Can not find common number type for $first and $second")
148145
}
149146

150147
/**
@@ -156,28 +153,28 @@ internal fun getUnifiedNumberClass(
156153
*
157154
* @param options See [UnifiedNumberTypeOptions]
158155
* @return The nearest common numeric type between the input types.
159-
* If no common type is found, it returns [Number].
156+
* If no common type is found, it returns `null`.
160157
* @see UnifyingNumbers
161158
*/
162-
internal fun Iterable<KType>.unifiedNumberType(
159+
internal fun Iterable<KType>.unifiedNumberTypeOrNull(
163160
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
164-
): KType =
161+
): KType? =
165162
fold(null as KType?) { a, b ->
166-
getUnifiedNumberType(a, b, options)
167-
} ?: typeOf<Number>()
163+
getUnifiedNumberTypeOrNull(a, b, options) ?: return null
164+
}
168165

169-
/** @include [unifiedNumberType] */
170-
internal fun Iterable<KClass<*>>.unifiedNumberClass(
166+
/** @include [unifiedNumberTypeOrNull] */
167+
internal fun Iterable<KClass<*>>.unifiedNumberClassOrNull(
171168
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
172-
): KClass<*> =
169+
): KClass<*>? =
173170
fold(null as KClass<*>?) { a, b ->
174-
getUnifiedNumberClass(a, b, options)
175-
} ?: Number::class
171+
getUnifiedNumberClassOrNull(a, b, options) ?: return null
172+
}
176173

177174
/**
178175
* Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
179176
* The common numeric type is determined using the provided [commonNumberType] parameter
180-
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
177+
* or calculated with [Iterable.unifiedNumberTypeOrNull] from the iterable's elements if not explicitly specified.
181178
*
182179
* @param commonNumberType The desired common numeric type to convert the elements to.
183180
* By default, (or if `null`), this is determined using the types of the elements in the iterable.
@@ -191,7 +188,12 @@ internal fun Iterable<Number?>.convertToUnifiedNumberType(
191188
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
192189
commonNumberType: KType? = null,
193190
): Iterable<Number?> {
194-
val commonNumberType = commonNumberType ?: this.types().unifiedNumberType(options)
191+
val commonNumberType = commonNumberType ?: this.types().let { types ->
192+
types.unifiedNumberTypeOrNull(options)
193+
?: throw IllegalArgumentException(
194+
"Cannot find unified number type of types: ${types.joinToString { renderType(it) }}",
195+
)
196+
}
195197
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
196198
return map {
197199
if (it == null) return@map null
@@ -216,7 +218,12 @@ internal fun Sequence<Number?>.convertToUnifiedNumberType(
216218
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
217219
commonNumberType: KType? = null,
218220
): Sequence<Number?> {
219-
val commonNumberType = commonNumberType ?: this.asIterable().types().unifiedNumberType(options)
221+
val commonNumberType = commonNumberType ?: this.asIterable().types().let { types ->
222+
types.unifiedNumberTypeOrNull(options)
223+
?: throw IllegalArgumentException(
224+
"Cannot find unified number type of types: ${types.joinToString { renderType(it) }}",
225+
)
226+
}
220227
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
221228
return map {
222229
if (it == null) return@map null
@@ -245,7 +252,28 @@ internal val primitiveNumberTypes: Set<KType> =
245252
typeOf<Double>(),
246253
)
247254

248-
internal fun Any.isPrimitiveNumber(): Boolean =
255+
/** Returns `true` only when this type is exactly `Number` or `Number?`. */
256+
@PublishedApi
257+
internal fun KType.isMixedNumber(): Boolean = this == typeOf<Number>() || this == typeOf<Number?>()
258+
259+
/**
260+
* Returns `true` when this type is one of the following (nullable) types:
261+
* [Byte], [Short], [Int], [Long], [Float], or [Double].
262+
*/
263+
@PublishedApi
264+
internal fun KType.isPrimitiveNumber(): Boolean = this.withNullability(false) in primitiveNumberTypes
265+
266+
/**
267+
* Returns `true` when this type is one of the following (nullable) types:
268+
* [Byte], [Short], [Int], [Long], [Float], [Double], or [Number].
269+
*
270+
* Careful: Will return `true` for `Number`.
271+
* This type may arise as a supertype from multiple non-primitive number types.
272+
*/
273+
@PublishedApi
274+
internal fun KType.isPrimitiveOrMixedNumber(): Boolean = isPrimitiveNumber() || isMixedNumber()
275+
276+
internal fun Number.isPrimitiveNumber(): Boolean =
249277
this is Byte ||
250278
this is Short ||
251279
this is Int ||

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ internal fun guessValueType(
477477
it.isSubclassOf(Number::class) && it != nothingClass
478478
}
479479
if (usedNumberClasses.isNotEmpty()) {
480-
val unifiedNumberClass = usedNumberClasses.unifiedNumberClass() as KClass<Number>
480+
val unifiedNumberClass = usedNumberClasses.unifiedNumberClassOrNull() as KClass<Number>
481481
classes -= usedNumberClasses
482482
classes += unifiedNumberClass
483483
}

0 commit comments

Comments
 (0)