Skip to content

Commit

Permalink
updating all references to the old encoder<>() function in favor of t…
Browse files Browse the repository at this point in the history
…he new kotlinEncoderFor<>()
  • Loading branch information
Jolanrensen committed Mar 17, 2024
1 parent 4896354 commit e234f40
Show file tree
Hide file tree
Showing 10 changed files with 1,821 additions and 1,845 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ private object MyAverage : Aggregator<Employee, Average, Double>() {
override fun finish(reduction: Average): Double = reduction.sum.toDouble() / reduction.count

// Specifies the Encoder for the intermediate value type
override fun bufferEncoder(): Encoder<Average> = encoder()
override fun bufferEncoder(): Encoder<Average> = kotlinEncoderFor()

// Specifies the Encoder for the final output value type
override fun outputEncoder(): Encoder<Double> = encoder()
override fun outputEncoder(): Encoder<Double> = kotlinEncoderFor()

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ operator fun Column.get(key: Any): Column = getItem(key)
* @see typed
*/
@Suppress("UNCHECKED_CAST")
inline fun <DsType, reified U> Column.`as`(): TypedColumn<DsType, U> = `as`(encoder<U>()) as TypedColumn<DsType, U>
inline fun <DsType, reified U> Column.`as`(): TypedColumn<DsType, U> = `as`(kotlinEncoderFor<U>()) as TypedColumn<DsType, U>

/**
* Provides a type hint about the expected return value of this column. This information can
Expand All @@ -458,7 +458,7 @@ inline fun <DsType, reified U> Column.`as`(): TypedColumn<DsType, U> = `as`(enco
* @see typed
*/
@Suppress("UNCHECKED_CAST")
inline fun <DsType, reified U> TypedColumn<DsType, *>.`as`(): TypedColumn<DsType, U> = `as`(encoder<U>()) as TypedColumn<DsType, U>
inline fun <DsType, reified U> TypedColumn<DsType, *>.`as`(): TypedColumn<DsType, U> = `as`(kotlinEncoderFor<U>()) as TypedColumn<DsType, U>

/**
* Provides a type hint about the expected return value of this column. This information can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.api.java.function.ReduceFunction
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.*
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import scala.Tuple2
import scala.Tuple3
import scala.Tuple4
Expand All @@ -49,7 +48,7 @@ import kotlin.reflect.KProperty1
* Utility method to create dataset from list
*/
inline fun <reified T> SparkSession.toDS(list: List<T>): Dataset<T> =
createDataset(list, encoder<T>())
createDataset(list, kotlinEncoderFor<T>())

/**
* Utility method to create dataframe from list
Expand All @@ -61,26 +60,26 @@ inline fun <reified T> SparkSession.toDF(list: List<T>, vararg colNames: String)
* Utility method to create dataset from *array or vararg arguments
*/
inline fun <reified T> SparkSession.dsOf(vararg t: T): Dataset<T> =
createDataset(t.toList(), encoder<T>())
createDataset(t.toList(), kotlinEncoderFor<T>())

/**
* Utility method to create dataframe from *array or vararg arguments
*/
inline fun <reified T> SparkSession.dfOf(vararg t: T): Dataset<Row> =
createDataset(t.toList(), encoder<T>()).toDF()
createDataset(t.toList(), kotlinEncoderFor<T>()).toDF()

/**
* Utility method to create dataframe from *array or vararg arguments with given column names
*/
inline fun <reified T> SparkSession.dfOf(colNames: Array<String>, vararg t: T): Dataset<Row> =
createDataset(t.toList(), encoder<T>())
createDataset(t.toList(), kotlinEncoderFor<T>())
.run { if (colNames.isEmpty()) toDF() else toDF(*colNames) }

/**
* Utility method to create dataset from list
*/
inline fun <reified T> List<T>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this, encoder<T>())
spark.createDataset(this, kotlinEncoderFor<T>())

/**
* Utility method to create dataframe from list
Expand All @@ -104,13 +103,13 @@ inline fun <reified T> Array<T>.toDF(spark: SparkSession, vararg colNames: Strin
* Utility method to create dataset from RDD
*/
inline fun <reified T> RDD<T>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this, encoder<T>())
spark.createDataset(this, kotlinEncoderFor<T>())

/**
* Utility method to create dataset from JavaRDD
*/
inline fun <reified T> JavaRDDLike<T, *>.toDS(spark: SparkSession): Dataset<T> =
spark.createDataset(this.rdd(), encoder<T>())
spark.createDataset(this.rdd(), kotlinEncoderFor<T>())

/**
* Utility method to create Dataset<Row> (Dataframe) from JavaRDD.
Expand All @@ -132,37 +131,37 @@ inline fun <reified T> RDD<T>.toDF(spark: SparkSession, vararg colNames: String)
* Returns a new Dataset that contains the result of applying [func] to each element.
*/
inline fun <reified T, reified R> Dataset<T>.map(noinline func: (T) -> R): Dataset<R> =
map(MapFunction(func), encoder<R>())
map(MapFunction(func), kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
* Returns a new Dataset by first applying a function to all elements of this Dataset,
* and then flattening the results.
*/
inline fun <T, reified R> Dataset<T>.flatMap(noinline func: (T) -> Iterator<R>): Dataset<R> =
flatMap(func, encoder<R>())
flatMap(func, kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
* Returns a new Dataset by flattening. This means that a Dataset of an iterable such as
* `listOf(listOf(1, 2, 3), listOf(4, 5, 6))` will be flattened to a Dataset of `listOf(1, 2, 3, 4, 5, 6)`.
*/
inline fun <reified T, I : Iterable<T>> Dataset<I>.flatten(): Dataset<T> =
flatMap(FlatMapFunction { it.iterator() }, encoder<T>())
flatMap(FlatMapFunction { it.iterator() }, kotlinEncoderFor<T>())

/**
* (Kotlin-specific)
* Returns a [KeyValueGroupedDataset] where the data is grouped by the given key [func].
*/
inline fun <T, reified R> Dataset<T>.groupByKey(noinline func: (T) -> R): KeyValueGroupedDataset<R, T> =
groupByKey(MapFunction(func), encoder<R>())
groupByKey(MapFunction(func), kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
* Returns a new Dataset that contains the result of applying [func] to each partition.
*/
inline fun <T, reified R> Dataset<T>.mapPartitions(noinline func: (Iterator<T>) -> Iterator<R>): Dataset<R> =
mapPartitions(func, encoder<R>())
mapPartitions(func, kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand Down Expand Up @@ -193,15 +192,6 @@ inline fun <reified T1, T2> Dataset<Tuple2<T1, T2>>.takeKeys(): Dataset<T1> = ma
*/
inline fun <reified T1, T2> Dataset<Pair<T1, T2>>.takeKeys(): Dataset<T1> = map { it.first }

/**
* (Kotlin-specific)
* Maps the Dataset to only retain the "keys" or [Arity2._1] values.
*/
@Suppress("DEPRECATION")
@JvmName("takeKeysArity2")
@Deprecated("Use Scala tuples instead.", ReplaceWith(""))
inline fun <reified T1, T2> Dataset<Arity2<T1, T2>>.takeKeys(): Dataset<T1> = map { it._1 }

/**
* (Kotlin-specific)
* Maps the Dataset to only retain the "values" or [Tuple2._2] values.
Expand All @@ -215,22 +205,13 @@ inline fun <T1, reified T2> Dataset<Tuple2<T1, T2>>.takeValues(): Dataset<T2> =
*/
inline fun <T1, reified T2> Dataset<Pair<T1, T2>>.takeValues(): Dataset<T2> = map { it.second }

/**
* (Kotlin-specific)
* Maps the Dataset to only retain the "values" or [Arity2._2] values.
*/
@Suppress("DEPRECATION")
@JvmName("takeValuesArity2")
@Deprecated("Use Scala tuples instead.", ReplaceWith(""))
inline fun <T1, reified T2> Dataset<Arity2<T1, T2>>.takeValues(): Dataset<T2> = map { it._2 }

/** DEPRECATED: Use [as] or [to] for this. */
@Deprecated(
message = "Deprecated, since we already have `as`() and to().",
replaceWith = ReplaceWith("this.to<R>()"),
level = DeprecationLevel.ERROR,
)
inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand All @@ -252,7 +233,7 @@ inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
*
* @see to as alias for [as]
*/
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand All @@ -274,7 +255,7 @@ inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
*
* @see as as alias for [to]
*/
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand All @@ -292,12 +273,16 @@ inline fun <reified T> Dataset<T>.forEachPartition(noinline func: (Iterator<T>)
/**
* It's hard to call `Dataset.debugCodegen` from kotlin, so here is utility for that
*/
fun <T> Dataset<T>.debugCodegen(): Dataset<T> = also { KSparkExtensions.debugCodegen(it) }
fun <T> Dataset<T>.debugCodegen(): Dataset<T> = also {
org.apache.spark.sql.execution.debug.`package$`.`MODULE$`.DebugQuery(it).debugCodegen()
}

/**
* It's hard to call `Dataset.debug` from kotlin, so here is utility for that
*/
fun <T> Dataset<T>.debug(): Dataset<T> = also { KSparkExtensions.debug(it) }
fun <T> Dataset<T>.debug(): Dataset<T> = also {
org.apache.spark.sql.execution.debug.`package$`.`MODULE$`.DebugQuery(it).debug()
}


/**
Expand Down Expand Up @@ -370,18 +355,6 @@ fun <T1, T2> Dataset<Tuple2<T1, T2>>.sortByKey(): Dataset<Tuple2<T1, T2>> = sort
@JvmName("sortByTuple2Value")
fun <T1, T2> Dataset<Tuple2<T1, T2>>.sortByValue(): Dataset<Tuple2<T1, T2>> = sort("_2")

/** Returns a dataset sorted by the first (`_1`) value of each [Arity2] inside. */
@Suppress("DEPRECATION")
@Deprecated("Use Scala tuples instead.", ReplaceWith(""))
@JvmName("sortByArity2Key")
fun <T1, T2> Dataset<Arity2<T1, T2>>.sortByKey(): Dataset<Arity2<T1, T2>> = sort("_1")

/** Returns a dataset sorted by the second (`_2`) value of each [Arity2] inside. */
@Suppress("DEPRECATION")
@Deprecated("Use Scala tuples instead.", ReplaceWith(""))
@JvmName("sortByArity2Value")
fun <T1, T2> Dataset<Arity2<T1, T2>>.sortByValue(): Dataset<Arity2<T1, T2>> = sort("_2")

/** Returns a dataset sorted by the first (`first`) value of each [Pair] inside. */
@JvmName("sortByPairKey")
fun <T1, T2> Dataset<Pair<T1, T2>>.sortByKey(): Dataset<Pair<T1, T2>> = sort("first")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import scala.Tuple2
* ```
*/
inline fun <KEY, VALUE, reified R> KeyValueGroupedDataset<KEY, VALUE>.mapValues(noinline func: (VALUE) -> R): KeyValueGroupedDataset<KEY, R> =
mapValues(MapFunction(func), encoder<R>())
mapValues(MapFunction(func), kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand All @@ -70,7 +70,7 @@ inline fun <KEY, VALUE, reified R> KeyValueGroupedDataset<KEY, VALUE>.mapValues(
* constraints of their cluster.
*/
inline fun <KEY, VALUE, reified R> KeyValueGroupedDataset<KEY, VALUE>.mapGroups(noinline func: (KEY, Iterator<VALUE>) -> R): Dataset<R> =
mapGroups(MapGroupsFunction(func), encoder<R>())
mapGroups(MapGroupsFunction(func), kotlinEncoderFor<R>())

/**
* (Kotlin-specific)
Expand Down Expand Up @@ -104,7 +104,7 @@ inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>,
): Dataset<U> = flatMapGroups(
FlatMapGroupsFunction(func),
encoder<U>(),
kotlinEncoderFor<U>(),
)


Expand All @@ -127,8 +127,8 @@ inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWi
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U,
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
kotlinEncoderFor<S>(),
kotlinEncoderFor<U>(),
)

/**
Expand All @@ -152,8 +152,8 @@ inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWi
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U,
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
kotlinEncoderFor<S>(),
kotlinEncoderFor<U>(),
timeoutConf,
)

Expand Down Expand Up @@ -181,8 +181,8 @@ inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGrou
): Dataset<U> = flatMapGroupsWithState(
FlatMapGroupsWithStateFunction(func),
outputMode,
encoder<S>(),
encoder<U>(),
kotlinEncoderFor<S>(),
kotlinEncoderFor<U>(),
timeoutConf,
)

Expand All @@ -199,5 +199,5 @@ inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
): Dataset<R> = cogroup(
other,
CoGroupFunction(func),
encoder<R>(),
kotlinEncoderFor<R>(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaRDDLike
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
Expand All @@ -45,7 +44,6 @@ import org.apache.spark.streaming.Durations
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
import org.jetbrains.kotlinx.spark.api.tuples.*
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import java.io.Serializable

/**
Expand Down Expand Up @@ -76,7 +74,7 @@ class KSparkSession(val spark: SparkSession) {
inline fun <reified T> dsOf(vararg arg: T): Dataset<T> = spark.dsOf(*arg)

/** Creates new empty dataset of type [T]. */
inline fun <reified T> emptyDataset(): Dataset<T> = spark.emptyDataset(encoder<T>())
inline fun <reified T> emptyDataset(): Dataset<T> = spark.emptyDataset(kotlinEncoderFor<T>())

/** Utility method to create dataframe from *array or vararg arguments */
inline fun <reified T> dfOf(vararg arg: T): Dataset<Row> = spark.dfOf(*arg)
Expand Down Expand Up @@ -227,7 +225,7 @@ enum class SparkLogLevel {
* Returns the Spark context associated with this Spark session.
*/
val SparkSession.sparkContext: SparkContext
get() = KSparkExtensions.sparkContext(this)
get() = sparkContext()

/**
* Wrapper for spark creation which allows setting different spark params.
Expand Down Expand Up @@ -339,7 +337,7 @@ inline fun withSpark(sparkConf: SparkConf, logLevel: SparkLogLevel = ERROR, func
fun withSparkStreaming(
batchDuration: Duration = Durations.seconds(1L),
checkpointPath: String? = null,
hadoopConf: Configuration = SparkHadoopUtil.get().conf(),
hadoopConf: Configuration = getDefaultHadoopConf(),
createOnError: Boolean = false,
props: Map<String, Any> = emptyMap(),
master: String = SparkConf().get("spark.master", "local[*]"),
Expand Down Expand Up @@ -386,6 +384,18 @@ fun withSparkStreaming(
ssc.stop()
}

// calling org.apache.spark.deploy.`SparkHadoopUtil$`.`MODULE$`.get().conf()
private fun getDefaultHadoopConf(): Configuration {
val klass = Class.forName("org.apache.spark.deploy.SparkHadoopUtil$")
val moduleField = klass.getField("MODULE$").also { it.isAccessible = true }
val module = moduleField.get(null)
val getMethod = klass.getMethod("get").also { it.isAccessible = true }
val sparkHadoopUtil = getMethod.invoke(module)
val confMethod = sparkHadoopUtil.javaClass.getMethod("conf").also { it.isAccessible = true }
val conf = confMethod.invoke(sparkHadoopUtil) as Configuration

return conf
}

/**
* Broadcast a read-only variable to the cluster, returning a
Expand All @@ -396,7 +406,7 @@ fun withSparkStreaming(
* @return `Broadcast` object, a read-only variable cached on each machine
*/
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
sparkContext.broadcast(value, encoder<T>().clsTag())
sparkContext.broadcast(value, kotlinEncoderFor<T>().clsTag())
} catch (e: ClassNotFoundException) {
JavaSparkContext(sparkContext).broadcast(value)
}
Expand All @@ -416,7 +426,7 @@ inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
DeprecationLevel.WARNING
)
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
broadcast(value, encoder<T>().clsTag())
broadcast(value, kotlinEncoderFor<T>().clsTag())
} catch (e: ClassNotFoundException) {
JavaSparkContext(this).broadcast(value)
}
Loading

0 comments on commit e234f40

Please sign in to comment.