Skip to content

Commit 9e7aab7

Browse files
authored
Add class parameters, flags, and privateWithin and annotations to newClass in reflect API (#21880)
Instead of replacing the one newMethod we have, we instead add two more with varying complexity (similarly to how newMethod is handled). This is also so we can keep the initial newClass implementation (the one creating newClass with public empty primary constructor) intact, which despite being experiemental - already sees use in libraries and projects. Fixes #21739 and addresses some old TODOs (from the stdlibExperimentalDefinitions.scala file).
1 parent ff8451a commit 9e7aab7

File tree

41 files changed

+1029
-45
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1029
-45
lines changed

compiler/src/dotty/tools/dotc/core/Symbols.scala

+54
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,32 @@ object Symbols extends SymUtils {
641641
newClassSymbol(owner, name, flags, completer, privateWithin, coord, compUnitInfo)
642642
}
643643

644+
/** Same as the other `newNormalizedClassSymbol` except that `parents` can be a function returning a list of arbitrary
645+
* types which get normalized into type refs and parameter bindings and annotations can be assigned in the completer.
646+
*/
647+
def newNormalizedClassSymbol(
648+
owner: Symbol,
649+
name: TypeName,
650+
flags: FlagSet,
651+
parentTypes: Symbol => List[Type],
652+
selfInfo: Type,
653+
privateWithin: Symbol,
654+
annotations: List[Tree],
655+
coord: Coord,
656+
compUnitInfo: CompilationUnitInfo | Null)(using Context): ClassSymbol = {
657+
def completer = new LazyType {
658+
def complete(denot: SymDenotation)(using Context): Unit = {
659+
val cls = denot.asClass.classSymbol
660+
val decls = newScope
661+
val parents = parentTypes(cls).map(_.dealias)
662+
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
663+
denot.info = ClassInfo(owner.thisType, cls, parents, decls, selfInfo)
664+
denot.annotations = annotations.map(Annotations.Annotation(_))
665+
}
666+
}
667+
newClassSymbol(owner, name, flags, completer, privateWithin, coord, compUnitInfo)
668+
}
669+
644670
def newRefinedClassSymbol(coord: Coord = NoCoord)(using Context): ClassSymbol =
645671
newCompleteClassSymbol(ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil, newScope, coord = coord)
646672

@@ -718,6 +744,34 @@ object Symbols extends SymUtils {
718744
privateWithin, coord, compUnitInfo)
719745
}
720746

747+
/** Same as `newNormalizedModuleSymbol` except that `parents` can be a function returning a list of arbitrary
748+
* types which get normalized into type refs and parameter bindings.
749+
*/
750+
def newNormalizedModuleSymbol(
751+
owner: Symbol,
752+
name: TermName,
753+
modFlags: FlagSet,
754+
clsFlags: FlagSet,
755+
parentTypes: ClassSymbol => List[Type],
756+
decls: Scope,
757+
privateWithin: Symbol,
758+
coord: Coord,
759+
compUnitInfo: CompilationUnitInfo | Null)(using Context): TermSymbol = {
760+
def completer(module: Symbol) = new LazyType {
761+
def complete(denot: SymDenotation)(using Context): Unit = {
762+
val cls = denot.asClass.classSymbol
763+
val decls = newScope
764+
val parents = parentTypes(cls).map(_.dealias)
765+
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
766+
denot.info = ClassInfo(owner.thisType, cls, parents, decls, TermRef(owner.thisType, module))
767+
}
768+
}
769+
newModuleSymbol(
770+
owner, name, modFlags, clsFlags,
771+
(module, modcls) => completer(module),
772+
privateWithin, coord, compUnitInfo)
773+
}
774+
721775
/** Create a package symbol with associated package class
722776
* from its non-info fields and a lazy type for loading the package's members.
723777
*/

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

+25-21
Original file line numberDiff line numberDiff line change
@@ -854,30 +854,34 @@ object TreeChecker {
854854
val phases = ctx.base.allPhases.toList
855855
val treeChecker = new LocalChecker(previousPhases(phases))
856856

857+
def reportMalformedMacroTree(msg: String | Null, err: Throwable) =
858+
val stack =
859+
if !ctx.settings.Ydebug.value then "\nstacktrace available when compiling with `-Ydebug`"
860+
else if err.getStackTrace == null then " no stacktrace"
861+
else err.getStackTrace.nn.mkString(" ", " \n", "")
862+
report.error(
863+
em"""Malformed tree was found while expanding macro with -Xcheck-macros.
864+
|The tree does not conform to the compiler's tree invariants.
865+
|
866+
|Macro was:
867+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
868+
|
869+
|The macro returned:
870+
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
871+
|
872+
|Error:
873+
|$msg
874+
|$stack
875+
|""",
876+
original
877+
)
878+
857879
try treeChecker.typed(expansion)(using checkingCtx)
858880
catch
859881
case err: java.lang.AssertionError =>
860-
val stack =
861-
if !ctx.settings.Ydebug.value then "\nstacktrace available when compiling with `-Ydebug`"
862-
else if err.getStackTrace == null then " no stacktrace"
863-
else err.getStackTrace.nn.mkString(" ", " \n", "")
864-
865-
report.error(
866-
em"""Malformed tree was found while expanding macro with -Xcheck-macros.
867-
|The tree does not conform to the compiler's tree invariants.
868-
|
869-
|Macro was:
870-
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(original)}
871-
|
872-
|The macro returned:
873-
|${scala.quoted.runtime.impl.QuotesImpl.showDecompiledTree(expansion)}
874-
|
875-
|Error:
876-
|${err.getMessage}
877-
|$stack
878-
|""",
879-
original
880-
)
882+
reportMalformedMacroTree(err.getMessage(), err)
883+
case err: UnhandledError =>
884+
reportMalformedMacroTree(err.diagnostic.message, err)
881885

882886
private[TreeChecker] def previousPhases(phases: List[Phase])(using Context): List[Phase] = phases match {
883887
case (phase: MegaPhase) :: phases1 =>

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import collection.mutable
1212
import reporting.*
1313
import Checking.{checkNoPrivateLeaks, checkNoWildcard}
1414
import cc.CaptureSet
15+
import transform.Splicer
1516

1617
trait TypeAssigner {
1718
import tpd.*
@@ -301,7 +302,10 @@ trait TypeAssigner {
301302
if fntpe.isResultDependent then safeSubstMethodParams(fntpe, args.tpes)
302303
else fntpe.resultType // fast path optimization
303304
else
304-
errorType(em"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
305+
val erroringPhase =
306+
if Splicer.inMacroExpansion then i"${ctx.phase} (while expanding macro)"
307+
else ctx.phase.prev.toString
308+
errorType(em"wrong number of arguments at $erroringPhase for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
305309
case err: ErrorType =>
306310
err
307311
case t =>

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

+174-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import scala.quoted.runtime.impl.printers.*
2525
import scala.reflect.TypeTest
2626
import dotty.tools.dotc.core.NameKinds.ExceptionBinderName
2727
import dotty.tools.dotc.transform.TreeChecker
28+
import dotty.tools.dotc.core.Names
29+
import dotty.tools.dotc.util.Spans.NoCoord
2830

2931
object QuotesImpl {
3032

@@ -241,9 +243,35 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
241243

242244
object ClassDef extends ClassDefModule:
243245
def apply(cls: Symbol, parents: List[Tree], body: List[Statement]): ClassDef =
244-
val untpdCtr = untpd.DefDef(nme.CONSTRUCTOR, Nil, tpd.TypeTree(dotc.core.Symbols.defn.UnitClass.typeRef), tpd.EmptyTree)
246+
val paramsDefs: List[untpd.ParamClause] =
247+
cls.primaryConstructor.paramSymss.map { paramSym =>
248+
if paramSym.headOption.map(_.isType).getOrElse(false) then
249+
paramSym.map(sym => TypeDef(sym))
250+
else
251+
paramSym.map(ValDef(_, None))
252+
}
253+
def throwError() =
254+
throw new RuntimeException(
255+
"Symbols necessary for creation of the ClassDef tree could not be found."
256+
)
257+
val paramsAccessDefs: List[untpd.ParamClause] =
258+
cls.primaryConstructor.paramSymss.map { paramSym =>
259+
if paramSym.headOption.map(_.isType).getOrElse(false) then
260+
paramSym.map { symm =>
261+
def isParamAccessor(memberSym: Symbol) = memberSym.flags.is(Flags.Param) && memberSym.name == symm.name
262+
TypeDef(cls.typeMembers.find(isParamAccessor).getOrElse(throwError()))
263+
}
264+
else
265+
paramSym.map { symm =>
266+
def isParam(memberSym: Symbol) = memberSym.flags.is(Flags.ParamAccessor) && memberSym.name == symm.name
267+
ValDef(cls.fieldMembers.find(isParam).getOrElse(throwError()), None)
268+
}
269+
}
270+
271+
val termSymbol: dotc.core.Symbols.TermSymbol = cls.primaryConstructor.asTerm
272+
val untpdCtr = untpd.DefDef(nme.CONSTRUCTOR, paramsDefs, tpd.TypeTree(dotc.core.Symbols.defn.UnitClass.typeRef), tpd.EmptyTree)
245273
val ctr = ctx.typeAssigner.assignType(untpdCtr, cls.primaryConstructor)
246-
tpd.ClassDefWithParents(cls.asClass, ctr, parents, body)
274+
tpd.ClassDefWithParents(cls.asClass, ctr, parents, paramsAccessDefs.flatten ++ body)
247275

248276
def copy(original: Tree)(name: String, constr: DefDef, parents: List[Tree], selfOpt: Option[ValDef], body: List[Statement]): ClassDef = {
249277
val dotc.ast.Trees.TypeDef(_, originalImpl: tpd.Template) = original: @unchecked
@@ -2655,8 +2683,134 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
26552683
for sym <- decls(cls) do cls.enter(sym)
26562684
cls
26572685

2658-
def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
2659-
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
2686+
def newClass(
2687+
owner: Symbol,
2688+
name: String,
2689+
parents: Symbol => List[TypeRepr],
2690+
decls: Symbol => List[Symbol],
2691+
selfType: Option[TypeRepr],
2692+
clsFlags: Flags,
2693+
clsPrivateWithin: Symbol,
2694+
conParams: List[(String, TypeRepr)]
2695+
): Symbol =
2696+
val (conParamNames, conParamTypes) = conParams.unzip
2697+
newClass(
2698+
owner,
2699+
name,
2700+
parents,
2701+
decls,
2702+
selfType,
2703+
clsFlags,
2704+
clsPrivateWithin,
2705+
Nil,
2706+
conMethodType = res => MethodType(conParamNames)(_ => conParamTypes, _ => res),
2707+
conFlags = Flags.EmptyFlags,
2708+
conPrivateWithin = Symbol.noSymbol,
2709+
conParamFlags = List(for i <- conParamNames yield Flags.EmptyFlags),
2710+
conParamPrivateWithins = List(for i <- conParamNames yield Symbol.noSymbol)
2711+
)
2712+
2713+
def newClass(
2714+
owner: Symbol,
2715+
name: String,
2716+
parents: Symbol => List[TypeRepr],
2717+
decls: Symbol => List[Symbol],
2718+
selfType: Option[TypeRepr],
2719+
clsFlags: Flags,
2720+
clsPrivateWithin: Symbol,
2721+
clsAnnotations: List[Term],
2722+
conMethodType: TypeRepr => MethodOrPoly,
2723+
conFlags: Flags,
2724+
conPrivateWithin: Symbol,
2725+
conParamFlags: List[List[Flags]],
2726+
conParamPrivateWithins: List[List[Symbol]]
2727+
) =
2728+
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
2729+
assert(!conPrivateWithin.exists || conPrivateWithin.isType, "consPrivateWithin must be a type symbol or `Symbol.noSymbol`")
2730+
checkValidFlags(clsFlags.toTypeFlags, Flags.validClassFlags)
2731+
checkValidFlags(conFlags.toTermFlags, Flags.validClassConstructorFlags)
2732+
val cls = dotc.core.Symbols.newNormalizedClassSymbol(
2733+
owner,
2734+
name.toTypeName,
2735+
clsFlags,
2736+
parents,
2737+
selfType.getOrElse(Types.NoType),
2738+
clsPrivateWithin,
2739+
clsAnnotations,
2740+
NoCoord,
2741+
compUnitInfo = null
2742+
)
2743+
val methodType: MethodOrPoly = conMethodType(cls.typeRef)
2744+
def throwShapeException() = throw new Exception("Shapes of conMethodType and conParamFlags differ.")
2745+
def checkMethodOrPolyShape(checkedMethodType: TypeRepr, clauseIdx: Int): Unit =
2746+
checkedMethodType match
2747+
case PolyType(params, _, res) if clauseIdx == 0 =>
2748+
if (conParamFlags.length < clauseIdx) throwShapeException()
2749+
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
2750+
checkMethodOrPolyShape(res, clauseIdx + 1)
2751+
case PolyType(_, _, _) => throw new Exception("Clause interleaving not supported for constructors")
2752+
case MethodType(params, _, res) =>
2753+
if (conParamFlags.length <= clauseIdx) throwShapeException()
2754+
if (conParamFlags(clauseIdx).length != params.length) throwShapeException()
2755+
checkMethodOrPolyShape(res, clauseIdx + 1)
2756+
case other =>
2757+
xCheckMacroAssert(
2758+
other.typeSymbol == cls,
2759+
"Incorrect type returned from the innermost PolyOrMethod."
2760+
)
2761+
(other, methodType) match
2762+
case (AppliedType(tycon, args), pt: PolyType) =>
2763+
xCheckMacroAssert(
2764+
args.length == pt.typeParams.length &&
2765+
args.zip(pt.typeParams).forall {
2766+
case (arg, param) => arg == param.paramRef
2767+
},
2768+
"Constructor result type does not correspond to the declared type parameters"
2769+
)
2770+
case _ =>
2771+
xCheckMacroAssert(
2772+
!(other.isInstanceOf[AppliedType] || methodType.isInstanceOf[PolyType]),
2773+
"AppliedType has to be the innermost resultTypeExp result if and only if conMethodType returns a PolyType"
2774+
)
2775+
checkMethodOrPolyShape(methodType, clauseIdx = 0)
2776+
2777+
cls.enter(dotc.core.Symbols.newSymbol(cls, nme.CONSTRUCTOR, Flags.Synthetic | Flags.Method | conFlags, methodType, conPrivateWithin, dotty.tools.dotc.util.Spans.NoCoord))
2778+
2779+
case class ParamSymbolData(name: String, tpe: TypeRepr, isTypeParam: Boolean, clauseIdx: Int, elementIdx: Int)
2780+
def getParamSymbolsData(methodType: TypeRepr, clauseIdx: Int): List[ParamSymbolData] =
2781+
methodType match
2782+
case MethodType(paramInfosExp, resultTypeExp, res) =>
2783+
paramInfosExp.zip(resultTypeExp).zipWithIndex.map { case ((name, tpe), elementIdx) =>
2784+
ParamSymbolData(name, tpe, isTypeParam = false, clauseIdx, elementIdx)
2785+
} ++ getParamSymbolsData(res, clauseIdx + 1)
2786+
case pt @ PolyType(paramNames, paramBounds, res) =>
2787+
paramNames.zip(paramBounds).zipWithIndex.map {case ((name, tpe), elementIdx) =>
2788+
ParamSymbolData(name, tpe, isTypeParam = true, clauseIdx, elementIdx)
2789+
} ++ getParamSymbolsData(res, clauseIdx + 1)
2790+
case result =>
2791+
List()
2792+
// Maps PolyType indexes to type parameter symbol typerefs
2793+
val paramRefMap = collection.mutable.HashMap[Int, Symbol]()
2794+
val paramRefRemapper = new Types.TypeMap {
2795+
def apply(tp: Types.Type) = tp match {
2796+
case pRef: ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
2797+
case _ => mapOver(tp)
2798+
}
2799+
}
2800+
for case ParamSymbolData(name, tpe, isTypeParam, clauseIdx, elementIdx) <- getParamSymbolsData(methodType, 0) do
2801+
if isTypeParam then
2802+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTypeFlags, Flags.validClassTypeParamFlags)
2803+
val symbol = dotc.core.Symbols.newSymbol(cls, name.toTypeName, Flags.Param | Flags.Deferred | Flags.Private | Flags.PrivateLocal | Flags.Local | conParamFlags(clauseIdx)(elementIdx), tpe, conParamPrivateWithins(clauseIdx)(elementIdx))
2804+
paramRefMap.addOne(elementIdx, symbol)
2805+
cls.enter(symbol)
2806+
else
2807+
checkValidFlags(conParamFlags(clauseIdx)(elementIdx).toTermFlags, Flags.validClassTermParamFlags)
2808+
val fixedType = paramRefRemapper(tpe)
2809+
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor | conParamFlags(clauseIdx)(elementIdx), fixedType, conParamPrivateWithins(clauseIdx)(elementIdx)))
2810+
for sym <- decls(cls) do cls.enter(sym)
2811+
cls
2812+
2813+
def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: Symbol => List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
26602814
assert(!privateWithin.exists || privateWithin.isType, "privateWithin must be a type symbol or `Symbol.noSymbol`")
26612815
val mod = dotc.core.Symbols.newNormalizedModuleSymbol(
26622816
owner,
@@ -2665,7 +2819,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
26652819
clsFlags | dotc.core.Flags.ModuleClassCreationFlags,
26662820
parents,
26672821
dotc.core.Scopes.newScope,
2668-
privateWithin)
2822+
privateWithin,
2823+
NoCoord,
2824+
compUnitInfo = null
2825+
)
26692826
val cls = mod.moduleClass.asClass
26702827
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, Nil, Nil))
26712828
for sym <- decls(cls) do cls.enter(sym)
@@ -3063,6 +3220,18 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
30633220
// Keep: aligned with Quotes's `newTypeAlias` doc
30643221
private[QuotesImpl] def validTypeAliasFlags: Flags = Private | Protected | Override | Final | Infix | Local
30653222

3223+
// Keep: aligned with Quotes's `newClass`
3224+
private[QuotesImpl] def validClassFlags: Flags = Private | Protected | PrivateLocal | Local | Final | Trait | Abstract | Open
3225+
3226+
// Keep: aligned with Quote's 'newClass'
3227+
private[QuotesImpl] def validClassConstructorFlags: Flags = Synthetic | Method | Private | Protected | PrivateLocal | Local
3228+
3229+
// Keep: aligned with Quotes's `newClass`
3230+
private[QuotesImpl] def validClassTypeParamFlags: Flags = Param | Deferred | Private | PrivateLocal | Local
3231+
3232+
// Keep: aligned with Quotes's `newClass`
3233+
private[QuotesImpl] def validClassTermParamFlags: Flags = ParamAccessor | Private | Protected | PrivateLocal | Local
3234+
30663235
end Flags
30673236

30683237
given FlagsMethods: FlagsMethods with

compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -1379,13 +1379,13 @@ object SourceCode {
13791379
printTypeTree(bounds.low)
13801380
else
13811381
bounds.low match {
1382-
case Inferred() =>
1382+
case Inferred() if bounds.low.tpe.typeSymbol == TypeRepr.of[Nothing].typeSymbol =>
13831383
case low =>
13841384
this += " >: "
13851385
printTypeTree(low)
13861386
}
13871387
bounds.hi match {
1388-
case Inferred() => this
1388+
case Inferred() if bounds.hi.tpe.typeSymbol == TypeRepr.of[Any].typeSymbol => this
13891389
case hi =>
13901390
this += " <: "
13911391
printTypeTree(hi)

0 commit comments

Comments
 (0)