Skip to content

Commit b42c7b4

Browse files
committed
Fixed AggregatorBase only filtering nulls when the type says they exist.
unifying numbers can now handle null/nothing in the input.
1 parent 7d72bdf commit b42c7b4

File tree

6 files changed

+56
-21
lines changed

6 files changed

+56
-21
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.jetbrains.kotlinx.dataframe.impl.UnifiedNumberTypeOptions
2121
* potentially losing a little precision, but a warning will be given.
2222
*
2323
* See [UnifiedNumberTypeOptions] for these settings.
24+
*
25+
* At the bottom of the graph is [Nothing]. This can be interpreted as `null`.
2426
*/
2527
internal interface UnifyingNumbers {
2628

@@ -40,6 +42,9 @@ internal interface UnifyingNumbers {
4042
* | / |
4143
* | / |
4244
* UByte Byte
45+
* \\ /
46+
* \\ /
47+
* Nothing?
4348
* ```
4449
*/
4550
interface Graph

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ internal fun getUnifiedNumberTypeGraph(
8282

8383
addEdge(typeOf<Short>(), typeOf<UByte>())
8484
addEdge(typeOf<Short>(), typeOf<Byte>())
85+
86+
addEdge(typeOf<UByte>(), nothingType)
87+
addEdge(typeOf<Byte>(), nothingType)
8588
}
8689
}
8790

@@ -121,7 +124,11 @@ internal fun getUnifiedNumberType(
121124
?: error("Can not find common number type for $first and $second")
122125
}
123126

124-
return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result
127+
return if (first.isMarkedNullable || second.isMarkedNullable) {
128+
result.withNullability(true)
129+
} else {
130+
result
131+
}
125132
}
126133

127134
/** @include [getUnifiedNumberType] */
@@ -184,7 +191,7 @@ internal fun Iterable<Number?>.convertToUnifiedNumberType(
184191
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
185192
commonNumberType: KType? = null,
186193
): Iterable<Number?> {
187-
val commonNumberType = commonNumberType ?: this.filterNotNull().types().unifiedNumberType(options)
194+
val commonNumberType = commonNumberType ?: this.types().unifiedNumberType(options)
188195
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
189196
return map {
190197
if (it == null) return@map null
@@ -209,7 +216,7 @@ internal fun Sequence<Number?>.convertToUnifiedNumberType(
209216
options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT,
210217
commonNumberType: KType? = null,
211218
): Sequence<Number?> {
212-
val commonNumberType = commonNumberType ?: this.filterNotNull().asIterable().types().unifiedNumberType(options)
219+
val commonNumberType = commonNumberType ?: this.asIterable().types().unifiedNumberType(options)
213220
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
214221
return map {
215222
if (it == null) return@map null

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -667,19 +667,26 @@ internal fun Any.isBigNumber(): Boolean = this is BigInteger || this is BigDecim
667667
*
668668
* The [KClass] is determined by retrieving the runtime class of each element.
669669
*
670+
* [Nothing::class][Nothing] is used for elements that are `null`.
671+
*
670672
* @return A set of [KClass] objects representing the runtime types of elements in the iterable.
671673
*/
672-
internal fun Iterable<Any>.classes(): Set<KClass<*>> = mapTo(mutableSetOf()) { it::class }
674+
internal fun Iterable<Any?>.classes(): Set<KClass<*>> =
675+
mapTo(mutableSetOf()) {
676+
if (it == null) Nothing::class else it::class
677+
}
673678

674679
/**
675680
* Returns a set of [KType] objects representing the star-projected types of the runtime classes
676681
* of all unique elements in the iterable.
677682
*
678-
* The method internally relies on the [classes] function to collect the runtime classes of the
679-
* elements in the iterable and then maps each class to its star-projected type.
680-
*
681683
* This can be a heavy operation!
682684
*
685+
* [typeOf<Nothing?>()][nullableNothingType] is used for elements that are `null`.
686+
*
683687
* @return A set of [KType] objects corresponding to the star-projected runtime types of elements in the iterable.
684688
*/
685-
internal fun Iterable<Any>.types(): Set<KType> = classes().mapTo(mutableSetOf()) { it.createStarProjectedType(false) }
689+
internal fun Iterable<Any?>.types(): Set<KType> =
690+
mapTo(mutableSetOf()) {
691+
if (it == null) nullableNothingType else it::class.createStarProjectedType(false)
692+
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ internal interface Aggregator<in Value, out Return> {
2626
* Base function of [Aggregator].
2727
*
2828
* Aggregates the given values, taking [type] into account,
29-
* filtering nulls, and computes a single resulting value.
29+
* filtering nulls (only if [type.isMarkedNullable][KType.isMarkedNullable]),
30+
* and computes a single resulting value.
3031
*
3132
* When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument.
3233
*
@@ -55,7 +56,7 @@ internal interface Aggregator<in Value, out Return> {
5556
* @param valueTypes The types of the values.
5657
* If provided, this can be used to avoid calculating the types of [values] at runtime with reflection.
5758
* It should contain all types of [values].
58-
* If `null`, the types of [values] will be calculated at runtime (heavy!).
59+
* If `null` or empty, the types of [values] will be calculated at runtime (heavy!).
5960
*/
6061
fun aggregateCalculatingType(values: Iterable<Value?>, valueTypes: Set<KType>? = null): Return
6162

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorBase.kt

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.api.asIterable
55
import org.jetbrains.kotlinx.dataframe.api.asSequence
66
import org.jetbrains.kotlinx.dataframe.impl.commonType
7+
import org.jetbrains.kotlinx.dataframe.impl.nothingType
78
import kotlin.reflect.KType
89
import kotlin.reflect.full.withNullability
910

@@ -26,7 +27,8 @@ internal abstract class AggregatorBase<in Value, out Return>(
2627
* Base function of [Aggregator].
2728
*
2829
* Aggregates the given values, taking [type] into account,
29-
* filtering nulls, and computes a single resulting value.
30+
* filtering nulls (only if [type.isMarkedNullable][KType.isMarkedNullable]),
31+
* and computes a single resulting value.
3032
*
3133
* When using [AggregatorBase], this can be supplied by the [AggregatorBase.aggregator] argument.
3234
*
@@ -35,7 +37,13 @@ internal abstract class AggregatorBase<in Value, out Return>(
3537
@Suppress("UNCHECKED_CAST")
3638
override fun aggregate(values: Iterable<Value?>, type: KType): Return =
3739
aggregator(
38-
values.asSequence().filterNotNull().asIterable(), // TODO make dependant on type's nullability
40+
// values =
41+
if (type.isMarkedNullable) {
42+
values.asSequence().filterNotNull().asIterable()
43+
} else {
44+
values as Iterable<Value & Any>
45+
},
46+
// type =
3947
type.withNullability(false),
4048
)
4149

@@ -66,7 +74,7 @@ internal abstract class AggregatorBase<in Value, out Return>(
6674

6775
/** @include [Aggregator.aggregateCalculatingType] */
6876
override fun aggregateCalculatingType(values: Iterable<Value?>, valueTypes: Set<KType>?): Return {
69-
val commonType = if (valueTypes != null) {
77+
val commonType = if (valueTypes != null && valueTypes.isNotEmpty()) {
7078
valueTypes.commonType(false)
7179
} else {
7280
var hasNulls = false
@@ -78,7 +86,11 @@ internal abstract class AggregatorBase<in Value, out Return>(
7886
it.javaClass.kotlin
7987
}
8088
}
81-
classes.commonType(hasNulls)
89+
if (classes.isEmpty()) {
90+
nothingType(hasNulls)
91+
} else {
92+
classes.commonType(hasNulls)
93+
}
8294
}
8395
return aggregate(values, commonType)
8496
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers
77
import org.jetbrains.kotlinx.dataframe.impl.UnifiedNumberTypeOptions.Companion.PRIMITIVES_ONLY
88
import org.jetbrains.kotlinx.dataframe.impl.anyNull
99
import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType
10+
import org.jetbrains.kotlinx.dataframe.impl.isNothing
1011
import org.jetbrains.kotlinx.dataframe.impl.nothingType
1112
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
1213
import org.jetbrains.kotlinx.dataframe.impl.renderType
@@ -106,6 +107,8 @@ internal class TwoStepNumbersAggregator<out Return : Number?>(
106107
*
107108
* Aggregates the given values, taking [type] into account, and computes a single resulting value.
108109
*
110+
* Nulls are filtered out (only if [type.isMarkedNullable][KType.isMarkedNullable]).
111+
*
109112
* Uses [aggregator] to compute the result.
110113
*
111114
* This function is modified to call [aggregateCalculatingType] when it encounters mixed number types.
@@ -143,21 +146,21 @@ internal class TwoStepNumbersAggregator<out Return : Number?>(
143146
* @param valueTypes The types of the numbers.
144147
* If provided, this can be used to avoid calculating the types of [values] at runtime with reflection.
145148
* It should contain all types of [values].
146-
* If `null`, the types of [values] will be calculated at runtime (heavy!).
149+
* If `null` or empty, the types of [values] will be calculated at runtime (heavy!).
147150
*/
148151
@Suppress("UNCHECKED_CAST")
149152
override fun aggregateCalculatingType(values: Iterable<Number?>, valueTypes: Set<KType>?): Return {
150-
val valueTypes = valueTypes ?: values.filterNotNull().types()
151-
val commonType = valueTypes
152-
.unifiedNumberType(PRIMITIVES_ONLY)
153-
.withNullability(false)
153+
val valueTypes = valueTypes?.takeUnless { it.isEmpty() } ?: values.types()
154+
val commonType = valueTypes.unifiedNumberType(PRIMITIVES_ONLY)
154155

155-
if (commonType == typeOf<Double>() && (typeOf<ULong>() in valueTypes || typeOf<Long>() in valueTypes)) {
156+
if (commonType.isSubtypeOf(typeOf<Double?>()) &&
157+
(typeOf<ULong>() in valueTypes || typeOf<Long>() in valueTypes)
158+
) {
156159
logger.warn {
157160
"Number unification of Long -> Double happened during aggregation. Loss of precision may have occurred."
158161
}
159162
}
160-
if (commonType !in primitiveNumberTypes && commonType != nothingType) {
163+
if (commonType.withNullability(false) !in primitiveNumberTypes && !commonType.isNothing) {
161164
throw IllegalArgumentException(
162165
"Cannot calculate $name of ${renderType(commonType)}, only primitive numbers are supported.",
163166
)

0 commit comments

Comments
 (0)