2929
3030package org.jetbrains.kotlinx.spark.api
3131
32+ import org.apache.commons.lang3.reflect.TypeUtils.*
3233import org.apache.spark.sql.Encoder
3334import org.apache.spark.sql.Row
3435import org.apache.spark.sql.catalyst.DefinedByConstructorParams
3536import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3637import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
3738import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.EncoderField
39+ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.JavaBeanEncoder
3840import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
3941import org.apache.spark.sql.catalyst.encoders.OuterScopes
4042import org.apache.spark.sql.types.DataType
@@ -49,19 +51,23 @@ import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName
4951import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify
5052import scala.reflect.ClassTag
5153import java.io.Serializable
54+ import java.util.*
55+ import javax.annotation.Nonnull
5256import kotlin.reflect.KClass
5357import kotlin.reflect.KMutableProperty
5458import kotlin.reflect.KProperty1
5559import kotlin.reflect.KType
5660import kotlin.reflect.KTypeProjection
5761import kotlin.reflect.full.createType
62+ import kotlin.reflect.full.declaredMemberFunctions
5863import kotlin.reflect.full.declaredMemberProperties
5964import kotlin.reflect.full.hasAnnotation
6065import kotlin.reflect.full.isSubclassOf
6166import kotlin.reflect.full.isSubtypeOf
6267import kotlin.reflect.full.primaryConstructor
6368import kotlin.reflect.full.staticFunctions
6469import kotlin.reflect.full.withNullability
70+ import kotlin.reflect.jvm.javaGetter
6571import kotlin.reflect.jvm.javaMethod
6672import kotlin.reflect.jvm.jvmName
6773import kotlin.reflect.typeOf
@@ -163,6 +169,7 @@ object KotlinTypeInference : Serializable {
163169 *
164170 * @return an [AgnosticEncoder] for the given [kType].
165171 */
172+ @Suppress(" UNCHECKED_CAST" )
166173 fun <T > encoderFor (kType : KType ): AgnosticEncoder <T > =
167174 encoderFor(
168175 currentType = kType,
@@ -562,10 +569,6 @@ object KotlinTypeInference : Serializable {
562569 }
563570
564571 kClass.isData -> {
565- // TODO provide warnings for non-Sparkify annotated classes
566- // TODO especially Pair and Triple, promote people to use Tuple2 and Tuple3 or use "getFirst" etc. as column name
567-
568- if (currentType in seenTypeSet) throw IllegalStateException (" Circular reference detected for type $currentType " )
569572 val constructor = kClass.primaryConstructor!!
570573 val kParameters = constructor .parameters
571574 // todo filter for transient?
@@ -586,7 +589,7 @@ object KotlinTypeInference : Serializable {
586589 )
587590
588591 val paramName = param.name
589- val readMethodName = prop.getter.javaMethod !! .name
592+ val readMethodName = prop.javaGetter !! .name
590593 val writeMethodName = (prop as ? KMutableProperty <* >)?.setter?.javaMethod?.name
591594
592595 EncoderField (
@@ -636,13 +639,87 @@ object KotlinTypeInference : Serializable {
636639 }
637640
638641 // java bean class
639- // currentType.classifier is KClass<*> -> {
640- // TODO()
641- //
642- // JavaBeanEncoder()
643- // }
642+ else -> {
643+ if (currentType in seenTypeSet)
644+ throw IllegalStateException (" Circular reference detected for type $currentType " )
645+
646+ val properties = getJavaBeanReadableProperties(kClass)
647+ val fields = properties.map {
648+ val encoder = encoderFor(
649+ currentType = it.type,
650+ seenTypeSet = seenTypeSet + currentType,
651+ typeVariables = typeVariables,
652+ )
653+
654+ EncoderField (
655+ /* name = */ it.propName,
656+ /* enc = */ encoder,
657+ /* nullable = */ encoder.nullable() && ! it.hasNonnull,
658+ /* metadata = */ Metadata .empty(),
659+ /* readMethod = */ it.getterName.toOption(),
660+ /* writeMethod = */ it.setterName.toOption(),
661+ )
662+ }
663+
664+ JavaBeanEncoder <Any >(
665+ ClassTag .apply (jClass),
666+ fields.asScalaSeq(),
667+ )
668+ }
669+
670+ // else -> throw IllegalArgumentException("No encoder found for type $currentType")
671+ }
672+ }
644673
645- else -> throw IllegalArgumentException (" No encoder found for type $currentType " )
674+ private data class JavaReadableProperty (
675+ val propName : String ,
676+ val getterName : String ,
677+ val setterName : String? ,
678+ val type : KType ,
679+ val hasNonnull : Boolean ,
680+ )
681+
682+ private fun getJavaBeanReadableProperties (klass : KClass <* >): List <JavaReadableProperty > {
683+ val functions = klass.declaredMemberFunctions.filter {
684+ it.name.startsWith(" get" ) || it.name.startsWith(" is" ) || it.name.startsWith(" set" )
646685 }
686+
687+ val properties = functions.mapNotNull { getter ->
688+ if (getter.name.startsWith(" set" )) return @mapNotNull null
689+
690+ val propName = getter.name
691+ .removePrefix(" get" )
692+ .removePrefix(" is" )
693+ .replaceFirstChar { it.lowercase() }
694+ val setter = functions.find {
695+ it.name == " set${propName.replaceFirstChar { it.uppercase() }} "
696+ }
697+
698+ JavaReadableProperty (
699+ propName = propName,
700+ getterName = getter.name,
701+ setterName = setter?.name,
702+ type = getter.returnType,
703+ hasNonnull = getter.hasAnnotation<Nonnull >(),
704+ )
705+ }
706+
707+ // Aside from java get/set functions, attempt to get kotlin properties as well, for non data classes
708+ val kotlinProps = klass.declaredMemberProperties
709+ .filter { it.getter.javaMethod != null } // filter kotlin-facing props
710+ .map {
711+ val hasSetter = (it as ? KMutableProperty <* >)?.setter != null
712+ val nameSuffix = it.name.removePrefix(" is" ).replaceFirstChar { it.uppercase() }
713+
714+ JavaReadableProperty (
715+ propName = it.name,
716+ getterName = if (it.name.startsWith(" is" )) it.name else " get$nameSuffix " ,
717+ setterName = if (hasSetter) " set$nameSuffix " else null ,
718+ type = it.returnType,
719+ hasNonnull = it.hasAnnotation<Nonnull >(),
720+ )
721+ }
722+
723+ return properties + kotlinProps
647724 }
648725}
0 commit comments