Skip to content

Commit 0d061b6

Browse files
committed
added tests and fixed name hack
1 parent 3e9261f commit 0d061b6

File tree

2 files changed

+52
-12
lines changed
  • kotlin-spark-api/src

2 files changed

+52
-12
lines changed

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

+12-7
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ object KotlinTypeInference {
112112
// TODO this hack is a WIP and can give errors
113113
// TODO it's to make data classes get column names like "age" with functions like "getAge"
114114
// TODO instead of column names like "getAge"
115-
var DO_NAME_HACK = false
115+
var DO_NAME_HACK = true
116116

117117
/**
118118
* @param kClass the class for which to infer the encoder.
@@ -151,6 +151,7 @@ object KotlinTypeInference {
151151
currentType = kType,
152152
seenTypeSet = emptySet(),
153153
typeVariables = emptyMap(),
154+
isTopLevel = true,
154155
) as AgnosticEncoder<T>
155156

156157

@@ -217,6 +218,7 @@ object KotlinTypeInference {
217218

218219
// how the generic types of the data class (like T, S) are filled in for this instance of the class
219220
typeVariables: Map<String, KType>,
221+
isTopLevel: Boolean = false,
220222
): AgnosticEncoder<*> {
221223
val kClass =
222224
currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType")
@@ -488,6 +490,7 @@ object KotlinTypeInference {
488490

489491
DirtyProductEncoderField(
490492
doNameHack = DO_NAME_HACK,
493+
isTopLevel = isTopLevel,
491494
columnName = paramName,
492495
readMethodName = readMethodName,
493496
writeMethodName = writeMethodName,
@@ -545,6 +548,7 @@ internal open class DirtyProductEncoderField(
545548
private val readMethodName: String, // the name of the method used to read the value
546549
private val writeMethodName: String?,
547550
private val doNameHack: Boolean,
551+
private val isTopLevel: Boolean,
548552
encoder: AgnosticEncoder<*>,
549553
nullable: Boolean,
550554
metadata: Metadata = Metadata.empty(),
@@ -557,20 +561,21 @@ internal open class DirtyProductEncoderField(
557561
/* writeMethod = */ writeMethodName.toOption(),
558562
), Serializable {
559563

560-
private var i = 0
564+
private var isFirstNameCall = true
561565

562566
/**
563567
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
564-
* creates an [Invoke] using [columnName] first and then calls [columnName] again to retrieve
568+
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
565569
* the name of the column. This way, we can alternate between the two names.
566570
*/
567571
override fun name(): String =
568-
when (doNameHack) {
569-
true -> if (i++ % 2 == 0) readMethodName else columnName
570-
false -> readMethodName
572+
if (doNameHack && !isFirstNameCall) {
573+
columnName
574+
} else {
575+
isFirstNameCall = false
576+
readMethodName
571577
}
572578

573-
574579
override fun canEqual(that: Any?): Boolean = that is AgnosticEncoders.EncoderField
575580

576581
override fun productElement(n: Int): Any =

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

+40-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ package org.jetbrains.kotlinx.spark.api
2222
import ch.tutteli.atrium.api.fluent.en_GB.*
2323
import ch.tutteli.atrium.api.verbs.expect
2424
import io.kotest.core.spec.style.ShouldSpec
25+
import io.kotest.matchers.collections.shouldContain
26+
import io.kotest.matchers.collections.shouldContainExactly
2527
import io.kotest.matchers.shouldBe
2628
import org.apache.spark.sql.Dataset
2729
import org.apache.spark.sql.types.Decimal
@@ -208,6 +210,39 @@ class EncodingTest : ShouldSpec({
208210
context("schema") {
209211
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
210212

213+
context("Give proper names to columns of data classe") {
214+
val old = KotlinTypeInference.DO_NAME_HACK
215+
KotlinTypeInference.DO_NAME_HACK = true
216+
217+
should("Be able to serialize pairs") {
218+
val pairs = listOf(
219+
1 to "1",
220+
2 to "2",
221+
3 to "3",
222+
)
223+
val dataset = pairs.toDS()
224+
dataset.show()
225+
dataset.collectAsList() shouldBe pairs
226+
dataset.columns().shouldContainExactly("first", "second")
227+
}
228+
229+
should("Be able to serialize pairs of pairs") {
230+
val pairs = listOf(
231+
1 to (1 to "1"),
232+
2 to (2 to "2"),
233+
3 to (3 to "3"),
234+
)
235+
val dataset = pairs.toDS()
236+
dataset.show()
237+
dataset.printSchema()
238+
dataset.columns().shouldContainExactly("first", "second")
239+
dataset.select("second.*").columns().shouldContainExactly("first", "second")
240+
dataset.collectAsList() shouldBe pairs
241+
}
242+
243+
KotlinTypeInference.DO_NAME_HACK = old
244+
}
245+
211246
should("handle Scala Case class datasets") {
212247
val caseClasses = listOf(
213248
tupleOf(1, "1"),
@@ -253,14 +288,14 @@ class EncodingTest : ShouldSpec({
253288
}
254289

255290

256-
xshould("handle Scala Option datasets") {
291+
should("handle Scala Option datasets") {
257292
val caseClasses = listOf(Some(1), Some(2), Some(3))
258293
val dataset = caseClasses.toDS()
259294
dataset.show()
260295
dataset.collectAsList() shouldBe caseClasses
261296
}
262297

263-
xshould("handle Scala Option Option datasets") {
298+
should("handle Scala Option Option datasets") {
264299
val caseClasses = listOf(
265300
Some(Some(1)),
266301
Some(Some(2)),
@@ -270,7 +305,7 @@ class EncodingTest : ShouldSpec({
270305
dataset.collectAsList() shouldBe caseClasses
271306
}
272307

273-
xshould("handle data class Scala Option datasets") {
308+
should("handle data class Scala Option datasets") {
274309
val caseClasses = listOf(
275310
Some(1) to Some(2),
276311
Some(3) to Some(4),
@@ -280,7 +315,7 @@ class EncodingTest : ShouldSpec({
280315
dataset.collectAsList() shouldBe caseClasses
281316
}
282317

283-
xshould("handle Scala Option data class datasets") {
318+
should("handle Scala Option data class datasets") {
284319
val caseClasses = listOf(
285320
Some(1 to 2),
286321
Some(3 to 4),
@@ -501,7 +536,7 @@ class EncodingTest : ShouldSpec({
501536
expect(result).toContain.inOrder.only.values(5.1 to 6)
502537
}
503538

504-
should("!handle primitive arrays") {
539+
should("handle boxed arrays") {
505540
val result = listOf(arrayOf(1, 2, 3, 4))
506541
.toDS()
507542
.map { it.map { ai -> ai + 1 } }

0 commit comments

Comments
 (0)