Skip to content

Commit 436c442

Browse files
Automated commit of generated code
1 parent 4554833 commit 436c442

File tree

13 files changed

+318
-240
lines changed

13 files changed

+318
-240
lines changed

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
77
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
88
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
99
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
10-
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
10+
import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber
11+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
12+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
1113
import org.jetbrains.kotlinx.dataframe.type
1214
import org.jetbrains.kotlinx.dataframe.typeClass
1315
import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE
@@ -48,11 +50,23 @@ public inline fun <reified T> AnyCol.isType(): Boolean = type() == typeOf<T>()
4850
/** Returns `true` when this column's type is a subtype of `Number?` */
4951
public fun AnyCol.isNumber(): Boolean = isSubtypeOf<Number?>()
5052

53+
/** Returns `true` only when this column's type is exactly `Number` or `Number?`. */
54+
public fun AnyCol.isMixedNumber(): Boolean = type().isMixedNumber()
55+
5156
/**
5257
* Returns `true` when this column has the (nullable) type of either:
5358
* [Byte], [Short], [Int], [Long], [Float], or [Double].
5459
*/
55-
public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes
60+
public fun AnyCol.isPrimitiveNumber(): Boolean = type().isPrimitiveNumber()
61+
62+
/**
63+
* Returns `true` when this column has the (nullable) type of either:
64+
* [Byte], [Short], [Int], [Long], [Float], [Double], or [Number].
65+
*
66+
* Careful: Will return `true` if the column contains multiple number types that
67+
* might NOT be primitive.
68+
*/
69+
public fun AnyCol.isPrimitiveOrMixedNumber(): Boolean = type().isPrimitiveOrMixedNumber()
5670

5771
public fun AnyCol.isList(): Boolean = typeClass == List::class
5872

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,15 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
1818
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
1919
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2020
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
21-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns
21+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
2222
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
23-
import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes
23+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
2424
import kotlin.reflect.KProperty
25-
import kotlin.reflect.full.withNullability
2625
import kotlin.reflect.typeOf
2726

28-
/*
29-
* TODO KDocs:
27+
/* TODO KDocs:
3028
* Calculating the mean is supported for all primitive number types.
31-
* Nulls are filtered from columns.
29+
* Nulls are filtered out.
3230
* The return type is always Double, Double.NaN for empty input, never null.
3331
* (May introduce loss of precision for Longs).
3432
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean.
@@ -48,24 +46,21 @@ public inline fun <T, reified R : Number> DataColumn<T>.meanOf(
4846
// region DataRow
4947

5048
public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double =
51-
Aggregators.mean(skipNA).aggregateOfRow(this) {
52-
colsOf<Number?> { it.isPrimitiveNumber() }
53-
}
49+
Aggregators.mean(skipNA).aggregateOfRow(this, primitiveOrMixedNumberColumns())
5450

5551
public inline fun <reified T : Number?> AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double {
56-
require(typeOf<T>().withNullability(false) in primitiveNumberTypes) {
52+
require(typeOf<T>().isPrimitiveOrMixedNumber()) {
5753
"Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types."
5854
}
59-
return Aggregators.mean(skipNA)
60-
.aggregateOfRow(this) { colsOf<T>() }
55+
return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf<T>() }
6156
}
6257

6358
// endregion
6459

6560
// region DataFrame
6661

6762
public fun <T> DataFrame<T>.mean(skipNA: Boolean = skipNA_default): DataRow<T> =
68-
meanFor(skipNA, primitiveNumberColumns())
63+
meanFor(skipNA, primitiveOrMixedNumberColumns())
6964

7065
public fun <T, C : Number> DataFrame<T>.meanFor(
7166
skipNA: Boolean = skipNA_default,
@@ -116,7 +111,7 @@ public inline fun <T, reified D : Number> DataFrame<T>.meanOf(
116111
@Refine
117112
@Interpretable("GroupByMean1")
118113
public fun <T> Grouped<T>.mean(skipNA: Boolean = skipNA_default): DataFrame<T> =
119-
meanFor(skipNA, primitiveNumberColumns())
114+
meanFor(skipNA, primitiveOrMixedNumberColumns())
120115

121116
@Refine
122117
@Interpretable("GroupByMean0")
@@ -181,7 +176,7 @@ public inline fun <T, reified R : Number> Grouped<T>.meanOf(
181176
// region Pivot
182177

183178
public fun <T> Pivot<T>.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow<T> =
184-
meanFor(skipNA, separate, primitiveNumberColumns())
179+
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())
185180

186181
public fun <T, C : Number> Pivot<T>.meanFor(
187182
skipNA: Boolean = skipNA_default,
@@ -224,7 +219,7 @@ public inline fun <T, reified R : Number> Pivot<T>.meanOf(
224219
// region PivotGroupBy
225220

226221
public fun <T> PivotGroupBy<T>.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame<T> =
227-
meanFor(skipNA, separate, primitiveNumberColumns())
222+
meanFor(skipNA, separate, primitiveOrMixedNumberColumns())
228223

229224
public fun <T, C : Number> PivotGroupBy<T>.meanFor(
230225
skipNA: Boolean = skipNA_default,

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
@file:OptIn(ExperimentalTypeInference::class)
2+
@file:Suppress("LocalVariableName")
3+
14
package org.jetbrains.kotlinx.dataframe.api
25

36
import org.jetbrains.kotlinx.dataframe.AnyRow
@@ -13,50 +16,97 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine
1316
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1417
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
1518
import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf
16-
import org.jetbrains.kotlinx.dataframe.columns.values
17-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1819
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators
19-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast
2020
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll
2121
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor
2222
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
23-
import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns
23+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
24+
import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns
2425
import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns
25-
import org.jetbrains.kotlinx.dataframe.impl.zero
26-
import org.jetbrains.kotlinx.dataframe.math.sum
27-
import org.jetbrains.kotlinx.dataframe.math.sumOf
26+
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
27+
import kotlin.experimental.ExperimentalTypeInference
28+
import kotlin.reflect.KClass
2829
import kotlin.reflect.KProperty
29-
import kotlin.reflect.full.isSubtypeOf
30+
import kotlin.reflect.KType
3031
import kotlin.reflect.typeOf
3132

33+
/* TODO KDocs
34+
* Calculating the sum is supported for all primitive number types.
35+
* Nulls are filtered out.
36+
* The return type is always the same as the input type (never null), except for `Byte` and `Short`,
37+
* which are converted to `Int`.
38+
* Empty input will result in 0 in the supplied number type.
39+
* For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum.
40+
*/
41+
3242
// region DataColumn
3343

34-
@JvmName("sumT")
35-
public fun <T : Number> DataColumn<T>.sum(): T = values.sum(type())
44+
@JvmName("sumShort")
45+
public fun DataColumn<Short?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
46+
47+
@JvmName("sumByte")
48+
public fun DataColumn<Byte?>.sum(): Int = Aggregators.sum.aggregate(this) as Int
3649

37-
@JvmName("sumTNullable")
38-
public fun <T : Number> DataColumn<T?>.sum(): T = values.sum(type())
50+
@Suppress("UNCHECKED_CAST")
51+
@JvmName("sumNumber")
52+
public fun <T : Number> DataColumn<T?>.sum(): T = Aggregators.sum.aggregate(this) as T
3953

40-
public inline fun <T, reified R : Number> DataColumn<T>.sumOf(noinline expression: (T) -> R): R? =
41-
(Aggregators.sum as Aggregator<*, *>).cast<R>().aggregateOf(this, expression)
54+
@JvmName("sumOfShort")
55+
@OverloadResolutionByLambdaReturnType
56+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Short?): Int =
57+
Aggregators.sum.aggregateOf(this, expression) as Int
58+
59+
@JvmName("sumOfByte")
60+
@OverloadResolutionByLambdaReturnType
61+
public fun <C> DataColumn<C>.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int
62+
63+
@JvmName("sumOfNumber")
64+
@OverloadResolutionByLambdaReturnType
65+
public inline fun <C, reified V : Number> DataColumn<C>.sumOf(crossinline expression: (C) -> V?): V =
66+
Aggregators.sum.aggregateOf(this, expression) as V
4267

4368
// endregion
4469

4570
// region DataRow
4671

47-
public fun AnyRow.rowSum(): Number =
48-
Aggregators.sum.aggregateCalculatingType(
49-
values = values().filterIsInstance<Number>(),
50-
valueTypes = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
51-
) ?: 0
72+
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveOrMixedNumberColumns())
73+
74+
@JvmName("rowSumOfShort")
75+
public inline fun <reified T : Short?> AnyRow.rowSumOf(_kClass: KClass<Short> = Short::class): Int =
76+
rowSumOf(typeOf<T>()) as Int
5277

53-
public inline fun <reified T : Number> AnyRow.rowSumOf(): T = values().filterIsInstance<T>().sum(typeOf<T>())
78+
@JvmName("rowSumOfByte")
79+
public inline fun <reified T : Byte?> AnyRow.rowSumOf(_kClass: KClass<Byte> = Byte::class): Int =
80+
rowSumOf(typeOf<T>()) as Int
5481

82+
@JvmName("rowSumOfInt")
83+
public inline fun <reified T : Int?> AnyRow.rowSumOf(_kClass: KClass<Int> = Int::class): Int =
84+
rowSumOf(typeOf<T>()) as Int
85+
86+
@JvmName("rowSumOfLong")
87+
public inline fun <reified T : Long?> AnyRow.rowSumOf(_kClass: KClass<Long> = Long::class): Long =
88+
rowSumOf(typeOf<T>()) as Long
89+
90+
@JvmName("rowSumOfFloat")
91+
public inline fun <reified T : Float?> AnyRow.rowSumOf(_kClass: KClass<Float> = Float::class): Float =
92+
rowSumOf(typeOf<T>()) as Float
93+
94+
@JvmName("rowSumOfDouble")
95+
public inline fun <reified T : Double?> AnyRow.rowSumOf(_kClass: KClass<Double> = Double::class): Double =
96+
rowSumOf(typeOf<T>()) as Double
97+
98+
// unfortunately, we cannot make a `reified T : Number?` due to clashes
99+
public fun AnyRow.rowSumOf(type: KType): Number {
100+
require(type.isPrimitiveOrMixedNumber()) {
101+
"Type $type is not a primitive number type. Mean only supports primitive number types."
102+
}
103+
return Aggregators.sum.aggregateOfRow(this) { colsOf(type) }
104+
}
55105
// endregion
56106

57107
// region DataFrame
58108

59-
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(numberColumns())
109+
public fun <T> DataFrame<T>.sum(): DataRow<T> = sumFor(primitiveOrMixedNumberColumns())
60110

61111
public fun <T, C : Number> DataFrame<T>.sumFor(columns: ColumnsForAggregateSelector<T, C?>): DataRow<T> =
62112
Aggregators.sum.aggregateFor(this, columns)
@@ -71,28 +121,70 @@ public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: ColumnReference<C
71121
public fun <T, C : Number> DataFrame<T>.sumFor(vararg columns: KProperty<C?>): DataRow<T> =
72122
sumFor { columns.toColumnSet() }
73123

124+
@JvmName("sumShort")
125+
@OverloadResolutionByLambdaReturnType
126+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Short?>): Int =
127+
Aggregators.sum.aggregateAll(this, columns) as Int
128+
129+
@JvmName("sumByte")
130+
@OverloadResolutionByLambdaReturnType
131+
public fun <T> DataFrame<T>.sum(columns: ColumnsSelector<T, Byte?>): Int =
132+
Aggregators.sum.aggregateAll(this, columns) as Int
133+
134+
@JvmName("sumNumber")
135+
@OverloadResolutionByLambdaReturnType
74136
public inline fun <T, reified C : Number> DataFrame<T>.sum(noinline columns: ColumnsSelector<T, C?>): C =
75-
(Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero()
137+
Aggregators.sum.aggregateAll(this, columns) as C
76138

139+
@JvmName("sumShort")
140+
@AccessApiOverload
141+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Short?>): Int = sum { columns.toColumnSet() }
142+
143+
@JvmName("sumByte")
144+
@AccessApiOverload
145+
public fun <T> DataFrame<T>.sum(vararg columns: ColumnReference<Byte?>): Int = sum { columns.toColumnSet() }
146+
147+
@JvmName("sumNumber")
77148
@AccessApiOverload
78149
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: ColumnReference<C?>): C =
79150
sum { columns.toColumnSet() }
80151

81-
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() }
152+
public fun <T> DataFrame<T>.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf<Number?>() }
153+
154+
@JvmName("sumShort")
155+
@AccessApiOverload
156+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Short?>): Int = sum { columns.toColumnSet() }
82157

158+
@JvmName("sumByte")
159+
@AccessApiOverload
160+
public fun <T> DataFrame<T>.sum(vararg columns: KProperty<Byte?>): Int = sum { columns.toColumnSet() }
161+
162+
@JvmName("sumNumber")
83163
@AccessApiOverload
84164
public inline fun <T, reified C : Number> DataFrame<T>.sum(vararg columns: KProperty<C?>): C =
85165
sum { columns.toColumnSet() }
86166

87-
public inline fun <T, reified C : Number?> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C>): C =
88-
rows().sumOf(typeOf<C>()) { expression(it, it) }
167+
@JvmName("sumOfShort")
168+
@OverloadResolutionByLambdaReturnType
169+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Short?>): Int =
170+
Aggregators.sum.aggregateOf(this, expression) as Int
171+
172+
@JvmName("sumOfByte")
173+
@OverloadResolutionByLambdaReturnType
174+
public fun <T> DataFrame<T>.sumOf(expression: RowExpression<T, Byte?>): Int =
175+
Aggregators.sum.aggregateOf(this, expression) as Int
176+
177+
@JvmName("sumOfNumber")
178+
@OverloadResolutionByLambdaReturnType
179+
public inline fun <T, reified C : Number> DataFrame<T>.sumOf(crossinline expression: RowExpression<T, C?>): C =
180+
Aggregators.sum.aggregateOf(this, expression) as C
89181

90182
// endregion
91183

92184
// region GroupBy
93185
@Refine
94186
@Interpretable("GroupBySum1")
95-
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(numberColumns())
187+
public fun <T> Grouped<T>.sum(): DataFrame<T> = sumFor(primitiveOrMixedNumberColumns())
96188

97189
@Refine
98190
@Interpretable("GroupBySum0")
@@ -136,7 +228,7 @@ public inline fun <T, reified R : Number> Grouped<T>.sumOf(
136228

137229
// region Pivot
138230

139-
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, numberColumns())
231+
public fun <T> Pivot<T>.sum(separate: Boolean = false): DataRow<T> = sumFor(separate, primitiveOrMixedNumberColumns())
140232

141233
public fun <T, R : Number> Pivot<T>.sumFor(
142234
separate: Boolean = false,
@@ -166,14 +258,15 @@ public fun <T, C : Number> Pivot<T>.sum(vararg columns: ColumnReference<C?>): Da
166258
@AccessApiOverload
167259
public fun <T, C : Number> Pivot<T>.sum(vararg columns: KProperty<C?>): DataRow<T> = sum { columns.toColumnSet() }
168260

169-
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R>): DataRow<T> =
261+
public inline fun <T, reified R : Number> Pivot<T>.sumOf(crossinline expression: RowExpression<T, R?>): DataRow<T> =
170262
delegate { sumOf(expression) }
171263

172264
// endregion
173265

174266
// region PivotGroupBy
175267

176-
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> = sumFor(separate, numberColumns())
268+
public fun <T> PivotGroupBy<T>.sum(separate: Boolean = false): DataFrame<T> =
269+
sumFor(separate, primitiveOrMixedNumberColumns())
177270

178271
public fun <T, R : Number> PivotGroupBy<T>.sumFor(
179272
separate: Boolean = false,
@@ -209,7 +302,7 @@ public fun <T, C : Number> PivotGroupBy<T>.sum(vararg columns: KProperty<C?>): D
209302
sum { columns.toColumnSet() }
210303

211304
public inline fun <T, reified R : Number> PivotGroupBy<T>.sumOf(
212-
crossinline expression: RowExpression<T, R>,
305+
crossinline expression: RowExpression<T, R?>,
213306
): DataFrame<T> = Aggregators.sum.aggregateOf(this, expression)
214307

215308
// endregion

0 commit comments

Comments
 (0)