Skip to content

Commit df021c0

Browse files
committed
fixing more tests. Can now remain at Kotlin 2.0 if we set -Xlambdas=class, which can be done with gradle plugin
1 parent 7069a9a commit df021c0

File tree

12 files changed

+103
-52
lines changed

12 files changed

+103
-52
lines changed

buildSrc/src/main/kotlin/Versions.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ object Versions : Dsl<Versions> {
22
const val project = "2.0.0-SNAPSHOT"
33
const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT"
44
const val groupID = "org.jetbrains.kotlinx.spark"
5-
// const val kotlin = "2.0.0-Beta5" // todo issues with NonSerializable lambdas
6-
const val kotlin = "1.9.23"
5+
const val kotlin = "2.0.0-Beta5"
76
const val jvmTarget = "8"
87
const val jupyterJvmTarget = "8"
98
inline val spark get() = System.getProperty("spark") as String

gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class SparkKotlinCompilerGradlePlugin : KotlinCompilerPluginSupportPlugin {
2020
compilerOptions {
2121
// Make sure the parameters of data classes are visible to scala
2222
javaParameters.set(true)
23+
24+
// Avoid NotSerializableException by making lambdas serializable
25+
freeCompilerArgs.add("-Xlambdas=class")
2326
}
2427
}
2528
}

kotlin-spark-api/build.gradle.kts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ tasks.compileTestKotlin {
147147
kotlin {
148148
jvmToolchain {
149149
languageVersion = JavaLanguageVersion.of(Versions.jvmTarget)
150-
151150
}
152151
}
153152

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import org.apache.spark.sql.types.UDTRegistration
4646
import org.apache.spark.sql.types.UserDefinedType
4747
import org.apache.spark.unsafe.types.CalendarInterval
4848
import scala.reflect.ClassTag
49+
import java.io.Serializable
4950
import kotlin.reflect.KClass
5051
import kotlin.reflect.KMutableProperty
5152
import kotlin.reflect.KType
@@ -122,7 +123,10 @@ fun schemaFor(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema().u
122123
@Deprecated("Use schemaFor instead", ReplaceWith("schemaFor(kType)"))
123124
fun schema(kType: KType) = schemaFor(kType)
124125

125-
object KotlinTypeInference {
126+
object KotlinTypeInference : Serializable {
127+
128+
// https://blog.stylingandroid.com/kotlin-serializable-objects/
129+
private fun readResolve(): Any = KotlinTypeInference
126130

127131
/**
128132
* @param kClass the class for which to infer the encoder.

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ inline fun <reified T : Number> JavaRDD<T>.toJavaDoubleRDD(): JavaDoubleRDD =
2020

2121
/** Utility method to convert [JavaDoubleRDD] to [JavaRDD]<[Double]>. */
2222
@Suppress("UNCHECKED_CAST")
23-
fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
23+
inline fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
2424
JavaDoubleRDD.toRDD(this).toJavaRDD() as JavaRDD<Double>
2525

2626
/** Add up the elements in this RDD. */

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.streaming.Durations
4444
import org.apache.spark.streaming.api.java.JavaStreamingContext
4545
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
4646
import org.jetbrains.kotlinx.spark.api.tuples.*
47+
import scala.reflect.ClassTag
4748
import java.io.Serializable
4849

4950
/**
@@ -406,7 +407,7 @@ private fun getDefaultHadoopConf(): Configuration {
406407
* @return `Broadcast` object, a read-only variable cached on each machine
407408
*/
408409
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
409-
sparkContext.broadcast(value, kotlinEncoderFor<T>().clsTag())
410+
sparkContext.broadcast(value, ClassTag.apply(T::class.java))
410411
} catch (e: ClassNotFoundException) {
411412
JavaSparkContext(sparkContext).broadcast(value)
412413
}
@@ -426,7 +427,7 @@ inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
426427
DeprecationLevel.WARNING
427428
)
428429
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
429-
broadcast(value, kotlinEncoderFor<T>().clsTag())
430+
broadcast(value, ClassTag.apply(T::class.java))
430431
} catch (e: ClassNotFoundException) {
431432
JavaSparkContext(this).broadcast(value)
432433
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName:
6969
)
7070

7171
@JvmName("arrayColumnAsSeq")
72-
fun <DsType, T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
72+
inline fun <DsType, reified T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
7373
@JvmName("iterableColumnAsSeq")
74-
fun <DsType, T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
74+
inline fun <DsType, reified T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
7575
@JvmName("byteArrayColumnAsSeq")
7676
fun <DsType> TypedColumn<DsType, ByteArray>.asSeq(): TypedColumn<DsType, Seq<Byte>> = typed()
7777
@JvmName("charArrayColumnAsSeq")

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ import java.time.Period
4141

4242
class EncodingTest : ShouldSpec({
4343

44+
@Sparkify
45+
data class SparkifiedPair<T, U>(val first: T, val second: U)
46+
4447
context("encoders") {
4548
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
4649

@@ -134,8 +137,8 @@ class EncodingTest : ShouldSpec({
134137
}
135138

136139
should("be able to serialize Date") {
137-
val datePair = Date.valueOf("2020-02-10") to 5
138-
val dataset: Dataset<Pair<Date, Int>> = dsOf(datePair)
140+
val datePair = SparkifiedPair(Date.valueOf("2020-02-10"), 5)
141+
val dataset: Dataset<SparkifiedPair<Date, Int>> = dsOf(datePair)
139142
dataset.collectAsList() shouldBe listOf(datePair)
140143
}
141144

@@ -213,6 +216,8 @@ class EncodingTest : ShouldSpec({
213216

214217
context("Give proper names to columns of data classes") {
215218

219+
infix fun <A, B> A.to(other: B) = SparkifiedPair(this, other)
220+
216221
should("Be able to serialize pairs") {
217222
val pairs = listOf(
218223
1 to "1",
@@ -653,25 +658,25 @@ class EncodingTest : ShouldSpec({
653658
}
654659

655660
should("handle arrays of generics") {
656-
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>)
661+
data class Test<Z>(val id: Long, val data: Array<SparkifiedPair<Z, Int>>)
657662

658-
val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
663+
val result = listOf(Test(1, arrayOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
659664
.toDS()
660665
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
661666
.map { it.second }
662667
.collectAsList()
663-
expect(result).toContain.inOrder.only.values(5.1 to 6)
668+
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
664669
}
665670

666671
should("handle lists of generics") {
667-
data class Test<Z>(val id: Long, val data: List<Pair<Z, Int>>)
672+
data class Test<Z>(val id: Long, val data: List<SparkifiedPair<Z, Int>>)
668673

669-
val result = listOf(Test(1, listOf(5.1 to 6, 6.1 to 7)))
674+
val result = listOf(Test(1, listOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
670675
.toDS()
671676
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
672677
.map { it.second }
673678
.collectAsList()
674-
expect(result).toContain.inOrder.only.values(5.1 to 6)
679+
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
675680
}
676681

677682
should("handle boxed arrays") {

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ import io.kotest.matchers.shouldBe
66
import org.apache.spark.api.java.JavaRDD
77
import org.jetbrains.kotlinx.spark.api.tuples.*
88
import scala.Tuple2
9+
import java.io.Serializable
910

10-
class RddTest : ShouldSpec({
11+
class RddTest : Serializable, ShouldSpec({
1112
context("RDD extension functions") {
1213

13-
withSpark(logLevel = SparkLogLevel.DEBUG) {
14+
withSpark(
15+
props = mapOf("spark.sql.codegen.wholeStage" to false),
16+
logLevel = SparkLogLevel.DEBUG,
17+
) {
1418

1519
context("Key/value") {
1620
should("work with spark example") {

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import scala.Tuple2
3939
import java.io.File
4040
import java.io.Serializable
4141
import java.nio.charset.StandardCharsets
42+
import java.nio.file.Files
4243
import java.util.*
4344
import java.util.concurrent.atomic.AtomicBoolean
4445

@@ -201,10 +202,10 @@ class StreamingTest : ShouldSpec({
201202

202203
private val scalaCompatVersion = SCALA_COMPAT_VERSION
203204
private val sparkVersion = SPARK_VERSION
204-
private fun createTempDir() = File.createTempFile(
205-
System.getProperty("java.io.tmpdir"),
206-
"spark_${scalaCompatVersion}_${sparkVersion}"
207-
).apply { deleteOnExit() }
205+
private fun createTempDir() =
206+
Files.createTempDirectory("spark_${scalaCompatVersion}_${sparkVersion}")
207+
.toFile()
208+
.also { it.deleteOnExit() }
208209

209210
private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
210211
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
@@ -215,7 +216,10 @@ private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
215216
return checkpointFileMethod.invoke(module, checkpointDir, checkpointTime) as Path
216217
}
217218

218-
private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSystem>): scala.collection.immutable.Seq<Path> {
219+
private fun getCheckpointFiles(
220+
checkpointDir: String,
221+
fs: scala.Option<FileSystem>
222+
): scala.collection.immutable.Seq<Path> {
219223
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
220224
val moduleField = klass.getField("MODULE$").also { it.isAccessible = true }
221225
val module = moduleField.get(null)
@@ -227,7 +231,11 @@ private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSyste
227231
private fun createCorruptedCheckpoint(): String {
228232
val checkpointDirectory = createTempDir().absolutePath
229233
val fakeCheckpointFile = checkpointFile(checkpointDirectory, Time(1000))
230-
FileUtils.write(File(fakeCheckpointFile.toString()), "spark_corrupt_${scalaCompatVersion}_${sparkVersion}", StandardCharsets.UTF_8)
234+
FileUtils.write(
235+
File(fakeCheckpointFile.toString()),
236+
"spark_corrupt_${scalaCompatVersion}_${sparkVersion}",
237+
StandardCharsets.UTF_8
238+
)
231239
assert(getCheckpointFiles(checkpointDirectory, (null as FileSystem?).toOption()).nonEmpty())
232240
return checkpointDirectory
233241
}

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/TypeInferenceTest.kt

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,22 @@ import org.jetbrains.kotlinx.spark.api.struct.model.ElementType.ComplexElement
3030
import org.jetbrains.kotlinx.spark.api.struct.model.ElementType.SimpleElement
3131
import org.jetbrains.kotlinx.spark.api.struct.model.Struct
3232
import org.jetbrains.kotlinx.spark.api.struct.model.StructField
33-
import kotlin.reflect.typeOf
34-
3533

3634
@OptIn(ExperimentalStdlibApi::class)
3735
class TypeInferenceTest : ShouldSpec({
36+
@Sparkify
37+
data class SparkifiedPair<T, U>(val first: T, val second: U)
38+
39+
@Sparkify
40+
data class SparkifiedTriple<T, U, V>(val first: T, val second: U, val third: V)
41+
3842
context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema") {
39-
@Sparkify data class Test2<T>(val vala2: T, val para2: Pair<T, String>)
40-
@Sparkify data class Test<T>(val vala: T, val tripl1: Triple<T, Test2<Long>, T>)
43+
@Sparkify
44+
data class Test2<T>(val vala2: T, val para2: SparkifiedPair<T, String>)
45+
@Sparkify
46+
data class Test<T>(val vala: T, val tripl1: SparkifiedTriple<T, Test2<Long>, T>)
4147

42-
val struct = Struct.fromJson(schemaFor<Pair<String, Test<Int>>>().prettyJson())!!
48+
val struct = Struct.fromJson(schemaFor<SparkifiedPair<String, Test<Int>>>().prettyJson())!!
4349
should("contain correct typings") {
4450
expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries(
4551
hasField("first", "string"),
@@ -65,12 +71,15 @@ class TypeInferenceTest : ShouldSpec({
6571
}
6672
}
6773
context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema with more complex data") {
68-
@Sparkify data class Single<T>(val vala3: T)
6974
@Sparkify
70-
data class Test2<T>(val vala2: T, val para2: Pair<T, Single<Double>>)
71-
@Sparkify data class Test<T>(val vala: T, val tripl1: Triple<T, Test2<Long>, T>)
75+
data class Single<T>(val vala3: T)
76+
77+
@Sparkify
78+
data class Test2<T>(val vala2: T, val para2: SparkifiedPair<T, Single<Double>>)
79+
@Sparkify
80+
data class Test<T>(val vala: T, val tripl1: SparkifiedTriple<T, Test2<Long>, T>)
7281

73-
val struct = Struct.fromJson(schemaFor<Pair<String, Test<Int>>>().prettyJson())!!
82+
val struct = Struct.fromJson(schemaFor<SparkifiedPair<String, Test<Int>>>().prettyJson())!!
7483
should("contain correct typings") {
7584
expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries(
7685
hasField("first", "string"),
@@ -99,7 +108,7 @@ class TypeInferenceTest : ShouldSpec({
99108
}
100109
}
101110
context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema without generics") {
102-
data class Test(val a: String, val b: Int, val c: Double)
111+
@Sparkify data class Test(val a: String, val b: Int, val c: Double)
103112

104113
val struct = Struct.fromJson(schemaFor<Test>().prettyJson())!!
105114
should("return correct types too") {
@@ -120,7 +129,7 @@ class TypeInferenceTest : ShouldSpec({
120129
}
121130
}
122131
context("type with list of Pairs int to long") {
123-
val struct = Struct.fromJson(schemaFor<List<Pair<Int, Long>>>().prettyJson())!!
132+
val struct = Struct.fromJson(schemaFor<List<SparkifiedPair<Int, Long>>>().prettyJson())!!
124133
should("return correct types too") {
125134
expect(struct) {
126135
isOfType("array")
@@ -134,7 +143,7 @@ class TypeInferenceTest : ShouldSpec({
134143
}
135144
}
136145
context("type with list of generic data class with E generic name") {
137-
data class Test<E>(val e: E)
146+
@Sparkify data class Test<E>(val e: E)
138147

139148
val struct = Struct.fromJson(schemaFor<List<Test<String>>>().prettyJson())!!
140149
should("return correct types too") {
@@ -180,7 +189,7 @@ class TypeInferenceTest : ShouldSpec({
180189
}
181190
}
182191
context("data class with props in order lon → lat") {
183-
data class Test(val lon: Double, val lat: Double)
192+
@Sparkify data class Test(val lon: Double, val lat: Double)
184193

185194
val struct = Struct.fromJson(schemaFor<Test>().prettyJson())!!
186195
should("Not change order of fields") {
@@ -191,7 +200,7 @@ class TypeInferenceTest : ShouldSpec({
191200
}
192201
}
193202
context("data class with nullable list inside") {
194-
data class Sample(val optionList: List<Int>?)
203+
@Sparkify data class Sample(val optionList: List<Int>?)
195204

196205
val struct = Struct.fromJson(schemaFor<Sample>().prettyJson())!!
197206

@@ -223,8 +232,8 @@ class TypeInferenceTest : ShouldSpec({
223232
.feature("element name", { name() }) { toEqual("optionList") }
224233
.feature("field type", { dataType() }, {
225234
this
226-
.isA<ArrayType>()
227-
.feature("element type", { elementType() }) { isA<IntegerType>() }
235+
.toBeAnInstanceOf<ArrayType>()
236+
.feature("element type", { elementType() }) { toBeAnInstanceOf<IntegerType>() }
228237
.feature("element nullable", { containsNull() }) { toEqual(expected = false) }
229238
})
230239
.feature("optionList nullable", { nullable() }) { toEqual(true) }
@@ -258,5 +267,5 @@ private fun hasStruct(
258267

259268
private fun hasField(name: String, type: String): Expect<StructField>.() -> Unit = {
260269
feature { f(it::name) }.toEqual(name)
261-
feature { f(it::type) }.isA<TypeName>().feature { f(it::value) }.toEqual(type)
270+
feature { f(it::type) }.toBeAnInstanceOf<TypeName>().feature { f(it::value) }.toEqual(type)
262271
}

0 commit comments

Comments
 (0)