Skip to content

Commit 9149607

Browse files
committed
disable name hack by default again, added JCP case for auto-applying the expression encoder without spark-connect
1 parent 0c8f4b1 commit 9149607

File tree

3 files changed

+173
-24
lines changed

3 files changed

+173
-24
lines changed

buildSrc/src/main/kotlin/Versions.kt

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ object Versions {
99
inline val scala get() = System.getProperty("scala") as String
1010
inline val sparkMinor get() = spark.substringBeforeLast('.')
1111
inline val scalaCompat get() = scala.substringBeforeLast('.')
12+
13+
// TODO
14+
const val sparkConnect = false
15+
1216
const val jupyter = "0.12.0-32-1"
1317

1418
const val kotest = "5.5.4"
@@ -25,14 +29,15 @@ object Versions {
2529
const val jacksonDatabind = "2.13.4.2"
2630
const val kotlinxDateTime = "0.6.0-RC.2"
2731

28-
inline val versionMap
32+
inline val versionMap: Map<String, String>
2933
get() = mapOf(
3034
"kotlin" to kotlin,
3135
"scala" to scala,
3236
"scalaCompat" to scalaCompat,
3337
"spark" to spark,
3438
"sparkMinor" to sparkMinor,
3539
"version" to project,
40+
"sparkConnect" to sparkConnect.toString(),
3641
)
3742

3843
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

+28-21
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper
3636
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3737
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
3838
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
39-
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
4039
import org.apache.spark.sql.catalyst.encoders.OuterScopes
4140
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
4241
import org.apache.spark.sql.types.DataType
@@ -69,14 +68,15 @@ fun <T : Any> kotlinEncoderFor(
6968
arguments: List<KTypeProjection> = emptyList(),
7069
nullable: Boolean = false,
7170
annotations: List<Annotation> = emptyList()
72-
): Encoder<T> = ExpressionEncoder.apply(
73-
KotlinTypeInference.encoderFor(
74-
kClass = kClass,
75-
arguments = arguments,
76-
nullable = nullable,
77-
annotations = annotations,
71+
): Encoder<T> =
72+
applyEncoder(
73+
KotlinTypeInference.encoderFor(
74+
kClass = kClass,
75+
arguments = arguments,
76+
nullable = nullable,
77+
annotations = annotations,
78+
)
7879
)
79-
)
8080

8181
/**
8282
* Main method of API, which gives you seamless integration with Spark:
@@ -88,15 +88,26 @@ fun <T : Any> kotlinEncoderFor(
8888
* @return generated encoder
8989
*/
9090
inline fun <reified T> kotlinEncoderFor(): Encoder<T> =
91-
ExpressionEncoder.apply(
92-
KotlinTypeInference.encoderFor<T>()
91+
kotlinEncoderFor(
92+
typeOf<T>()
9393
)
9494

9595
fun <T> kotlinEncoderFor(kType: KType): Encoder<T> =
96-
ExpressionEncoder.apply(
96+
applyEncoder(
9797
KotlinTypeInference.encoderFor(kType)
9898
)
9999

100+
/**
101+
* For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder.
102+
*/
103+
private fun <T> applyEncoder(agnosticEncoder: AgnosticEncoder<T>): Encoder<T> {
104+
//#if sparkConnect == false
105+
return org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.apply(agnosticEncoder)
106+
//#else
107+
//$return agnosticEncoder
108+
//#endif
109+
}
110+
100111

101112
@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor<T>()"))
102113
inline fun <reified T> encoder(): Encoder<T> = kotlinEncoderFor(typeOf<T>())
@@ -112,7 +123,7 @@ object KotlinTypeInference {
112123
// TODO this hack is a WIP and can give errors
113124
// TODO it's to make data classes get column names like "age" with functions like "getAge"
114125
// TODO instead of column names like "getAge"
115-
var DO_NAME_HACK = true
126+
var DO_NAME_HACK = false
116127

117128
/**
118129
* @param kClass the class for which to infer the encoder.
@@ -151,7 +162,6 @@ object KotlinTypeInference {
151162
currentType = kType,
152163
seenTypeSet = emptySet(),
153164
typeVariables = emptyMap(),
154-
isTopLevel = true,
155165
) as AgnosticEncoder<T>
156166

157167

@@ -218,7 +228,6 @@ object KotlinTypeInference {
218228

219229
// how the generic types of the data class (like T, S) are filled in for this instance of the class
220230
typeVariables: Map<String, KType>,
221-
isTopLevel: Boolean = false,
222231
): AgnosticEncoder<*> {
223232
val kClass =
224233
currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType")
@@ -328,7 +337,7 @@ object KotlinTypeInference {
328337
AgnosticEncoders.UDTEncoder(udt, udt.javaClass)
329338
}
330339

331-
currentType.isSubtypeOf<scala.Option<*>>() -> {
340+
currentType.isSubtypeOf<scala.Option<*>?>() -> {
332341
val elementEncoder = encoderFor(
333342
currentType = tArguments.first().type!!,
334343
seenTypeSet = seenTypeSet,
@@ -506,7 +515,6 @@ object KotlinTypeInference {
506515

507516
DirtyProductEncoderField(
508517
doNameHack = DO_NAME_HACK,
509-
isTopLevel = isTopLevel,
510518
columnName = paramName,
511519
readMethodName = readMethodName,
512520
writeMethodName = writeMethodName,
@@ -525,7 +533,7 @@ object KotlinTypeInference {
525533
if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType")
526534
val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass)
527535

528-
val params: List<AgnosticEncoders.EncoderField> = constructorParams.map { (paramName, paramType) ->
536+
val params = constructorParams.map { (paramName, paramType) ->
529537
val encoder = encoderFor(
530538
currentType = paramType,
531539
seenTypeSet = seenTypeSet + currentType,
@@ -564,7 +572,6 @@ internal open class DirtyProductEncoderField(
564572
private val readMethodName: String, // the name of the method used to read the value
565573
private val writeMethodName: String?,
566574
private val doNameHack: Boolean,
567-
private val isTopLevel: Boolean,
568575
encoder: AgnosticEncoder<*>,
569576
nullable: Boolean,
570577
metadata: Metadata = Metadata.empty(),
@@ -577,18 +584,18 @@ internal open class DirtyProductEncoderField(
577584
/* writeMethod = */ writeMethodName.toOption(),
578585
), Serializable {
579586

580-
private var isFirstNameCall = true
587+
private var noNameCalls = 0
581588

582589
/**
583590
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
584591
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
585592
* the name of the column. This way, we can alternate between the two names.
586593
*/
587594
override fun name(): String =
588-
if (doNameHack && !isFirstNameCall) {
595+
if (doNameHack && noNameCalls > 0) {
589596
columnName
590597
} else {
591-
isFirstNameCall = false
598+
noNameCalls++
592599
readMethodName
593600
}
594601

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

+139-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ package org.jetbrains.kotlinx.spark.api
2222
import ch.tutteli.atrium.api.fluent.en_GB.*
2323
import ch.tutteli.atrium.api.verbs.expect
2424
import io.kotest.core.spec.style.ShouldSpec
25-
import io.kotest.matchers.collections.shouldContain
2625
import io.kotest.matchers.collections.shouldContainExactly
2726
import io.kotest.matchers.shouldBe
27+
import io.kotest.matchers.string.shouldContain
2828
import org.apache.spark.sql.Dataset
2929
import org.apache.spark.sql.types.Decimal
3030
import org.apache.spark.unsafe.types.CalendarInterval
@@ -210,7 +210,7 @@ class EncodingTest : ShouldSpec({
210210
context("schema") {
211211
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
212212

213-
context("Give proper names to columns of data classe") {
213+
context("Give proper names to columns of data classes") {
214214
val old = KotlinTypeInference.DO_NAME_HACK
215215
KotlinTypeInference.DO_NAME_HACK = true
216216

@@ -240,6 +240,142 @@ class EncodingTest : ShouldSpec({
240240
dataset.collectAsList() shouldBe pairs
241241
}
242242

243+
should("Be able to serialize pairs of pairs of pairs") {
244+
val pairs = listOf(
245+
1 to (1 to (1 to "1")),
246+
2 to (2 to (2 to "2")),
247+
3 to (3 to (3 to "3")),
248+
)
249+
val dataset = pairs.toDS()
250+
dataset.show()
251+
dataset.printSchema()
252+
dataset.columns().shouldContainExactly("first", "second")
253+
dataset.select("second.*").columns().shouldContainExactly("first", "second")
254+
dataset.select("second.second.*").columns().shouldContainExactly("first", "second")
255+
dataset.collectAsList() shouldBe pairs
256+
}
257+
258+
should("Be able to serialize lists of pairs") {
259+
val pairs = listOf(
260+
listOf(1 to "1", 2 to "2"),
261+
listOf(3 to "3", 4 to "4"),
262+
)
263+
val dataset = pairs.toDS()
264+
dataset.show()
265+
dataset.printSchema()
266+
dataset.schema().toString().let {
267+
it shouldContain "first"
268+
it shouldContain "second"
269+
}
270+
dataset.collectAsList() shouldBe pairs
271+
}
272+
273+
should("Be able to serialize lists of lists of pairs") {
274+
val pairs = listOf(
275+
listOf(
276+
listOf(1 to "1", 2 to "2"),
277+
listOf(3 to "3", 4 to "4")
278+
)
279+
)
280+
val dataset = pairs.toDS()
281+
dataset.show()
282+
dataset.printSchema()
283+
dataset.schema().toString().let {
284+
it shouldContain "first"
285+
it shouldContain "second"
286+
}
287+
dataset.collectAsList() shouldBe pairs
288+
}
289+
290+
should("Be able to serialize lists of lists of lists of pairs") {
291+
val pairs = listOf(
292+
listOf(
293+
listOf(
294+
listOf(1 to "1", 2 to "2"),
295+
listOf(3 to "3", 4 to "4"),
296+
)
297+
)
298+
)
299+
val dataset = pairs.toDS()
300+
dataset.show()
301+
dataset.printSchema()
302+
dataset.schema().toString().let {
303+
it shouldContain "first"
304+
it shouldContain "second"
305+
}
306+
dataset.collectAsList() shouldBe pairs
307+
}
308+
309+
should("Be able to serialize lists of lists of lists of pairs of pairs") {
310+
val pairs = listOf(
311+
listOf(
312+
listOf(
313+
listOf(1 to ("1" to 3.0), 2 to ("2" to 3.0)),
314+
listOf(3 to ("3" to 3.0), 4 to ("4" to 3.0)),
315+
)
316+
)
317+
)
318+
val dataset = pairs.toDS()
319+
dataset.show()
320+
dataset.printSchema()
321+
dataset.schema().toString().let {
322+
it shouldContain "first"
323+
it shouldContain "second"
324+
}
325+
dataset.collectAsList() shouldBe pairs
326+
}
327+
328+
should("Be able to serialize arrays of pairs") {
329+
val pairs = arrayOf(
330+
arrayOf(1 to "1", 2 to "2"),
331+
arrayOf(3 to "3", 4 to "4"),
332+
)
333+
val dataset = pairs.toDS()
334+
dataset.show()
335+
dataset.printSchema()
336+
dataset.schema().toString().let {
337+
it shouldContain "first"
338+
it shouldContain "second"
339+
}
340+
dataset.collectAsList() shouldBe pairs
341+
}
342+
343+
should("Be able to serialize arrays of arrays of pairs") {
344+
val pairs = arrayOf(
345+
arrayOf(
346+
arrayOf(1 to "1", 2 to "2"),
347+
arrayOf(3 to "3", 4 to "4")
348+
)
349+
)
350+
val dataset = pairs.toDS()
351+
dataset.show()
352+
dataset.printSchema()
353+
dataset.schema().toString().let {
354+
it shouldContain "first"
355+
it shouldContain "second"
356+
}
357+
dataset.collectAsList() shouldBe pairs
358+
}
359+
360+
should("Be able to serialize arrays of arrays of arrays of pairs") {
361+
val pairs = arrayOf(
362+
arrayOf(
363+
arrayOf(
364+
arrayOf(1 to "1", 2 to "2"),
365+
arrayOf(3 to "3", 4 to "4"),
366+
)
367+
)
368+
)
369+
val dataset = pairs.toDS()
370+
dataset.show()
371+
dataset.printSchema()
372+
dataset.schema().toString().let {
373+
it shouldContain "first"
374+
it shouldContain "second"
375+
}
376+
dataset.collectAsList() shouldBe pairs
377+
}
378+
243379
KotlinTypeInference.DO_NAME_HACK = old
244380
}
245381

@@ -351,6 +487,7 @@ class EncodingTest : ShouldSpec({
351487
listOf(SomeClass(intArrayOf(1, 2, 3), 4)),
352488
listOf(SomeClass(intArrayOf(3, 2, 1), 0)),
353489
)
490+
dataset.printSchema()
354491

355492
val (first, second) = dataset.collectAsList()
356493

0 commit comments

Comments
 (0)