@@ -20,12 +20,11 @@ object DesugarEnums {
20
20
val Simple, Object, Class : Value = Value
21
21
}
22
22
23
- final class EnumConstraints (minKind : CaseKind .Value , maxKind : CaseKind .Value , cases : List [(Int , TermName )]):
24
- require(minKind <= maxKind && ! (cached && cachedValues .isEmpty))
23
+ final case class EnumConstraints (minKind : CaseKind .Value , maxKind : CaseKind .Value , enumCases : List [(Int , RefTree )]):
24
+ require(minKind <= maxKind && ! (cached && enumCases .isEmpty))
25
25
def requiresCreator = minKind == CaseKind .Simple
26
26
def isEnumeration = maxKind < CaseKind .Class
27
27
def cached = minKind < CaseKind .Class
28
- def cachedValues = cases
29
28
end EnumConstraints
30
29
31
30
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
@@ -47,6 +46,11 @@ object DesugarEnums {
47
46
if (cls.is(Module )) cls.linkedClass else cls
48
47
}
49
48
49
+ def enumCompanion (using Context ): Symbol = {
50
+ val cls = ctx.owner
51
+ if (cls.is(Module )) cls.sourceModule else cls.linkedClass.sourceModule
52
+ }
53
+
50
54
/** Is `tree` an (untyped) enum case? */
51
55
def isEnumCase (tree : Tree )(using Context ): Boolean = tree match {
52
56
case tree : MemberDef => tree.mods.isEnumCase
@@ -109,14 +113,12 @@ object DesugarEnums {
109
113
* case _ => throw new IllegalArgumentException("case not found: " + $name)
110
114
* }
111
115
*/
112
- private def enumScaffolding (enumCases : List [( Int , TermName ) ])(using Context ): List [Tree ] = {
116
+ private def enumScaffolding (enumValues : List [RefTree ])(using Context ): List [Tree ] = {
113
117
val rawEnumClassRef = rawRef(enumClass.typeRef)
114
118
extension (tpe : NamedType ) def ofRawEnum = AppliedTypeTree (ref(tpe), rawEnumClassRef)
115
119
116
- val privateValuesDef =
117
- ValDef (nme.DOLLAR_VALUES , TypeTree (),
118
- ArrayLiteral (enumCases.map((_, name) => Ident (name)), rawEnumClassRef))
119
- .withFlags(Private | Synthetic )
120
+ val privateValuesDef = ValDef (nme.DOLLAR_VALUES , TypeTree (), ArrayLiteral (enumValues, rawEnumClassRef))
121
+ .withFlags(Private | Synthetic )
120
122
121
123
val valuesDef =
122
124
DefDef (nme.values, Nil , Nil , defn.ArrayType .ofRawEnum, valuesDot(nme.clone_))
@@ -127,8 +129,8 @@ object DesugarEnums {
127
129
val msg = Apply (Select (Literal (Constant (" enum case not found: " )), nme.PLUS ), Ident (nme.nameDollar))
128
130
CaseDef (Ident (nme.WILDCARD ), EmptyTree ,
129
131
Throw (New (TypeTree (defn.IllegalArgumentExceptionType ), List (msg :: Nil ))))
130
- val stringCases = enumCases .map((_, name) =>
131
- CaseDef (Literal (Constant (name.toString)), EmptyTree , Ident (name) )
132
+ val stringCases = enumValues .map(enumValue =>
133
+ CaseDef (Literal (Constant (enumValue. name.toString)), EmptyTree , enumValue )
132
134
) ::: defaultCase :: Nil
133
135
Match (Ident (nme.nameDollar), stringCases)
134
136
val valueOfDef = DefDef (nme.valueOf, Nil , List (param(nme.nameDollar, defn.StringType ) :: Nil ),
@@ -141,7 +143,7 @@ object DesugarEnums {
141
143
}
142
144
143
145
private def enumLookupMethods (constraints : EnumConstraints )(using Context ): List [Tree ] =
144
- def scaffolding : List [Tree ] = if constraints.cached then enumScaffolding(constraints.cachedValues ) else Nil
146
+ def scaffolding : List [Tree ] = if constraints.cached then enumScaffolding(constraints.enumCases.map(_._2) ) else Nil
145
147
def valueCtor : List [Tree ] = if constraints.requiresCreator then enumValueCreator :: Nil else Nil
146
148
def byOrdinal : List [Tree ] =
147
149
if isJavaEnum || ! constraints.cached then Nil
@@ -150,8 +152,8 @@ object DesugarEnums {
150
152
val ord = Ident (nme.ordinal)
151
153
val err = Throw (New (TypeTree (defn.IndexOutOfBoundsException .typeRef), List (Select (ord, nme.toString_) :: Nil )))
152
154
CaseDef (ord, EmptyTree , err)
153
- val valueCases = constraints.cachedValues .map((i, name ) =>
154
- CaseDef (Literal (Constant (i)), EmptyTree , Ident (name) )
155
+ val valueCases = constraints.enumCases .map((i, enumValue ) =>
156
+ CaseDef (Literal (Constant (i)), EmptyTree , enumValue )
155
157
) ::: defaultCase :: Nil
156
158
val fromOrdinalDef = DefDef (nme.fromOrdinalDollar, Nil , List (param(nme.ordinalDollar_, defn.IntType ) :: Nil ),
157
159
rawRef(enumClass.typeRef), Match (Ident (nme.ordinalDollar_), valueCases))
@@ -304,7 +306,9 @@ object DesugarEnums {
304
306
case name : TermName => (ordinal, name) :: seenCases
305
307
case _ => seenCases
306
308
if definesLookups then
307
- (ordinal, enumLookupMethods(EnumConstraints (minKind, maxKind, cases.reverse)))
309
+ val companionRef = ref(enumCompanion.termRef)
310
+ val cachedValues = cases.reverse.map((i, name) => (i, Select (companionRef, name)))
311
+ (ordinal, enumLookupMethods(EnumConstraints (minKind, maxKind, cachedValues)))
308
312
else
309
313
ctx.tree.pushAttachment(EnumCaseCount , (ordinal + 1 , minKind, maxKind, cases))
310
314
(ordinal, Nil )
@@ -313,7 +317,7 @@ object DesugarEnums {
313
317
def param (name : TermName , typ : Type )(using Context ): ValDef = param(name, TypeTree (typ))
314
318
def param (name : TermName , tpt : Tree )(using Context ): ValDef = ValDef (name, tpt, EmptyTree ).withFlags(Param )
315
319
316
- private def isJavaEnum (using Context ): Boolean = ctx.owner.linkedClass .derivesFrom(defn.JavaEnumClass )
320
+ private def isJavaEnum (using Context ): Boolean = enumClass .derivesFrom(defn.JavaEnumClass )
317
321
318
322
def ordinalMeth (body : Tree )(using Context ): DefDef =
319
323
DefDef (nme.ordinal, Nil , Nil , TypeTree (defn.IntType ), body)
0 commit comments