@@ -49,6 +49,11 @@ import org.apache.spark.sql.types.UserDefinedType
49
49
import org.apache.spark.unsafe.types.CalendarInterval
50
50
import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName
51
51
import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify
52
+ import org.jetbrains.kotlinx.spark.api.udts.DatePeriodUdt
53
+ import org.jetbrains.kotlinx.spark.api.udts.DateTimePeriodUdt
54
+ import org.jetbrains.kotlinx.spark.api.udts.InstantUdt
55
+ import org.jetbrains.kotlinx.spark.api.udts.LocalDateTimeUdt
56
+ import org.jetbrains.kotlinx.spark.api.udts.LocalDateUdt
52
57
import scala.reflect.ClassTag
53
58
import java.io.Serializable
54
59
import java.util.*
@@ -170,12 +175,14 @@ object KotlinTypeInference : Serializable {
170
175
* @return an [AgnosticEncoder] for the given [kType].
171
176
*/
172
177
@Suppress(" UNCHECKED_CAST" )
173
- fun <T > encoderFor (kType : KType ): AgnosticEncoder <T > =
174
- encoderFor(
178
+ fun <T > encoderFor (kType : KType ): AgnosticEncoder <T > {
179
+ registerUdts()
180
+ return encoderFor(
175
181
currentType = kType,
176
182
seenTypeSet = emptySet(),
177
183
typeVariables = emptyMap(),
178
184
) as AgnosticEncoder <T >
185
+ }
179
186
180
187
181
188
private inline fun <reified T > KType.isSubtypeOf (): Boolean = isSubtypeOf(typeOf<T >())
@@ -296,6 +303,16 @@ object KotlinTypeInference : Serializable {
296
303
private fun <K , V > transitiveMerge (a : Map <K , V >, b : Map <K , V >, valueToKey : (V ) -> K ? ): Map <K , V > =
297
304
a + b.mapValues { a.getOrDefault(valueToKey(it.value), it.value) }
298
305
306
+ private fun registerUdts () {
307
+ UDTRegistration .register(kotlinx.datetime.LocalDate ::class .java.name, LocalDateUdt ::class .java.name)
308
+ UDTRegistration .register(kotlinx.datetime.Instant ::class .java.name, InstantUdt ::class .java.name)
309
+ UDTRegistration .register(kotlinx.datetime.LocalDateTime ::class .java.name, LocalDateTimeUdt ::class .java.name)
310
+ UDTRegistration .register(kotlinx.datetime.DatePeriod ::class .java.name, DatePeriodUdt ::class .java.name)
311
+ UDTRegistration .register(kotlinx.datetime.DateTimePeriod ::class .java.name, DateTimePeriodUdt ::class .java.name)
312
+ // TODO
313
+ // UDTRegistration.register(kotlin.time.Duration::class.java.name, DurationUdt::class.java.name)
314
+ }
315
+
299
316
/* *
300
317
*
301
318
*/
@@ -375,19 +392,12 @@ object KotlinTypeInference : Serializable {
375
392
currentType.isSubtypeOf< java.math.BigInteger ? > () -> AgnosticEncoders .`JavaBigIntEncoder $`.`MODULE $`
376
393
currentType.isSubtypeOf<CalendarInterval ?>() -> AgnosticEncoders .`CalendarIntervalEncoder $`.`MODULE $`
377
394
currentType.isSubtypeOf< java.time.LocalDate ? > () -> AgnosticEncoders .STRICT_LOCAL_DATE_ENCODER ()
378
- currentType.isSubtypeOf< kotlinx.datetime.LocalDate ? > () -> TODO (" User java.time.LocalDate for now. We'll create a UDT for this." )
379
395
currentType.isSubtypeOf< java.sql.Date ? > () -> AgnosticEncoders .STRICT_DATE_ENCODER ()
380
396
currentType.isSubtypeOf< java.time.Instant ? > () -> AgnosticEncoders .STRICT_INSTANT_ENCODER ()
381
- currentType.isSubtypeOf< kotlinx.datetime.Instant ? > () -> TODO (" Use java.time.Instant for now. We'll create a UDT for this." )
382
- currentType.isSubtypeOf< kotlin.time.TimeMark ? > () -> TODO (" Use java.time.Instant for now. We'll create a UDT for this." )
383
397
currentType.isSubtypeOf< java.sql.Timestamp ? > () -> AgnosticEncoders .STRICT_TIMESTAMP_ENCODER ()
384
398
currentType.isSubtypeOf< java.time.LocalDateTime ? > () -> AgnosticEncoders .`LocalDateTimeEncoder $`.`MODULE $`
385
- currentType.isSubtypeOf< kotlinx.datetime.LocalDateTime ? > () -> TODO (" Use java.time.LocalDateTime for now. We'll create a UDT for this." )
386
399
currentType.isSubtypeOf< java.time.Duration ? > () -> AgnosticEncoders .`DayTimeIntervalEncoder $`.`MODULE $`
387
- currentType.isSubtypeOf< kotlin.time.Duration ? > () -> TODO (" Use java.time.Duration for now. We'll create a UDT for this." )
388
400
currentType.isSubtypeOf< java.time.Period ? > () -> AgnosticEncoders .`YearMonthIntervalEncoder $`.`MODULE $`
389
- currentType.isSubtypeOf< kotlinx.datetime.DateTimePeriod ? > () -> TODO (" Use java.time.Period for now. We'll create a UDT for this." )
390
- currentType.isSubtypeOf< kotlinx.datetime.DatePeriod ? > () -> TODO (" Use java.time.Period for now. We'll create a UDT for this." )
391
401
currentType.isSubtypeOf<Row ?>() -> AgnosticEncoders .`UnboundRowEncoder $`.`MODULE $`
392
402
393
403
// enums
@@ -414,6 +424,8 @@ object KotlinTypeInference : Serializable {
414
424
AgnosticEncoders .UDTEncoder (udt, udt.javaClass)
415
425
}
416
426
427
+ currentType.isSubtypeOf< kotlin.time.Duration ? > () -> TODO (" kotlin.time.Duration is unsupported. Use java.time.Duration for now." )
428
+
417
429
currentType.isSubtypeOf< scala.Option <* >? > () -> {
418
430
val elementEncoder = encoderFor(
419
431
currentType = tArguments.first().type!! ,
@@ -666,8 +678,6 @@ object KotlinTypeInference : Serializable {
666
678
fields.asScalaSeq(),
667
679
)
668
680
}
669
-
670
- // else -> throw IllegalArgumentException("No encoder found for type $currentType")
671
681
}
672
682
}
673
683
0 commit comments