Skip to content

Commit 45ba9b5

Browse files
authored
Merge pull request #735 from Kotlin/pivot-fix
Pivot fix
2 parents c40fb04 + e1668f8 commit 45ba9b5

File tree

8 files changed

+94
-18
lines changed

8 files changed

+94
-18
lines changed

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ internal fun <T, G, R> aggregateGroupBy(
8383
val result = body(builder, builder)
8484
if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield(
8585
NamedValue.create(
86-
pathOf(defaultAggregateName), result, null, null, true
86+
path = pathOf(defaultAggregateName),
87+
value = result,
88+
type = null,
89+
defaultValue = null,
90+
guessType = true,
8791
)
8892
)
8993
builder.compute()

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/concat.kt

+15-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.dataframe.hasNulls
1212
import org.jetbrains.kotlinx.dataframe.impl.columns.guessColumnType
1313
import org.jetbrains.kotlinx.dataframe.impl.commonType
1414
import org.jetbrains.kotlinx.dataframe.impl.getListType
15+
import org.jetbrains.kotlinx.dataframe.impl.guessValueType
1516
import org.jetbrains.kotlinx.dataframe.nrow
1617
import kotlin.reflect.KType
1718
import kotlin.reflect.full.withNullability
@@ -54,7 +55,13 @@ internal fun <T> concatImpl(name: String, columns: List<DataColumn<T>?>, columnS
5455
col.toList()
5556
} else {
5657
val nrow = columnSizes[index]
57-
if (!nulls && nrow > 0 && defaultValue == null) nulls = true
58+
if (!nulls && nrow > 0 && defaultValue == null) {
59+
nulls = true
60+
} else if (defaultValue != null) {
61+
types.add(
62+
guessValueType(sequenceOf(defaultValue))
63+
)
64+
}
5865
List(nrow) { defaultValue }
5966
}
6067
}
@@ -63,7 +70,13 @@ internal fun <T> concatImpl(name: String, columns: List<DataColumn<T>?>, columnS
6370
val baseType = types.commonType()
6471
val tartypeOf = if (guessType || !hasList) baseType.withNullability(nulls)
6572
else getListType(baseType.withNullability(listOfNullable))
66-
return guessColumnType(name, list, tartypeOf, guessType, defaultValue).cast()
73+
return guessColumnType(
74+
name = name,
75+
values = list,
76+
suggestedType = tartypeOf,
77+
suggestedTypeIsUpperBound = guessType,
78+
defaultValue = defaultValue,
79+
).cast()
6780
}
6881
}
6982

core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/pivot.kt

+10-6
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,21 @@ internal fun <T, R> aggregatePivot(
100100
val hasResult = result != null && result != Unit
101101

102102
fun NamedValue.apply(path: ColumnPath) =
103-
copy(path = path, value = this.value ?: default ?: globalDefault, default = default ?: globalDefault)
103+
copy(
104+
path = path,
105+
value = this.value ?: default ?: globalDefault,
106+
default = default ?: globalDefault,
107+
)
104108

105109
val values = builder.values
106110
when {
107111
values.size == 1 && values[0].path.isEmpty() -> aggregator.yield(values[0].apply(path))
108112
values.isEmpty() -> aggregator.yield(
109-
path,
110-
if (hasResult) result else globalDefault,
111-
null,
112-
globalDefault,
113-
true
113+
path = path,
114+
value = if (hasResult) result else globalDefault,
115+
type = null,
116+
default = globalDefault,
117+
guessType = true,
114118
)
115119

116120
else -> {

core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt

+17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

33
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.impl.commonType
45
import org.junit.Test
56
import kotlin.reflect.typeOf
67

@@ -168,4 +169,20 @@ class PivotTests {
168169
}
169170
df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count()
170171
}
172+
173+
@Test
174+
fun `pivot with default of other type`() {
175+
val df = dataFrameOf("firstName", "lastName", "age", "city", "weight", "isHappy")(
176+
"Alice", "Cooper", 15, "London", 54, true,
177+
"Bob", "Dylan", 45, "Dubai", 87, true,
178+
"Charlie", "Daniels", 20, "Moscow", null, false,
179+
"Charlie", "Chaplin", 40, "Milan", null, true,
180+
"Bob", "Marley", 30, "Tokyo", 68, true,
181+
"Alice", "Wolf", 20, null, 55, false,
182+
"Charlie", "Byrd", 30, "Moscow", 90, true
183+
).group("firstName", "lastName").into("name")
184+
185+
val pivoted = df.pivot("city").groupBy("name").default(0).min()
186+
pivoted["city"]["London"]["isHappy"].type() shouldBe listOf(typeOf<Int>(), typeOf<Boolean>()).commonType()
187+
}
171188
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ internal fun <T, G, R> aggregateGroupBy(
8383
val result = body(builder, builder)
8484
if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield(
8585
NamedValue.create(
86-
pathOf(defaultAggregateName), result, null, null, true
86+
path = pathOf(defaultAggregateName),
87+
value = result,
88+
type = null,
89+
defaultValue = null,
90+
guessType = true,
8791
)
8892
)
8993
builder.compute()

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/concat.kt

+15-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.dataframe.hasNulls
1212
import org.jetbrains.kotlinx.dataframe.impl.columns.guessColumnType
1313
import org.jetbrains.kotlinx.dataframe.impl.commonType
1414
import org.jetbrains.kotlinx.dataframe.impl.getListType
15+
import org.jetbrains.kotlinx.dataframe.impl.guessValueType
1516
import org.jetbrains.kotlinx.dataframe.nrow
1617
import kotlin.reflect.KType
1718
import kotlin.reflect.full.withNullability
@@ -54,7 +55,13 @@ internal fun <T> concatImpl(name: String, columns: List<DataColumn<T>?>, columnS
5455
col.toList()
5556
} else {
5657
val nrow = columnSizes[index]
57-
if (!nulls && nrow > 0 && defaultValue == null) nulls = true
58+
if (!nulls && nrow > 0 && defaultValue == null) {
59+
nulls = true
60+
} else if (defaultValue != null) {
61+
types.add(
62+
guessValueType(sequenceOf(defaultValue))
63+
)
64+
}
5865
List(nrow) { defaultValue }
5966
}
6067
}
@@ -63,7 +70,13 @@ internal fun <T> concatImpl(name: String, columns: List<DataColumn<T>?>, columnS
6370
val baseType = types.commonType()
6471
val tartypeOf = if (guessType || !hasList) baseType.withNullability(nulls)
6572
else getListType(baseType.withNullability(listOfNullable))
66-
return guessColumnType(name, list, tartypeOf, guessType, defaultValue).cast()
73+
return guessColumnType(
74+
name = name,
75+
values = list,
76+
suggestedType = tartypeOf,
77+
suggestedTypeIsUpperBound = guessType,
78+
defaultValue = defaultValue,
79+
).cast()
6780
}
6881
}
6982

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/pivot.kt

+10-6
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,21 @@ internal fun <T, R> aggregatePivot(
100100
val hasResult = result != null && result != Unit
101101

102102
fun NamedValue.apply(path: ColumnPath) =
103-
copy(path = path, value = this.value ?: default ?: globalDefault, default = default ?: globalDefault)
103+
copy(
104+
path = path,
105+
value = this.value ?: default ?: globalDefault,
106+
default = default ?: globalDefault,
107+
)
104108

105109
val values = builder.values
106110
when {
107111
values.size == 1 && values[0].path.isEmpty() -> aggregator.yield(values[0].apply(path))
108112
values.isEmpty() -> aggregator.yield(
109-
path,
110-
if (hasResult) result else globalDefault,
111-
null,
112-
globalDefault,
113-
true
113+
path = path,
114+
value = if (hasResult) result else globalDefault,
115+
type = null,
116+
default = globalDefault,
117+
guessType = true,
114118
)
115119

116120
else -> {

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt

+17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

33
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.impl.commonType
45
import org.junit.Test
56
import kotlin.reflect.typeOf
67

@@ -168,4 +169,20 @@ class PivotTests {
168169
}
169170
df1 shouldBe df.pivot { "category2" then "category3" }.groupBy("category1").count()
170171
}
172+
173+
@Test
174+
fun `pivot with default of other type`() {
175+
val df = dataFrameOf("firstName", "lastName", "age", "city", "weight", "isHappy")(
176+
"Alice", "Cooper", 15, "London", 54, true,
177+
"Bob", "Dylan", 45, "Dubai", 87, true,
178+
"Charlie", "Daniels", 20, "Moscow", null, false,
179+
"Charlie", "Chaplin", 40, "Milan", null, true,
180+
"Bob", "Marley", 30, "Tokyo", 68, true,
181+
"Alice", "Wolf", 20, null, 55, false,
182+
"Charlie", "Byrd", 30, "Moscow", 90, true
183+
).group("firstName", "lastName").into("name")
184+
185+
val pivoted = df.pivot("city").groupBy("name").default(0).min()
186+
pivoted["city"]["London"]["isHappy"].type() shouldBe listOf(typeOf<Int>(), typeOf<Boolean>()).commonType()
187+
}
171188
}

0 commit comments

Comments
 (0)