@@ -20,10 +20,17 @@ object DesugarEnums {
20
20
val Simple, Object, Class : Value = Value
21
21
}
22
22
23
+ final case class EnumConstraints (minKind : CaseKind .Value , maxKind : CaseKind .Value , enumCases : List [(Int , RefTree )]):
24
+ require(minKind <= maxKind && ! (cached && enumCases.isEmpty))
25
+ def requiresCreator = minKind == CaseKind .Simple
26
+ def isEnumeration = maxKind < CaseKind .Class
27
+ def cached = minKind < CaseKind .Class
28
+ end EnumConstraints
29
+
23
30
/** Attachment containing the number of enum cases, the smallest kind that was seen so far,
24
31
* and a list of all the value cases with their ordinals.
25
32
*/
26
- val EnumCaseCount : Property .Key [(Int , CaseKind .Value , List [(Int , TermName )])] = Property .Key ()
33
+ val EnumCaseCount : Property .Key [(Int , CaseKind .Value , CaseKind . Value , List [(Int , TermName )])] = Property .Key ()
27
34
28
35
/** Attachment signalling that when this definition is desugared, it should add any additional
29
36
* lookup methods for enums.
@@ -39,6 +46,11 @@ object DesugarEnums {
39
46
if (cls.is(Module )) cls.linkedClass else cls
40
47
}
41
48
49
+ def enumCompanion (using Context ): Symbol = {
50
+ val cls = ctx.owner
51
+ if (cls.is(Module )) cls.sourceModule else cls.linkedClass.sourceModule
52
+ }
53
+
42
54
/** Is `tree` an (untyped) enum case? */
43
55
def isEnumCase (tree : Tree )(using Context ): Boolean = tree match {
44
56
case tree : MemberDef => tree.mods.isEnumCase
@@ -84,65 +96,73 @@ object DesugarEnums {
84
96
private def valuesDot (name : PreName )(implicit src : SourceFile ) =
85
97
Select (Ident (nme.DOLLAR_VALUES ), name.toTermName)
86
98
87
- private def registerCall (using Context ): Tree =
88
- Apply (valuesDot(" register" ), This (EmptyTypeIdent ) :: Nil )
99
+ private def ArrayLiteral (values : List [Tree ], tpt : Tree )(using Context ): Tree =
100
+ val clazzOf = TypeApply (ref(defn.Predef_classOf .termRef), tpt :: Nil )
101
+ val ctag = Apply (TypeApply (ref(defn.ClassTagModule_apply .termRef), tpt :: Nil ), clazzOf :: Nil )
102
+ val apply = Select (ref(defn.ArrayModule .termRef), nme.apply)
103
+ Apply (Apply (TypeApply (apply, tpt :: Nil ), values), ctag :: Nil )
89
104
90
- /** The following lists of definitions for an enum type E:
105
+ /** The following lists of definitions for an enum type E and known value cases e_0, ..., e_n :
91
106
*
92
- * private val $values = new EnumValues [E]
93
- * def values = $values.values.toArray
94
- * def valueOf($name: String) =
95
- * try $values.fromName($name) catch
96
- * {
97
- * case ex$:NoSuchElementException =>
98
- * throw new IllegalArgumentException("key not found: ".concat( name) )
99
- * }
107
+ * private val $values = Array [E](e_0,...,e_n)(ClassTag[E](classOf[E]))
108
+ * def values = $values.clone
109
+ * def valueOf($name: String) = $name match {
110
+ * case "e_0" => e_0
111
+ * ...
112
+ * case "e_n" => e_n
113
+ * case _ => throw new IllegalArgumentException("case not found: " + $ name)
114
+ * }
100
115
*/
101
- private def enumScaffolding (using Context ): List [Tree ] = {
116
+ private def enumScaffolding (enumValues : List [ RefTree ])( using Context ): List [Tree ] = {
102
117
val rawEnumClassRef = rawRef(enumClass.typeRef)
103
118
extension (tpe : NamedType ) def ofRawEnum = AppliedTypeTree (ref(tpe), rawEnumClassRef)
119
+
120
+ val lazyFlagOpt = if enumCompanion.owner.isStatic then EmptyFlags else Lazy
121
+ val privateValuesDef = ValDef (nme.DOLLAR_VALUES , TypeTree (), ArrayLiteral (enumValues, rawEnumClassRef))
122
+ .withFlags(Private | Synthetic | lazyFlagOpt)
123
+
104
124
val valuesDef =
105
- DefDef (nme.values, Nil , Nil , defn.ArrayType .ofRawEnum, Select ( valuesDot(nme.values), nme.toArray ))
125
+ DefDef (nme.values, Nil , Nil , defn.ArrayType .ofRawEnum, valuesDot(nme.clone_ ))
106
126
.withFlags(Synthetic )
107
- val privateValuesDef =
108
- ValDef (nme.DOLLAR_VALUES , TypeTree (), New (defn.EnumValuesClass .typeRef.ofRawEnum, ListOfNil ))
109
- .withFlags(Private | Synthetic )
110
-
111
- val valuesOfExnMessage = Apply (
112
- Select (Literal (Constant (" key not found: " )), " concat" .toTermName),
113
- Ident (nme.nameDollar) :: Nil )
114
- val valuesOfBody = Try (
115
- expr = Apply (valuesDot(" fromName" ), Ident (nme.nameDollar) :: Nil ),
116
- cases = CaseDef (
117
- pat = Typed (Ident (nme.DEFAULT_EXCEPTION_NAME ), TypeTree (defn.NoSuchElementExceptionType )),
118
- guard = EmptyTree ,
119
- body = Throw (New (TypeTree (defn.IllegalArgumentExceptionType ), List (valuesOfExnMessage :: Nil )))
120
- ) :: Nil ,
121
- finalizer = EmptyTree
122
- )
127
+
128
+ val valuesOfBody : Tree =
129
+ val defaultCase =
130
+ val msg = Apply (Select (Literal (Constant (" enum case not found: " )), nme.PLUS ), Ident (nme.nameDollar))
131
+ CaseDef (Ident (nme.WILDCARD ), EmptyTree ,
132
+ Throw (New (TypeTree (defn.IllegalArgumentExceptionType ), List (msg :: Nil ))))
133
+ val stringCases = enumValues.map(enumValue =>
134
+ CaseDef (Literal (Constant (enumValue.name.toString)), EmptyTree , enumValue)
135
+ ) ::: defaultCase :: Nil
136
+ Match (Ident (nme.nameDollar), stringCases)
123
137
val valueOfDef = DefDef (nme.valueOf, Nil , List (param(nme.nameDollar, defn.StringType ) :: Nil ),
124
138
TypeTree (), valuesOfBody)
125
139
.withFlags(Synthetic )
126
140
127
- valuesDef ::
128
141
privateValuesDef ::
142
+ valuesDef ::
129
143
valueOfDef :: Nil
130
144
}
131
145
132
- private def enumLookupMethods (cases : List [(Int , TermName )])(using Context ): List [Tree ] =
133
- if isJavaEnum || cases.isEmpty then Nil
134
- else
135
- val defaultCase =
136
- val ord = Ident (nme.ordinal)
137
- val err = Throw (New (TypeTree (defn.IndexOutOfBoundsException .typeRef), List (Select (ord, nme.toString_) :: Nil )))
138
- CaseDef (ord, EmptyTree , err)
139
- val valueCases = cases.map((i, name) =>
140
- CaseDef (Literal (Constant (i)), EmptyTree , Ident (name))
141
- ) ::: defaultCase :: Nil
142
- val fromOrdinalDef = DefDef (nme.fromOrdinalDollar, Nil , List (param(nme.ordinalDollar_, defn.IntType ) :: Nil ),
143
- rawRef(enumClass.typeRef), Match (Ident (nme.ordinalDollar_), valueCases))
144
- .withFlags(Synthetic | Private )
145
- fromOrdinalDef :: Nil
146
+ private def enumLookupMethods (constraints : EnumConstraints )(using Context ): List [Tree ] =
147
+ def scaffolding : List [Tree ] = if constraints.cached then enumScaffolding(constraints.enumCases.map(_._2)) else Nil
148
+ def valueCtor : List [Tree ] = if constraints.requiresCreator then enumValueCreator :: Nil else Nil
149
+ def byOrdinal : List [Tree ] =
150
+ if isJavaEnum || ! constraints.cached then Nil
151
+ else
152
+ val defaultCase =
153
+ val ord = Ident (nme.ordinal)
154
+ val err = Throw (New (TypeTree (defn.IndexOutOfBoundsException .typeRef), List (Select (ord, nme.toString_) :: Nil )))
155
+ CaseDef (ord, EmptyTree , err)
156
+ val valueCases = constraints.enumCases.map((i, enumValue) =>
157
+ CaseDef (Literal (Constant (i)), EmptyTree , enumValue)
158
+ ) ::: defaultCase :: Nil
159
+ val fromOrdinalDef = DefDef (nme.fromOrdinalDollar, Nil , List (param(nme.ordinalDollar_, defn.IntType ) :: Nil ),
160
+ rawRef(enumClass.typeRef), Match (Ident (nme.ordinalDollar_), valueCases))
161
+ .withFlags(Synthetic | Private )
162
+ fromOrdinalDef :: Nil
163
+
164
+ scaffolding ::: valueCtor ::: byOrdinal
165
+ end enumLookupMethods
146
166
147
167
/** A creation method for a value of enum type `E`, which is defined as follows:
148
168
*
@@ -167,7 +187,7 @@ object DesugarEnums {
167
187
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue ) :: Nil ,
168
188
derived = Nil ,
169
189
self = EmptyValDef ,
170
- body = fieldMethods ::: registerCall :: Nil
190
+ body = fieldMethods
171
191
).withAttachment(ExtendsSingletonMirror , ()))
172
192
DefDef (nme.DOLLAR_NEW , Nil ,
173
193
List (List (param(nme.ordinalDollar_, defn.IntType ), param(nme.nameDollar, defn.StringType ))),
@@ -279,27 +299,26 @@ object DesugarEnums {
279
299
* unless that scaffolding was already generated by a previous call to `nextEnumKind`.
280
300
*/
281
301
def nextOrdinal (name : Name , kind : CaseKind .Value , definesLookups : Boolean )(using Context ): (Int , List [Tree ]) = {
282
- val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount ).getOrElse((0 , CaseKind .Class , Nil ))
283
- val minKind = if kind < seenKind then kind else seenKind
302
+ val (ordinal, seenMinKind, seenMaxKind, seenCases) =
303
+ ctx.tree.removeAttachment(EnumCaseCount ).getOrElse((0 , CaseKind .Class , CaseKind .Simple , Nil ))
304
+ val minKind = if kind < seenMinKind then kind else seenMinKind
305
+ val maxKind = if kind > seenMaxKind then kind else seenMaxKind
284
306
val cases = name match
285
307
case name : TermName => (ordinal, name) :: seenCases
286
308
case _ => seenCases
287
- ctx.tree.pushAttachment(EnumCaseCount , (ordinal + 1 , minKind, cases))
288
- val scaffolding0 =
289
- if (kind >= seenKind) Nil
290
- else if (kind == CaseKind .Object ) enumScaffolding
291
- else if (seenKind == CaseKind .Object ) enumValueCreator :: Nil
292
- else enumScaffolding :+ enumValueCreator
293
- val scaffolding =
294
- if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse)
295
- else scaffolding0
296
- (ordinal, scaffolding)
309
+ if definesLookups then
310
+ val companionRef = ref(enumCompanion.termRef)
311
+ val cachedValues = cases.reverse.map((i, name) => (i, Select (companionRef, name)))
312
+ (ordinal, enumLookupMethods(EnumConstraints (minKind, maxKind, cachedValues)))
313
+ else
314
+ ctx.tree.pushAttachment(EnumCaseCount , (ordinal + 1 , minKind, maxKind, cases))
315
+ (ordinal, Nil )
297
316
}
298
317
299
- def param (name : TermName , typ : Type )(using Context ) =
300
- ValDef (name, TypeTree (typ) , EmptyTree ).withFlags(Param )
318
+ def param (name : TermName , typ : Type )(using Context ): ValDef = param(name, TypeTree (typ))
319
+ def param (name : TermName , tpt : Tree )( using Context ) : ValDef = ValDef (name, tpt , EmptyTree ).withFlags(Param )
301
320
302
- private def isJavaEnum (using Context ): Boolean = ctx.owner.linkedClass .derivesFrom(defn.JavaEnumClass )
321
+ private def isJavaEnum (using Context ): Boolean = enumClass .derivesFrom(defn.JavaEnumClass )
303
322
304
323
def ordinalMeth (body : Tree )(using Context ): DefDef =
305
324
DefDef (nme.ordinal, Nil , Nil , TypeTree (defn.IntType ), body)
@@ -325,10 +344,10 @@ object DesugarEnums {
325
344
val enumLabelDef = enumLabelLit(name.toString)
326
345
val impl1 = cpy.Template (impl)(
327
346
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue ),
328
- body = ordinalDef ::: enumLabelDef :: registerCall :: Nil
347
+ body = ordinalDef ::: enumLabelDef :: Nil
329
348
).withAttachment(ExtendsSingletonMirror , ())
330
349
val vdef = ValDef (name, TypeTree (), New (impl1)).withMods(mods.withAddedFlags(EnumValue , span))
331
- flatTree(scaffolding ::: vdef :: Nil ).withSpan(span)
350
+ flatTree(vdef :: scaffolding ).withSpan(span)
332
351
}
333
352
}
334
353
@@ -344,6 +363,6 @@ object DesugarEnums {
344
363
val (tag, scaffolding) = nextOrdinal(name, CaseKind .Simple , definesLookups)
345
364
val creator = Apply (Ident (nme.DOLLAR_NEW ), List (Literal (Constant (tag)), Literal (Constant (name.toString))))
346
365
val vdef = ValDef (name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue , span))
347
- flatTree(scaffolding ::: vdef :: Nil ).withSpan(span)
366
+ flatTree(vdef :: scaffolding ).withSpan(span)
348
367
}
349
368
}
0 commit comments