Skip to content

Commit ab4c455

Browse files
committed
added encoding for DatePeriod, DateTimePeriod, Instant, LocalDateTime, and LocalDate, Duration not working
1 parent 48db819 commit ab4c455

File tree

8 files changed

+288
-11
lines changed

8 files changed

+288
-11
lines changed

Diff for: kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

+21-11
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ import org.apache.spark.sql.types.UserDefinedType
4949
import org.apache.spark.unsafe.types.CalendarInterval
5050
import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName
5151
import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify
52+
import org.jetbrains.kotlinx.spark.api.udts.DatePeriodUdt
53+
import org.jetbrains.kotlinx.spark.api.udts.DateTimePeriodUdt
54+
import org.jetbrains.kotlinx.spark.api.udts.InstantUdt
55+
import org.jetbrains.kotlinx.spark.api.udts.LocalDateTimeUdt
56+
import org.jetbrains.kotlinx.spark.api.udts.LocalDateUdt
5257
import scala.reflect.ClassTag
5358
import java.io.Serializable
5459
import java.util.*
@@ -170,12 +175,14 @@ object KotlinTypeInference : Serializable {
170175
* @return an [AgnosticEncoder] for the given [kType].
171176
*/
172177
@Suppress("UNCHECKED_CAST")
173-
fun <T> encoderFor(kType: KType): AgnosticEncoder<T> =
174-
encoderFor(
178+
fun <T> encoderFor(kType: KType): AgnosticEncoder<T> {
179+
registerUdts()
180+
return encoderFor(
175181
currentType = kType,
176182
seenTypeSet = emptySet(),
177183
typeVariables = emptyMap(),
178184
) as AgnosticEncoder<T>
185+
}
179186

180187

181188
private inline fun <reified T> KType.isSubtypeOf(): Boolean = isSubtypeOf(typeOf<T>())
@@ -296,6 +303,16 @@ object KotlinTypeInference : Serializable {
296303
private fun <K, V> transitiveMerge(a: Map<K, V>, b: Map<K, V>, valueToKey: (V) -> K?): Map<K, V> =
297304
a + b.mapValues { a.getOrDefault(valueToKey(it.value), it.value) }
298305

306+
private fun registerUdts() {
307+
UDTRegistration.register(kotlinx.datetime.LocalDate::class.java.name, LocalDateUdt::class.java.name)
308+
UDTRegistration.register(kotlinx.datetime.Instant::class.java.name, InstantUdt::class.java.name)
309+
UDTRegistration.register(kotlinx.datetime.LocalDateTime::class.java.name, LocalDateTimeUdt::class.java.name)
310+
UDTRegistration.register(kotlinx.datetime.DatePeriod::class.java.name, DatePeriodUdt::class.java.name)
311+
UDTRegistration.register(kotlinx.datetime.DateTimePeriod::class.java.name, DateTimePeriodUdt::class.java.name)
312+
// TODO
313+
// UDTRegistration.register(kotlin.time.Duration::class.java.name, DurationUdt::class.java.name)
314+
}
315+
299316
/**
300317
*
301318
*/
@@ -375,19 +392,12 @@ object KotlinTypeInference : Serializable {
375392
currentType.isSubtypeOf<java.math.BigInteger?>() -> AgnosticEncoders.`JavaBigIntEncoder$`.`MODULE$`
376393
currentType.isSubtypeOf<CalendarInterval?>() -> AgnosticEncoders.`CalendarIntervalEncoder$`.`MODULE$`
377394
currentType.isSubtypeOf<java.time.LocalDate?>() -> AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER()
378-
currentType.isSubtypeOf<kotlinx.datetime.LocalDate?>() -> TODO("User java.time.LocalDate for now. We'll create a UDT for this.")
379395
currentType.isSubtypeOf<java.sql.Date?>() -> AgnosticEncoders.STRICT_DATE_ENCODER()
380396
currentType.isSubtypeOf<java.time.Instant?>() -> AgnosticEncoders.STRICT_INSTANT_ENCODER()
381-
currentType.isSubtypeOf<kotlinx.datetime.Instant?>() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.")
382-
currentType.isSubtypeOf<kotlin.time.TimeMark?>() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.")
383397
currentType.isSubtypeOf<java.sql.Timestamp?>() -> AgnosticEncoders.STRICT_TIMESTAMP_ENCODER()
384398
currentType.isSubtypeOf<java.time.LocalDateTime?>() -> AgnosticEncoders.`LocalDateTimeEncoder$`.`MODULE$`
385-
currentType.isSubtypeOf<kotlinx.datetime.LocalDateTime?>() -> TODO("Use java.time.LocalDateTime for now. We'll create a UDT for this.")
386399
currentType.isSubtypeOf<java.time.Duration?>() -> AgnosticEncoders.`DayTimeIntervalEncoder$`.`MODULE$`
387-
currentType.isSubtypeOf<kotlin.time.Duration?>() -> TODO("Use java.time.Duration for now. We'll create a UDT for this.")
388400
currentType.isSubtypeOf<java.time.Period?>() -> AgnosticEncoders.`YearMonthIntervalEncoder$`.`MODULE$`
389-
currentType.isSubtypeOf<kotlinx.datetime.DateTimePeriod?>() -> TODO("Use java.time.Period for now. We'll create a UDT for this.")
390-
currentType.isSubtypeOf<kotlinx.datetime.DatePeriod?>() -> TODO("Use java.time.Period for now. We'll create a UDT for this.")
391401
currentType.isSubtypeOf<Row?>() -> AgnosticEncoders.`UnboundRowEncoder$`.`MODULE$`
392402

393403
// enums
@@ -414,6 +424,8 @@ object KotlinTypeInference : Serializable {
414424
AgnosticEncoders.UDTEncoder(udt, udt.javaClass)
415425
}
416426

427+
currentType.isSubtypeOf<kotlin.time.Duration?>() -> TODO("kotlin.time.Duration is unsupported. Use java.time.Duration for now.")
428+
417429
currentType.isSubtypeOf<scala.Option<*>?>() -> {
418430
val elementEncoder = encoderFor(
419431
currentType = tArguments.first().type!!,
@@ -666,8 +678,6 @@ object KotlinTypeInference : Serializable {
666678
fields.asScalaSeq(),
667679
)
668680
}
669-
670-
// else -> throw IllegalArgumentException("No encoder found for type $currentType")
671681
}
672682
}
673683

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import kotlinx.datetime.DatePeriod
4+
import kotlinx.datetime.toJavaPeriod
5+
import kotlinx.datetime.toKotlinDatePeriod
6+
import org.apache.spark.sql.catalyst.util.IntervalUtils
7+
import org.apache.spark.sql.types.UserDefinedType
8+
import org.apache.spark.sql.types.YearMonthIntervalType
9+
10+
/**
11+
* NOTE: Just like java.time.DatePeriod, this is truncated to months.
12+
*/
13+
class DatePeriodUdt : UserDefinedType<DatePeriod>() {
14+
15+
override fun userClass(): Class<DatePeriod> = DatePeriod::class.java
16+
override fun deserialize(datum: Any?): DatePeriod? =
17+
when (datum) {
18+
null -> null
19+
is Int -> IntervalUtils.monthsToPeriod(datum).toKotlinDatePeriod()
20+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
21+
}
22+
23+
override fun serialize(obj: DatePeriod?): Int? =
24+
obj?.let { IntervalUtils.periodToMonths(it.toJavaPeriod()) }
25+
26+
override fun sqlType(): YearMonthIntervalType = YearMonthIntervalType.apply()
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import kotlinx.datetime.DateTimePeriod
4+
import org.apache.spark.sql.types.CalendarIntervalType
5+
import org.apache.spark.sql.types.`CalendarIntervalType$`
6+
import org.apache.spark.sql.types.UserDefinedType
7+
import org.apache.spark.unsafe.types.CalendarInterval
8+
import kotlin.time.Duration.Companion.hours
9+
import kotlin.time.Duration.Companion.minutes
10+
import kotlin.time.Duration.Companion.nanoseconds
11+
import kotlin.time.Duration.Companion.seconds
12+
13+
/**
14+
* NOTE: Just like java.time.DatePeriod, this is truncated to months.
15+
*/
16+
class DateTimePeriodUdt : UserDefinedType<DateTimePeriod>() {
17+
18+
override fun userClass(): Class<DateTimePeriod> = DateTimePeriod::class.java
19+
override fun deserialize(datum: Any?): DateTimePeriod? =
20+
when (datum) {
21+
null -> null
22+
is CalendarInterval ->
23+
DateTimePeriod(
24+
months = datum.months,
25+
days = datum.days,
26+
nanoseconds = datum.microseconds * 1_000,
27+
)
28+
29+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
30+
}
31+
32+
override fun serialize(obj: DateTimePeriod?): CalendarInterval? =
33+
obj?.let {
34+
CalendarInterval(
35+
/* months = */ obj.months + obj.years * 12,
36+
/* days = */ obj.days,
37+
/* microseconds = */
38+
(obj.hours.hours +
39+
obj.minutes.minutes +
40+
obj.seconds.seconds +
41+
obj.nanoseconds.nanoseconds).inWholeMicroseconds,
42+
)
43+
}
44+
45+
override fun sqlType(): CalendarIntervalType = `CalendarIntervalType$`.`MODULE$`
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import org.apache.spark.sql.catalyst.util.IntervalUtils
4+
import org.apache.spark.sql.types.DataType
5+
import org.apache.spark.sql.types.DayTimeIntervalType
6+
import org.apache.spark.sql.types.UserDefinedType
7+
import kotlin.time.Duration
8+
import kotlin.time.Duration.Companion.milliseconds
9+
import kotlin.time.Duration.Companion.nanoseconds
10+
import kotlin.time.toJavaDuration
11+
import kotlin.time.toKotlinDuration
12+
13+
// TODO Fails, likely because Duration is a value class.
14+
class DurationUdt : UserDefinedType<Duration>() {
15+
16+
override fun userClass(): Class<Duration> = Duration::class.java
17+
override fun deserialize(datum: Any?): Duration? =
18+
when (datum) {
19+
null -> null
20+
is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration()
21+
// is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration().let {
22+
// // store in nanos
23+
// it.inWholeNanoseconds shl 1
24+
// }
25+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
26+
}
27+
28+
// override fun serialize(obj: Duration): Long =
29+
// IntervalUtils.durationToMicros(obj.toJavaDuration())
30+
31+
fun serialize(obj: Long): Long? =
32+
obj?.let { rawValue ->
33+
val unitDiscriminator = rawValue.toInt() and 1
34+
fun isInNanos() = unitDiscriminator == 0
35+
val value = rawValue shr 1
36+
val duration = if (isInNanos()) value.nanoseconds else value.milliseconds
37+
38+
IntervalUtils.durationToMicros(duration.toJavaDuration())
39+
}
40+
41+
override fun serialize(obj: Duration): Long? =
42+
obj?.let { IntervalUtils.durationToMicros(it.toJavaDuration()) }
43+
44+
45+
override fun sqlType(): DataType = DayTimeIntervalType.apply()
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import kotlinx.datetime.Instant
4+
import kotlinx.datetime.toJavaInstant
5+
import kotlinx.datetime.toKotlinInstant
6+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
7+
import org.apache.spark.sql.types.DataType
8+
import org.apache.spark.sql.types.`TimestampType$`
9+
import org.apache.spark.sql.types.UserDefinedType
10+
11+
12+
class InstantUdt : UserDefinedType<Instant>() {
13+
14+
override fun userClass(): Class<Instant> = Instant::class.java
15+
override fun deserialize(datum: Any?): Instant? =
16+
when (datum) {
17+
null -> null
18+
is Long -> DateTimeUtils.microsToInstant(datum).toKotlinInstant()
19+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
20+
}
21+
22+
override fun serialize(obj: Instant?): Long? =
23+
obj?.let { DateTimeUtils.instantToMicros(it.toJavaInstant()) }
24+
25+
override fun sqlType(): DataType = `TimestampType$`.`MODULE$`
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import kotlinx.datetime.LocalDateTime
4+
import kotlinx.datetime.toJavaLocalDateTime
5+
import kotlinx.datetime.toKotlinLocalDateTime
6+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
7+
import org.apache.spark.sql.types.DataType
8+
import org.apache.spark.sql.types.`TimestampNTZType$`
9+
import org.apache.spark.sql.types.UserDefinedType
10+
11+
12+
class LocalDateTimeUdt : UserDefinedType<LocalDateTime>() {
13+
14+
override fun userClass(): Class<LocalDateTime> = LocalDateTime::class.java
15+
override fun deserialize(datum: Any?): LocalDateTime? =
16+
when (datum) {
17+
null -> null
18+
is Long -> DateTimeUtils.microsToLocalDateTime(datum).toKotlinLocalDateTime()
19+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
20+
}
21+
22+
override fun serialize(obj: LocalDateTime?): Long? =
23+
obj?.let { DateTimeUtils.localDateTimeToMicros(it.toJavaLocalDateTime()) }
24+
25+
override fun sqlType(): DataType = `TimestampNTZType$`.`MODULE$`
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package org.jetbrains.kotlinx.spark.api.udts
2+
3+
import kotlinx.datetime.LocalDate
4+
import kotlinx.datetime.toJavaLocalDate
5+
import kotlinx.datetime.toKotlinLocalDate
6+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
7+
import org.apache.spark.sql.types.DataType
8+
import org.apache.spark.sql.types.`DateType$`
9+
import org.apache.spark.sql.types.UserDefinedType
10+
11+
12+
class LocalDateUdt : UserDefinedType<LocalDate>() {
13+
14+
override fun userClass(): Class<LocalDate> = LocalDate::class.java
15+
override fun deserialize(datum: Any?): LocalDate? =
16+
when (datum) {
17+
null -> null
18+
is Int -> DateTimeUtils.daysToLocalDate(datum).toKotlinLocalDate()
19+
else -> throw IllegalArgumentException("Unsupported datum: $datum")
20+
}
21+
22+
override fun serialize(obj: LocalDate?): Int? =
23+
obj?.let { DateTimeUtils.localDateToDays(it.toJavaLocalDate()) }
24+
25+
override fun sqlType(): DataType = `DateType$`.`MODULE$`
26+
}

0 commit comments

Comments
 (0)