Skip to content

Commit d60e4dc

Browse files
committed
enabled core as scala-helpers with VarargUnwrapper. Removed name hack in favor of upcoming IR compiler plugin. Removed spark dependency in scala-helpers
1 parent 2c875ff commit d60e4dc

File tree

9 files changed

+133
-92
lines changed

9 files changed

+133
-92
lines changed

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,20 @@ package org.jetbrains.kotlinx.spark.api
3232
import org.apache.spark.sql.Encoder
3333
import org.apache.spark.sql.Row
3434
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
35-
import org.apache.spark.sql.catalyst.SerializerBuildHelper
3635
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3736
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
3837
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.EncoderField
3938
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
4039
import org.apache.spark.sql.catalyst.encoders.OuterScopes
41-
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
4240
import org.apache.spark.sql.types.DataType
4341
import org.apache.spark.sql.types.Decimal
4442
import org.apache.spark.sql.types.Metadata
4543
import org.apache.spark.sql.types.SQLUserDefinedType
44+
import org.apache.spark.sql.types.StructType
4645
import org.apache.spark.sql.types.UDTRegistration
4746
import org.apache.spark.sql.types.UserDefinedType
4847
import org.apache.spark.unsafe.types.CalendarInterval
4948
import scala.reflect.ClassTag
50-
import java.io.Serializable
5149
import kotlin.reflect.KClass
5250
import kotlin.reflect.KMutableProperty
5351
import kotlin.reflect.KType
@@ -113,11 +111,13 @@ private fun <T> applyEncoder(agnosticEncoder: AgnosticEncoder<T>): Encoder<T> {
113111
@Deprecated("Use kotlinEncoderFor instead", ReplaceWith("kotlinEncoderFor<T>()"))
114112
inline fun <reified T> encoder(): Encoder<T> = kotlinEncoderFor(typeOf<T>())
115113

116-
@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor<T>().schema()"))
117-
inline fun <reified T> schema(): DataType = kotlinEncoderFor<T>().schema()
114+
internal fun StructType.unwrap(): DataType =
115+
if (fields().singleOrNull()?.name() == "value") fields().single().dataType()
116+
else this
118117

119-
@Deprecated("Use kotlinEncoderFor to get the schema.", ReplaceWith("kotlinEncoderFor<Any?>(kType).schema()"))
120-
fun schema(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema()
118+
inline fun <reified T> schemaFor(): DataType = schemaFor(typeOf<T>())
119+
120+
fun schemaFor(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema().unwrap()
121121

122122
object KotlinTypeInference {
123123

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class UDFWrapper0(private val udfName: String) {
5353
@OptIn(ExperimentalStdlibApi::class)
5454
@Deprecated("Use new UDF notation", ReplaceWith("this.register(name, func)"), DeprecationLevel.HIDDEN)
5555
inline fun <reified R> UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 {
56-
register(name, UDF0(func), kotlinEncoderFor<R>().schema())
56+
register(name, UDF0(func), schemaFor<R>())
5757
return UDFWrapper0(name)
5858
}
5959

@@ -78,7 +78,7 @@ class UDFWrapper1(private val udfName: String) {
7878
@Deprecated("Use new UDF notation", ReplaceWith("this.register(name, func)"), DeprecationLevel.HIDDEN)
7979
inline fun <reified T0, reified R> UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 {
8080
T0::class.checkForValidType("T0")
81-
register(name, UDF1(func), kotlinEncoderFor<R>().schema())
81+
register(name, UDF1(func), schemaFor<R>())
8282
return UDFWrapper1(name)
8383
}
8484

@@ -107,7 +107,7 @@ inline fun <reified T0, reified T1, reified R> UDFRegistration.register(
107107
): UDFWrapper2 {
108108
T0::class.checkForValidType("T0")
109109
T1::class.checkForValidType("T1")
110-
register(name, UDF2(func), kotlinEncoderFor<R>().schema())
110+
register(name, UDF2(func), schemaFor<R>())
111111
return UDFWrapper2(name)
112112
}
113113

@@ -137,7 +137,7 @@ inline fun <reified T0, reified T1, reified T2, reified R> UDFRegistration.regis
137137
T0::class.checkForValidType("T0")
138138
T1::class.checkForValidType("T1")
139139
T2::class.checkForValidType("T2")
140-
register(name, UDF3(func), kotlinEncoderFor<R>().schema())
140+
register(name, UDF3(func), schemaFor<R>())
141141
return UDFWrapper3(name)
142142
}
143143

@@ -168,7 +168,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified R> UDFRegist
168168
T1::class.checkForValidType("T1")
169169
T2::class.checkForValidType("T2")
170170
T3::class.checkForValidType("T3")
171-
register(name, UDF4(func), kotlinEncoderFor<R>().schema())
171+
register(name, UDF4(func), schemaFor<R>())
172172
return UDFWrapper4(name)
173173
}
174174

@@ -200,7 +200,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
200200
T2::class.checkForValidType("T2")
201201
T3::class.checkForValidType("T3")
202202
T4::class.checkForValidType("T4")
203-
register(name, UDF5(func), kotlinEncoderFor<R>().schema())
203+
register(name, UDF5(func), schemaFor<R>())
204204
return UDFWrapper5(name)
205205
}
206206

@@ -240,7 +240,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
240240
T3::class.checkForValidType("T3")
241241
T4::class.checkForValidType("T4")
242242
T5::class.checkForValidType("T5")
243-
register(name, UDF6(func), kotlinEncoderFor<R>().schema())
243+
register(name, UDF6(func), schemaFor<R>())
244244
return UDFWrapper6(name)
245245
}
246246

@@ -282,7 +282,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
282282
T4::class.checkForValidType("T4")
283283
T5::class.checkForValidType("T5")
284284
T6::class.checkForValidType("T6")
285-
register(name, UDF7(func), kotlinEncoderFor<R>().schema())
285+
register(name, UDF7(func), schemaFor<R>())
286286
return UDFWrapper7(name)
287287
}
288288

@@ -326,7 +326,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
326326
T5::class.checkForValidType("T5")
327327
T6::class.checkForValidType("T6")
328328
T7::class.checkForValidType("T7")
329-
register(name, UDF8(func), kotlinEncoderFor<R>().schema())
329+
register(name, UDF8(func), schemaFor<R>())
330330
return UDFWrapper8(name)
331331
}
332332

@@ -372,7 +372,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
372372
T6::class.checkForValidType("T6")
373373
T7::class.checkForValidType("T7")
374374
T8::class.checkForValidType("T8")
375-
register(name, UDF9(func), kotlinEncoderFor<R>().schema())
375+
register(name, UDF9(func), schemaFor<R>())
376376
return UDFWrapper9(name)
377377
}
378378

@@ -432,7 +432,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
432432
T7::class.checkForValidType("T7")
433433
T8::class.checkForValidType("T8")
434434
T9::class.checkForValidType("T9")
435-
register(name, UDF10(func), kotlinEncoderFor<R>().schema())
435+
register(name, UDF10(func), schemaFor<R>())
436436
return UDFWrapper10(name)
437437
}
438438

@@ -495,7 +495,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
495495
T8::class.checkForValidType("T8")
496496
T9::class.checkForValidType("T9")
497497
T10::class.checkForValidType("T10")
498-
register(name, UDF11(func), kotlinEncoderFor<R>().schema())
498+
register(name, UDF11(func), schemaFor<R>())
499499
return UDFWrapper11(name)
500500
}
501501

@@ -561,7 +561,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
561561
T9::class.checkForValidType("T9")
562562
T10::class.checkForValidType("T10")
563563
T11::class.checkForValidType("T11")
564-
register(name, UDF12(func), kotlinEncoderFor<R>().schema())
564+
register(name, UDF12(func), schemaFor<R>())
565565
return UDFWrapper12(name)
566566
}
567567

@@ -630,7 +630,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
630630
T10::class.checkForValidType("T10")
631631
T11::class.checkForValidType("T11")
632632
T12::class.checkForValidType("T12")
633-
register(name, UDF13(func), kotlinEncoderFor<R>().schema())
633+
register(name, UDF13(func), schemaFor<R>())
634634
return UDFWrapper13(name)
635635
}
636636

@@ -702,7 +702,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
702702
T11::class.checkForValidType("T11")
703703
T12::class.checkForValidType("T12")
704704
T13::class.checkForValidType("T13")
705-
register(name, UDF14(func), kotlinEncoderFor<R>().schema())
705+
register(name, UDF14(func), schemaFor<R>())
706706
return UDFWrapper14(name)
707707
}
708708

@@ -777,7 +777,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
777777
T12::class.checkForValidType("T12")
778778
T13::class.checkForValidType("T13")
779779
T14::class.checkForValidType("T14")
780-
register(name, UDF15(func), kotlinEncoderFor<R>().schema())
780+
register(name, UDF15(func), schemaFor<R>())
781781
return UDFWrapper15(name)
782782
}
783783

@@ -855,7 +855,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
855855
T13::class.checkForValidType("T13")
856856
T14::class.checkForValidType("T14")
857857
T15::class.checkForValidType("T15")
858-
register(name, UDF16(func), kotlinEncoderFor<R>().schema())
858+
register(name, UDF16(func), schemaFor<R>())
859859
return UDFWrapper16(name)
860860
}
861861

@@ -936,7 +936,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
936936
T14::class.checkForValidType("T14")
937937
T15::class.checkForValidType("T15")
938938
T16::class.checkForValidType("T16")
939-
register(name, UDF17(func), kotlinEncoderFor<R>().schema())
939+
register(name, UDF17(func), schemaFor<R>())
940940
return UDFWrapper17(name)
941941
}
942942

@@ -1020,7 +1020,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
10201020
T15::class.checkForValidType("T15")
10211021
T16::class.checkForValidType("T16")
10221022
T17::class.checkForValidType("T17")
1023-
register(name, UDF18(func), kotlinEncoderFor<R>().schema())
1023+
register(name, UDF18(func), schemaFor<R>())
10241024
return UDFWrapper18(name)
10251025
}
10261026

@@ -1107,7 +1107,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
11071107
T16::class.checkForValidType("T16")
11081108
T17::class.checkForValidType("T17")
11091109
T18::class.checkForValidType("T18")
1110-
register(name, UDF19(func), kotlinEncoderFor<R>().schema())
1110+
register(name, UDF19(func), schemaFor<R>())
11111111
return UDFWrapper19(name)
11121112
}
11131113

@@ -1197,7 +1197,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
11971197
T17::class.checkForValidType("T17")
11981198
T18::class.checkForValidType("T18")
11991199
T19::class.checkForValidType("T19")
1200-
register(name, UDF20(func), kotlinEncoderFor<R>().schema())
1200+
register(name, UDF20(func), schemaFor<R>())
12011201
return UDFWrapper20(name)
12021202
}
12031203

@@ -1290,7 +1290,7 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
12901290
T18::class.checkForValidType("T18")
12911291
T19::class.checkForValidType("T19")
12921292
T20::class.checkForValidType("T20")
1293-
register(name, UDF21(func), kotlinEncoderFor<R>().schema())
1293+
register(name, UDF21(func), schemaFor<R>())
12941294
return UDFWrapper21(name)
12951295
}
12961296

@@ -1386,6 +1386,6 @@ inline fun <reified T0, reified T1, reified T2, reified T3, reified T4, reified
13861386
T19::class.checkForValidType("T19")
13871387
T20::class.checkForValidType("T20")
13881388
T21::class.checkForValidType("T21")
1389-
register(name, UDF22(func), kotlinEncoderFor<R>().schema())
1389+
register(name, UDF22(func), schemaFor<R>())
13901390
return UDFWrapper22(name)
13911391
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.spark.api
2323

2424
import org.apache.spark.sql.*
2525
import org.apache.spark.sql.types.DataType
26+
import org.apache.spark.sql.types.StructType
2627
import scala.collection.Seq
2728
import java.io.Serializable
2829
import kotlin.reflect.KClass
@@ -31,6 +32,7 @@ import kotlin.reflect.full.isSubclassOf
3132
import kotlin.reflect.full.primaryConstructor
3233
import org.apache.spark.sql.expressions.UserDefinedFunction as SparkUserDefinedFunction
3334

35+
3436
/**
3537
* Checks if [this] is of a valid type for a UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException]
3638
*/

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunctionVararg.kt

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ inline fun <reified R> udf(
135135

136136
return withAllowUntypedScalaUDF {
137137
UserDefinedFunctionVararg(
138-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ByteArray(i, init::call) }, kotlinEncoderFor<R>().schema())
138+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ByteArray(i, init::apply) }, schemaFor<R>())
139139
.let { if (nondeterministic) it.asNondeterministic() else it }
140140
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
141141
encoder = kotlinEncoderFor<R>(),
@@ -334,7 +334,7 @@ inline fun <reified R> udf(
334334

335335
return withAllowUntypedScalaUDF {
336336
UserDefinedFunctionVararg(
337-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ShortArray(i, init::call) }, kotlinEncoderFor<R>().schema())
337+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> ShortArray(i, init::apply) }, schemaFor<R>())
338338
.let { if (nondeterministic) it.asNondeterministic() else it }
339339
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
340340
encoder = kotlinEncoderFor<R>(),
@@ -533,7 +533,7 @@ inline fun <reified R> udf(
533533

534534
return withAllowUntypedScalaUDF {
535535
UserDefinedFunctionVararg(
536-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> IntArray(i, init::call) }, kotlinEncoderFor<R>().schema())
536+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> IntArray(i, init::apply) }, schemaFor<R>())
537537
.let { if (nondeterministic) it.asNondeterministic() else it }
538538
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
539539
encoder = kotlinEncoderFor<R>(),
@@ -732,7 +732,7 @@ inline fun <reified R> udf(
732732

733733
return withAllowUntypedScalaUDF {
734734
UserDefinedFunctionVararg(
735-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> LongArray(i, init::call) }, kotlinEncoderFor<R>().schema())
735+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> LongArray(i, init::apply) }, schemaFor<R>())
736736
.let { if (nondeterministic) it.asNondeterministic() else it }
737737
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
738738
encoder = kotlinEncoderFor<R>(),
@@ -931,7 +931,7 @@ inline fun <reified R> udf(
931931

932932
return withAllowUntypedScalaUDF {
933933
UserDefinedFunctionVararg(
934-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> FloatArray(i, init::call) }, kotlinEncoderFor<R>().schema())
934+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> FloatArray(i, init::apply) }, schemaFor<R>())
935935
.let { if (nondeterministic) it.asNondeterministic() else it }
936936
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
937937
encoder = kotlinEncoderFor<R>(),
@@ -1130,7 +1130,7 @@ inline fun <reified R> udf(
11301130

11311131
return withAllowUntypedScalaUDF {
11321132
UserDefinedFunctionVararg(
1133-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> DoubleArray(i, init::call) }, kotlinEncoderFor<R>().schema())
1133+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> DoubleArray(i, init::apply) }, schemaFor<R>())
11341134
.let { if (nondeterministic) it.asNondeterministic() else it }
11351135
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
11361136
encoder = kotlinEncoderFor<R>(),
@@ -1325,11 +1325,9 @@ inline fun <reified R> udf(
13251325
nondeterministic: Boolean = false,
13261326
varargFunc: UDF1<BooleanArray, R>,
13271327
): UserDefinedFunctionVararg<Boolean, R> {
1328-
1329-
13301328
return withAllowUntypedScalaUDF {
13311329
UserDefinedFunctionVararg(
1332-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> BooleanArray(i, init::call) }, kotlinEncoderFor<R>().schema())
1330+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> BooleanArray(i, init::apply) }, schemaFor<R>())
13331331
.let { if (nondeterministic) it.asNondeterministic() else it }
13341332
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
13351333
encoder = kotlinEncoderFor<R>(),
@@ -1528,7 +1526,7 @@ inline fun <reified T, reified R> udf(
15281526

15291527
return withAllowUntypedScalaUDF {
15301528
UserDefinedFunctionVararg(
1531-
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> Array<T>(i, init::call) }, kotlinEncoderFor<R>().schema())
1529+
udf = functions.udf(VarargUnwrapper(varargFunc) { i, init -> Array<T>(i, init::apply) }, schemaFor<R>())
15321530
.let { if (nondeterministic) it.asNondeterministic() else it }
15331531
.let { if (typeOf<R>().isMarkedNullable) it else it.asNonNullable() },
15341532
encoder = kotlinEncoderFor<R>(),

0 commit comments

Comments
 (0)