@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper
36
36
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
37
37
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
38
38
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
39
- import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
40
39
import org.apache.spark.sql.catalyst.encoders.OuterScopes
41
40
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
42
41
import org.apache.spark.sql.types.DataType
@@ -69,14 +68,15 @@ fun <T : Any> kotlinEncoderFor(
69
68
arguments : List <KTypeProjection > = emptyList(),
70
69
nullable : Boolean = false,
71
70
annotations : List <Annotation > = emptyList()
72
- ): Encoder <T > = ExpressionEncoder .apply (
73
- KotlinTypeInference .encoderFor(
74
- kClass = kClass,
75
- arguments = arguments,
76
- nullable = nullable,
77
- annotations = annotations,
71
+ ): Encoder <T > =
72
+ applyEncoder(
73
+ KotlinTypeInference .encoderFor(
74
+ kClass = kClass,
75
+ arguments = arguments,
76
+ nullable = nullable,
77
+ annotations = annotations,
78
+ )
78
79
)
79
- )
80
80
81
81
/* *
82
82
* Main method of API, which gives you seamless integration with Spark:
@@ -88,15 +88,26 @@ fun <T : Any> kotlinEncoderFor(
88
88
* @return generated encoder
89
89
*/
90
90
inline fun <reified T > kotlinEncoderFor (): Encoder <T > =
91
- ExpressionEncoder . apply (
92
- KotlinTypeInference .encoderFor <T >()
91
+ kotlinEncoderFor (
92
+ typeOf <T >()
93
93
)
94
94
95
95
fun <T > kotlinEncoderFor (kType : KType ): Encoder <T > =
96
- ExpressionEncoder . apply (
96
+ applyEncoder (
97
97
KotlinTypeInference .encoderFor(kType)
98
98
)
99
99
100
+ /* *
101
+ * For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder.
102
+ */
103
+ private fun <T > applyEncoder (agnosticEncoder : AgnosticEncoder <T >): Encoder <T > {
104
+ // #if sparkConnect == false
105
+ return org.apache.spark.sql.catalyst.encoders.ExpressionEncoder .apply (agnosticEncoder)
106
+ // #else
107
+ // $return agnosticEncoder
108
+ // #endif
109
+ }
110
+
100
111
101
112
@Deprecated(" Use kotlinEncoderFor instead" , ReplaceWith (" kotlinEncoderFor<T>()" ))
102
113
inline fun <reified T > encoder (): Encoder <T > = kotlinEncoderFor(typeOf<T >())
@@ -112,7 +123,7 @@ object KotlinTypeInference {
112
123
// TODO this hack is a WIP and can give errors
113
124
// TODO it's to make data classes get column names like "age" with functions like "getAge"
114
125
// TODO instead of column names like "getAge"
115
- var DO_NAME_HACK = true
126
+ var DO_NAME_HACK = false
116
127
117
128
/* *
118
129
* @param kClass the class for which to infer the encoder.
@@ -151,7 +162,6 @@ object KotlinTypeInference {
151
162
currentType = kType,
152
163
seenTypeSet = emptySet(),
153
164
typeVariables = emptyMap(),
154
- isTopLevel = true ,
155
165
) as AgnosticEncoder <T >
156
166
157
167
@@ -218,7 +228,6 @@ object KotlinTypeInference {
218
228
219
229
// how the generic types of the data class (like T, S) are filled in for this instance of the class
220
230
typeVariables : Map <String , KType >,
221
- isTopLevel : Boolean = false,
222
231
): AgnosticEncoder <* > {
223
232
val kClass =
224
233
currentType.classifier as ? KClass <* > ? : throw IllegalArgumentException (" Unsupported type $currentType " )
@@ -328,7 +337,7 @@ object KotlinTypeInference {
328
337
AgnosticEncoders .UDTEncoder (udt, udt.javaClass)
329
338
}
330
339
331
- currentType.isSubtypeOf< scala.Option <* >>() -> {
340
+ currentType.isSubtypeOf< scala.Option <* >? > () -> {
332
341
val elementEncoder = encoderFor(
333
342
currentType = tArguments.first().type!! ,
334
343
seenTypeSet = seenTypeSet,
@@ -506,7 +515,6 @@ object KotlinTypeInference {
506
515
507
516
DirtyProductEncoderField (
508
517
doNameHack = DO_NAME_HACK ,
509
- isTopLevel = isTopLevel,
510
518
columnName = paramName,
511
519
readMethodName = readMethodName,
512
520
writeMethodName = writeMethodName,
@@ -525,7 +533,7 @@ object KotlinTypeInference {
525
533
if (currentType in seenTypeSet) throw IllegalStateException (" Circular reference detected for type $currentType " )
526
534
val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass)
527
535
528
- val params: List < AgnosticEncoders . EncoderField > = constructorParams.map { (paramName, paramType) ->
536
+ val params = constructorParams.map { (paramName, paramType) ->
529
537
val encoder = encoderFor(
530
538
currentType = paramType,
531
539
seenTypeSet = seenTypeSet + currentType,
@@ -564,7 +572,6 @@ internal open class DirtyProductEncoderField(
564
572
private val readMethodName : String , // the name of the method used to read the value
565
573
private val writeMethodName : String? ,
566
574
private val doNameHack : Boolean ,
567
- private val isTopLevel : Boolean ,
568
575
encoder : AgnosticEncoder <* >,
569
576
nullable : Boolean ,
570
577
metadata : Metadata = Metadata .empty(),
@@ -577,18 +584,18 @@ internal open class DirtyProductEncoderField(
577
584
/* writeMethod = */ writeMethodName.toOption(),
578
585
), Serializable {
579
586
580
- private var isFirstNameCall = true
587
+ private var noNameCalls = 0
581
588
582
589
/* *
583
590
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
584
591
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
585
592
* the name of the column. This way, we can alternate between the two names.
586
593
*/
587
594
override fun name (): String =
588
- if (doNameHack && ! isFirstNameCall ) {
595
+ if (doNameHack && noNameCalls > 0 ) {
589
596
columnName
590
597
} else {
591
- isFirstNameCall = false
598
+ noNameCalls ++
592
599
readMethodName
593
600
}
594
601
0 commit comments