1
+ package org.jetbrains.kotlinx.spark.api.compilerPlugin.fir
2
+
3
+ import org.jetbrains.kotlin.GeneratedDeclarationKey
4
+ import org.jetbrains.kotlin.fir.FirSession
5
+ import org.jetbrains.kotlin.fir.declarations.utils.isData
6
+ import org.jetbrains.kotlin.fir.extensions.FirDeclarationGenerationExtension
7
+ import org.jetbrains.kotlin.fir.extensions.MemberGenerationContext
8
+ import org.jetbrains.kotlin.fir.plugin.createMemberFunction
9
+ import org.jetbrains.kotlin.fir.render
10
+ import org.jetbrains.kotlin.fir.resolve.getSuperTypes
11
+ import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
12
+ import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
13
+ import org.jetbrains.kotlin.fir.types.toClassSymbol
14
+ import org.jetbrains.kotlin.name.CallableId
15
+ import org.jetbrains.kotlin.name.Name
16
+
17
+ class DataClassSparkifyFunctionsGenerator (
18
+ session : FirSession ,
19
+ private val sparkifyAnnotationFqNames : List <String >,
20
+ private val productFqNames : List <String >,
21
+ ) : FirDeclarationGenerationExtension(session) {
22
+
23
+ companion object {
24
+ fun builder (
25
+ sparkifyAnnotationFqNames : List <String >,
26
+ productFqNames : List <String >
27
+ ): (FirSession ) -> FirDeclarationGenerationExtension = {
28
+ DataClassSparkifyFunctionsGenerator (
29
+ session = it,
30
+ sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
31
+ productFqNames = productFqNames,
32
+ )
33
+ }
34
+
35
+ // functions to generate
36
+ val canEqual = Name .identifier(" canEqual" )
37
+ val productElement = Name .identifier(" productElement" )
38
+ val productArity = Name .identifier(" productArity" )
39
+ }
40
+
41
+ override fun generateFunctions (
42
+ callableId : CallableId ,
43
+ context : MemberGenerationContext ?
44
+ ): List <FirNamedFunctionSymbol > {
45
+ val owner = context?.owner ? : return emptyList()
46
+
47
+ val functionName = callableId.callableName
48
+ val superTypes = owner.getSuperTypes(session)
49
+ val superProduct = superTypes.first {
50
+ it.toString().endsWith(" Product" )
51
+ }.toClassSymbol(session)!!
52
+ val superEquals = superTypes.first {
53
+ it.toString().endsWith(" Equals" )
54
+ }.toClassSymbol(session)!!
55
+
56
+ val function = when (functionName) {
57
+ canEqual -> {
58
+ val func = createMemberFunction(
59
+ owner = owner,
60
+ key = Key ,
61
+ name = functionName,
62
+ returnType = session.builtinTypes.booleanType.type,
63
+ ) {
64
+ valueParameter(
65
+ name = Name .identifier(" that" ),
66
+ type = session.builtinTypes.nullableAnyType.type,
67
+ )
68
+ }
69
+ // val superFunction = superEquals.declarationSymbols.first {
70
+ // it is FirNamedFunctionSymbol && it.name == functionName
71
+ // } as FirNamedFunctionSymbol
72
+ // overrides(func, superFunction)
73
+ func
74
+ }
75
+
76
+ productElement -> {
77
+ createMemberFunction(
78
+ owner = owner,
79
+ key = Key ,
80
+ name = functionName,
81
+ returnType = session.builtinTypes.nullableAnyType.type,
82
+ ) {
83
+ valueParameter(
84
+ name = Name .identifier(" n" ),
85
+ type = session.builtinTypes.intType.type,
86
+ )
87
+ }
88
+ }
89
+
90
+ productArity -> {
91
+ createMemberFunction(
92
+ owner = owner,
93
+ key = Key ,
94
+ name = functionName,
95
+ returnType = session.builtinTypes.intType.type,
96
+ )
97
+ }
98
+
99
+ else -> {
100
+ return emptyList()
101
+ }
102
+ }
103
+
104
+ return listOf (function.symbol)
105
+ }
106
+
107
+ override fun getCallableNamesForClass (classSymbol : FirClassSymbol <* >, context : MemberGenerationContext ): Set <Name > =
108
+ if (classSymbol.isData && classSymbol.annotations.any { " Sparkify" in it.render() }) {
109
+ setOf (canEqual, productElement, productArity)
110
+ } else {
111
+ emptySet()
112
+ }
113
+
114
+ object Key : GeneratedDeclarationKey()
115
+ }
0 commit comments