@@ -25,6 +25,8 @@ import scala.quoted.runtime.impl.printers.*
25
25
import scala .reflect .TypeTest
26
26
import dotty .tools .dotc .core .NameKinds .ExceptionBinderName
27
27
import dotty .tools .dotc .transform .TreeChecker
28
+ import dotty .tools .dotc .core .Names
29
+ import dotty .tools .dotc .util .Spans .NoCoord
28
30
29
31
object QuotesImpl {
30
32
@@ -241,9 +243,35 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
241
243
242
244
object ClassDef extends ClassDefModule :
243
245
def apply (cls : Symbol , parents : List [Tree ], body : List [Statement ]): ClassDef =
244
- val untpdCtr = untpd.DefDef (nme.CONSTRUCTOR , Nil , tpd.TypeTree (dotc.core.Symbols .defn.UnitClass .typeRef), tpd.EmptyTree )
246
+ val paramsDefs : List [untpd.ParamClause ] =
247
+ cls.primaryConstructor.paramSymss.map { paramSym =>
248
+ if paramSym.headOption.map(_.isType).getOrElse(false ) then
249
+ paramSym.map(sym => TypeDef (sym))
250
+ else
251
+ paramSym.map(ValDef (_, None ))
252
+ }
253
+ def throwError () =
254
+ throw new RuntimeException (
255
+ " Symbols necessary for creation of the ClassDef tree could not be found."
256
+ )
257
+ val paramsAccessDefs : List [untpd.ParamClause ] =
258
+ cls.primaryConstructor.paramSymss.map { paramSym =>
259
+ if paramSym.headOption.map(_.isType).getOrElse(false ) then
260
+ paramSym.map { symm =>
261
+ def isParamAccessor (memberSym : Symbol ) = memberSym.flags.is(Flags .Param ) && memberSym.name == symm.name
262
+ TypeDef (cls.typeMembers.find(isParamAccessor).getOrElse(throwError()))
263
+ }
264
+ else
265
+ paramSym.map { symm =>
266
+ def isParam (memberSym : Symbol ) = memberSym.flags.is(Flags .ParamAccessor ) && memberSym.name == symm.name
267
+ ValDef (cls.fieldMembers.find(isParam).getOrElse(throwError()), None )
268
+ }
269
+ }
270
+
271
+ val termSymbol : dotc.core.Symbols .TermSymbol = cls.primaryConstructor.asTerm
272
+ val untpdCtr = untpd.DefDef (nme.CONSTRUCTOR , paramsDefs, tpd.TypeTree (dotc.core.Symbols .defn.UnitClass .typeRef), tpd.EmptyTree )
245
273
val ctr = ctx.typeAssigner.assignType(untpdCtr, cls.primaryConstructor)
246
- tpd.ClassDefWithParents (cls.asClass, ctr, parents, body)
274
+ tpd.ClassDefWithParents (cls.asClass, ctr, parents, paramsAccessDefs.flatten ++ body)
247
275
248
276
def copy (original : Tree )(name : String , constr : DefDef , parents : List [Tree ], selfOpt : Option [ValDef ], body : List [Statement ]): ClassDef = {
249
277
val dotc .ast.Trees .TypeDef (_, originalImpl : tpd.Template ) = original : @ unchecked
@@ -2655,8 +2683,134 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
2655
2683
for sym <- decls(cls) do cls.enter(sym)
2656
2684
cls
2657
2685
2658
- def newModule (owner : Symbol , name : String , modFlags : Flags , clsFlags : Flags , parents : List [TypeRepr ], decls : Symbol => List [Symbol ], privateWithin : Symbol ): Symbol =
2659
- assert(parents.nonEmpty && ! parents.head.typeSymbol.is(dotc.core.Flags .Trait ), " First parent must be a class" )
2686
+ def newClass (
2687
+ owner : Symbol ,
2688
+ name : String ,
2689
+ parents : Symbol => List [TypeRepr ],
2690
+ decls : Symbol => List [Symbol ],
2691
+ selfType : Option [TypeRepr ],
2692
+ clsFlags : Flags ,
2693
+ clsPrivateWithin : Symbol ,
2694
+ conParams : List [(String , TypeRepr )]
2695
+ ): Symbol =
2696
+ val (conParamNames, conParamTypes) = conParams.unzip
2697
+ newClass(
2698
+ owner,
2699
+ name,
2700
+ parents,
2701
+ decls,
2702
+ selfType,
2703
+ clsFlags,
2704
+ clsPrivateWithin,
2705
+ Nil ,
2706
+ conMethodType = res => MethodType (conParamNames)(_ => conParamTypes, _ => res),
2707
+ conFlags = Flags .EmptyFlags ,
2708
+ conPrivateWithin = Symbol .noSymbol,
2709
+ conParamFlags = List (for i <- conParamNames yield Flags .EmptyFlags ),
2710
+ conParamPrivateWithins = List (for i <- conParamNames yield Symbol .noSymbol)
2711
+ )
2712
+
2713
+ def newClass (
2714
+ owner : Symbol ,
2715
+ name : String ,
2716
+ parents : Symbol => List [TypeRepr ],
2717
+ decls : Symbol => List [Symbol ],
2718
+ selfType : Option [TypeRepr ],
2719
+ clsFlags : Flags ,
2720
+ clsPrivateWithin : Symbol ,
2721
+ clsAnnotations : List [Term ],
2722
+ conMethodType : TypeRepr => MethodOrPoly ,
2723
+ conFlags : Flags ,
2724
+ conPrivateWithin : Symbol ,
2725
+ conParamFlags : List [List [Flags ]],
2726
+ conParamPrivateWithins : List [List [Symbol ]]
2727
+ ) =
2728
+ assert(! clsPrivateWithin.exists || clsPrivateWithin.isType, " clsPrivateWithin must be a type symbol or `Symbol.noSymbol`" )
2729
+ assert(! conPrivateWithin.exists || conPrivateWithin.isType, " consPrivateWithin must be a type symbol or `Symbol.noSymbol`" )
2730
+ checkValidFlags(clsFlags.toTypeFlags, Flags .validClassFlags)
2731
+ checkValidFlags(conFlags.toTermFlags, Flags .validClassConstructorFlags)
2732
+ val cls = dotc.core.Symbols .newNormalizedClassSymbol(
2733
+ owner,
2734
+ name.toTypeName,
2735
+ clsFlags,
2736
+ parents,
2737
+ selfType.getOrElse(Types .NoType ),
2738
+ clsPrivateWithin,
2739
+ clsAnnotations,
2740
+ NoCoord ,
2741
+ compUnitInfo = null
2742
+ )
2743
+ val methodType : MethodOrPoly = conMethodType(cls.typeRef)
2744
+ def throwShapeException () = throw new Exception (" Shapes of conMethodType and conParamFlags differ." )
2745
+ def checkMethodOrPolyShape (checkedMethodType : TypeRepr , clauseIdx : Int ): Unit =
2746
+ checkedMethodType match
2747
+ case PolyType (params, _, res) if clauseIdx == 0 =>
2748
+ if (conParamFlags.length < clauseIdx) throwShapeException()
2749
+ if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
2750
+ checkMethodOrPolyShape(res, clauseIdx + 1 )
2751
+ case PolyType (_, _, _) => throw new Exception (" Clause interleaving not supported for constructors" )
2752
+ case MethodType (params, _, res) =>
2753
+ if (conParamFlags.length <= clauseIdx) throwShapeException()
2754
+ if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
2755
+ checkMethodOrPolyShape(res, clauseIdx + 1 )
2756
+ case other =>
2757
+ xCheckMacroAssert(
2758
+ other.typeSymbol == cls,
2759
+ " Incorrect type returned from the innermost PolyOrMethod."
2760
+ )
2761
+ (other, methodType) match
2762
+ case (AppliedType (tycon, args), pt : PolyType ) =>
2763
+ xCheckMacroAssert(
2764
+ args.length == pt.typeParams.length &&
2765
+ args.zip(pt.typeParams).forall {
2766
+ case (arg, param) => arg == param.paramRef
2767
+ },
2768
+ " Constructor result type does not correspond to the declared type parameters"
2769
+ )
2770
+ case _ =>
2771
+ xCheckMacroAssert(
2772
+ ! (other.isInstanceOf [AppliedType ] || methodType.isInstanceOf [PolyType ]),
2773
+ " AppliedType has to be the innermost resultTypeExp result if and only if conMethodType returns a PolyType"
2774
+ )
2775
+ checkMethodOrPolyShape(methodType, clauseIdx = 0 )
2776
+
2777
+ cls.enter(dotc.core.Symbols .newSymbol(cls, nme.CONSTRUCTOR , Flags .Synthetic | Flags .Method | conFlags, methodType, conPrivateWithin, dotty.tools.dotc.util.Spans .NoCoord ))
2778
+
2779
+ case class ParamSymbolData (name : String , tpe : TypeRepr , isTypeParam : Boolean , clauseIdx : Int , elementIdx : Int )
2780
+ def getParamSymbolsData (methodType : TypeRepr , clauseIdx : Int ): List [ParamSymbolData ] =
2781
+ methodType match
2782
+ case MethodType (paramInfosExp, resultTypeExp, res) =>
2783
+ paramInfosExp.zip(resultTypeExp).zipWithIndex.map { case ((name, tpe), elementIdx) =>
2784
+ ParamSymbolData (name, tpe, isTypeParam = false , clauseIdx, elementIdx)
2785
+ } ++ getParamSymbolsData(res, clauseIdx + 1 )
2786
+ case pt @ PolyType (paramNames, paramBounds, res) =>
2787
+ paramNames.zip(paramBounds).zipWithIndex.map {case ((name, tpe), elementIdx) =>
2788
+ ParamSymbolData (name, tpe, isTypeParam = true , clauseIdx, elementIdx)
2789
+ } ++ getParamSymbolsData(res, clauseIdx + 1 )
2790
+ case result =>
2791
+ List ()
2792
+ // Maps PolyType indexes to type parameter symbol typerefs
2793
+ val paramRefMap = collection.mutable.HashMap [Int , Symbol ]()
2794
+ val paramRefRemapper = new Types .TypeMap {
2795
+ def apply (tp : Types .Type ) = tp match {
2796
+ case pRef : ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
2797
+ case _ => mapOver(tp)
2798
+ }
2799
+ }
2800
+ for case ParamSymbolData (name, tpe, isTypeParam, clauseIdx, elementIdx) <- getParamSymbolsData(methodType, 0 ) do
2801
+ if isTypeParam then
2802
+ checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTypeFlags, Flags .validClassTypeParamFlags)
2803
+ val symbol = dotc.core.Symbols .newSymbol(cls, name.toTypeName, Flags .Param | Flags .Deferred | Flags .Private | Flags .PrivateLocal | Flags .Local | conParamFlags(clauseIdx)(elementIdx), tpe, conParamPrivateWithins(clauseIdx)(elementIdx))
2804
+ paramRefMap.addOne(elementIdx, symbol)
2805
+ cls.enter(symbol)
2806
+ else
2807
+ checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTermFlags, Flags .validClassTermParamFlags)
2808
+ val fixedType = paramRefRemapper(tpe)
2809
+ cls.enter(dotc.core.Symbols .newSymbol(cls, name.toTermName, Flags .ParamAccessor | conParamFlags(clauseIdx)(elementIdx), fixedType, conParamPrivateWithins(clauseIdx)(elementIdx)))
2810
+ for sym <- decls(cls) do cls.enter(sym)
2811
+ cls
2812
+
2813
+ def newModule (owner : Symbol , name : String , modFlags : Flags , clsFlags : Flags , parents : Symbol => List [TypeRepr ], decls : Symbol => List [Symbol ], privateWithin : Symbol ): Symbol =
2660
2814
assert(! privateWithin.exists || privateWithin.isType, " privateWithin must be a type symbol or `Symbol.noSymbol`" )
2661
2815
val mod = dotc.core.Symbols .newNormalizedModuleSymbol(
2662
2816
owner,
@@ -2665,7 +2819,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
2665
2819
clsFlags | dotc.core.Flags .ModuleClassCreationFlags ,
2666
2820
parents,
2667
2821
dotc.core.Scopes .newScope,
2668
- privateWithin)
2822
+ privateWithin,
2823
+ NoCoord ,
2824
+ compUnitInfo = null
2825
+ )
2669
2826
val cls = mod.moduleClass.asClass
2670
2827
cls.enter(dotc.core.Symbols .newConstructor(cls, dotc.core.Flags .Synthetic , Nil , Nil ))
2671
2828
for sym <- decls(cls) do cls.enter(sym)
@@ -3063,6 +3220,18 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
3063
3220
// Keep: aligned with Quotes's `newTypeAlias` doc
3064
3221
private [QuotesImpl ] def validTypeAliasFlags : Flags = Private | Protected | Override | Final | Infix | Local
3065
3222
3223
+ // Keep: aligned with Quotes's `newClass`
3224
+ private [QuotesImpl ] def validClassFlags : Flags = Private | Protected | PrivateLocal | Local | Final | Trait | Abstract | Open
3225
+
3226
+ // Keep: aligned with Quote's 'newClass'
3227
+ private [QuotesImpl ] def validClassConstructorFlags : Flags = Synthetic | Method | Private | Protected | PrivateLocal | Local
3228
+
3229
+ // Keep: aligned with Quotes's `newClass`
3230
+ private [QuotesImpl ] def validClassTypeParamFlags : Flags = Param | Deferred | Private | PrivateLocal | Local
3231
+
3232
+ // Keep: aligned with Quotes's `newClass`
3233
+ private [QuotesImpl ] def validClassTermParamFlags : Flags = ParamAccessor | Private | Protected | PrivateLocal | Local
3234
+
3066
3235
end Flags
3067
3236
3068
3237
given FlagsMethods : FlagsMethods with
0 commit comments