Skip to content

Commit c5bb595

Browse files
authored
Merge pull request #15874 from dotty-staging/mirror-for-inner-classes-backport
[Backport] Support Mirrors for local and inner classes.
2 parents 8817b38 + 3974e4a commit c5bb595

32 files changed

+1779
-94
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+69
Original file line numberDiff line numberDiff line change
@@ -855,4 +855,73 @@ object TypeOps:
855855
def stripTypeVars(tp: Type)(using Context): Type =
856856
new StripTypeVarsMap().apply(tp)
857857

858+
/** computes a prefix for `child`, derived from its common prefix with `pre`
859+
* - `pre` is assumed to be the prefix of `parent` at a given callsite.
860+
* - `child` is assumed to be the sealed child of `parent`, and reachable according to `whyNotGenericSum`.
861+
*/
862+
def childPrefix(pre: Type, parent: Symbol, child: Symbol)(using Context): Type =
863+
// Example, given this class hierarchy, we can see how this should work
864+
// when summoning a mirror for `wrapper.Color`:
865+
//
866+
// package example
867+
// object Outer3:
868+
// class Wrapper:
869+
// sealed trait Color
870+
// val wrapper = new Wrapper
871+
// object Inner:
872+
// case object Red extends wrapper.Color
873+
// case object Green extends wrapper.Color
874+
// case object Blue extends wrapper.Color
875+
//
876+
// summon[Mirror.SumOf[wrapper.Color]]
877+
// ^^^^^^^^^^^^^
878+
// > pre = example.Outer3.wrapper.type
879+
// > parent = sealed trait example.Outer3.Wrapper.Color
880+
// > child = module val example.Outer3.Innner.Red
881+
// > parentOwners = [example, Outer3, Wrapper] // computed from definition
882+
// > childOwners = [example, Outer3, Inner] // computed from definition
883+
// > parentRest = [Wrapper] // strip common owners from `childOwners`
884+
// > childRest = [Inner] // strip common owners from `parentOwners`
885+
// > commonPrefix = example.Outer3.type // i.e. parentRest has only 1 element, use 1st subprefix of `pre`.
886+
// > childPrefix = example.Outer3.Inner.type // select all symbols in `childRest` from `commonPrefix`
887+
888+
/** unwind the prefix into a sequence of sub-prefixes, selecting the one at `limit`
889+
* @return `NoType` if there is an unrecognised prefix type.
890+
*/
891+
def subPrefixAt(pre: Type, limit: Int): Type =
892+
def go(pre: Type, limit: Int): Type =
893+
if limit == 0 then pre // EXIT: No More prefix
894+
else pre match
895+
case pre: ThisType => go(pre.tref.prefix, limit - 1)
896+
case pre: TermRef => go(pre.prefix, limit - 1)
897+
case _:SuperType | NoPrefix => pre.ensuring(limit == 1) // EXIT: can't rewind further than this
898+
case _ => NoType // EXIT: unrecognized prefix
899+
go(pre, limit)
900+
end subPrefixAt
901+
902+
/** Successively select each symbol in the `suffix` from `pre`, such that they are reachable. */
903+
def selectAll(pre: Type, suffix: Seq[Symbol]): Type =
904+
suffix.foldLeft(pre)((pre, sym) =>
905+
pre.select(
906+
if sym.isType && sym.is(Module) then sym.sourceModule
907+
else sym
908+
)
909+
)
910+
911+
def stripCommonPrefix(xs: List[Symbol], ys: List[Symbol]): (List[Symbol], List[Symbol]) = (xs, ys) match
912+
case (x :: xs1, y :: ys1) if x eq y => stripCommonPrefix(xs1, ys1)
913+
case _ => (xs, ys)
914+
915+
val (parentRest, childRest) = stripCommonPrefix(
916+
parent.owner.ownersIterator.toList.reverse,
917+
child.owner.ownersIterator.toList.reverse
918+
)
919+
920+
val commonPrefix = subPrefixAt(pre, parentRest.size) // unwind parent owners up to common prefix
921+
922+
if commonPrefix.exists then selectAll(commonPrefix, childRest)
923+
else NoType
924+
925+
end childPrefix
926+
858927
end TypeOps

compiler/src/dotty/tools/dotc/transform/PostInlining.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ class PostInlining extends MacroTransform, IdentityDenotTransformer:
2626
override def transform(tree: Tree)(using Context): Tree =
2727
super.transform(tree) match
2828
case tree1: Template
29-
if tree1.hasAttachment(ExtendsSingletonMirror)
30-
|| tree1.hasAttachment(ExtendsProductMirror)
31-
|| tree1.hasAttachment(ExtendsSumMirror) =>
29+
if tree1.hasAttachment(ExtendsSingletonMirror) || tree1.hasAttachment(ExtendsSumOrProductMirror) =>
3230
synthMbr.addMirrorSupport(tree1)
3331
case tree1 => tree1
3432

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

+21-7
Original file line numberDiff line numberDiff line change
@@ -163,28 +163,42 @@ object SymUtils:
163163
* and also the location of the generated mirror.
164164
* - all of its children are generic products, singletons, or generic sums themselves.
165165
*/
166-
def whyNotGenericSum(using Context): String =
166+
def whyNotGenericSum(pre: Type)(using Context): String =
167167
if (!self.is(Sealed))
168168
s"it is not a sealed ${self.kindString}"
169169
else if (!self.isOneOf(AbstractOrTrait))
170170
"it is not an abstract class"
171171
else {
172172
val children = self.children
173173
val companionMirror = self.useCompanionAsSumMirror
174+
val ownerScope = if pre.isInstanceOf[SingletonType] then pre.classSymbol else NoSymbol
174175
def problem(child: Symbol) = {
175176

176-
def isAccessible(sym: Symbol): Boolean =
177-
(self.isContainedIn(sym) && (companionMirror || ctx.owner.isContainedIn(sym)))
178-
|| sym.is(Module) && isAccessible(sym.owner)
177+
def accessibleMessage(sym: Symbol): String =
178+
def inherits(sym: Symbol, scope: Symbol): Boolean =
179+
!scope.is(Package) && (scope.derivesFrom(sym) || inherits(sym, scope.owner))
180+
def isVisibleToParent(sym: Symbol): Boolean =
181+
self.isContainedIn(sym) || sym.is(Module) && isVisibleToParent(sym.owner)
182+
def isVisibleToScope(sym: Symbol): Boolean =
183+
def isReachable: Boolean = ctx.owner.isContainedIn(sym)
184+
def isMemberOfPrefix: Boolean =
185+
ownerScope.exists && inherits(sym, ownerScope)
186+
isReachable || isMemberOfPrefix || sym.is(Module) && isVisibleToScope(sym.owner)
187+
if !isVisibleToParent(sym) then i"to its parent $self"
188+
else if !companionMirror && !isVisibleToScope(sym) then i"to call site ${ctx.owner}"
189+
else ""
190+
end accessibleMessage
191+
192+
val childAccessible = accessibleMessage(child.owner)
179193

180194
if (child == self) "it has anonymous or inaccessible subclasses"
181-
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
195+
else if (!childAccessible.isEmpty) i"its child $child is not accessible $childAccessible"
182196
else if (!child.isClass) "" // its a singleton enum value
183197
else {
184198
val s = child.whyNotGenericProduct
185199
if s.isEmpty then s
186200
else if child.is(Sealed) then
187-
val s = child.whyNotGenericSum
201+
val s = child.whyNotGenericSum(pre)
188202
if s.isEmpty then s
189203
else i"its child $child is not a generic sum because $s"
190204
else
@@ -195,7 +209,7 @@ object SymUtils:
195209
else children.map(problem).find(!_.isEmpty).getOrElse("")
196210
}
197211

198-
def isGenericSum(using Context): Boolean = whyNotGenericSum.isEmpty
212+
def isGenericSum(pre: Type)(using Context): Boolean = whyNotGenericSum(pre).isEmpty
199213

200214
/** If this is a constructor, its owner: otherwise this. */
201215
final def skipConstructor(using Context): Symbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+67-42
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ import NullOpsDecorator._
1818

1919
object SyntheticMembers {
2020

21+
enum MirrorImpl:
22+
case OfProduct(pre: Type)
23+
case OfSum(childPres: List[Type])
24+
2125
/** Attachment marking an anonymous class as a singleton case that will extend from Mirror.Singleton */
2226
val ExtendsSingletonMirror: Property.StickyKey[Unit] = new Property.StickyKey
2327

2428
/** Attachment recording that an anonymous class should extend Mirror.Product */
25-
val ExtendsProductMirror: Property.StickyKey[Unit] = new Property.StickyKey
26-
27-
/** Attachment recording that an anonymous class should extend Mirror.Sum */
28-
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey
29+
val ExtendsSumOrProductMirror: Property.StickyKey[MirrorImpl] = new Property.StickyKey
2930
}
3031

3132
/** Synthetic method implementations for case classes, case objects,
@@ -483,32 +484,41 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
483484
* type MirroredMonoType = C[?]
484485
* ```
485486
*/
486-
def fromProductBody(caseClass: Symbol, param: Tree)(using Context): Tree = {
487-
val (classRef, methTpe) =
488-
caseClass.primaryConstructor.info match {
487+
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
488+
def extractParams(tpe: Type): List[Type] =
489+
tpe.asInstanceOf[MethodType].paramInfos
490+
491+
def computeFromCaseClass: (Type, List[Type]) =
492+
val (baseRef, baseInfo) =
493+
val rawRef = caseClass.typeRef
494+
val rawInfo = caseClass.primaryConstructor.info
495+
optInfo match
496+
case Some(info) =>
497+
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
498+
case _ =>
499+
(rawRef, rawInfo)
500+
baseInfo match
489501
case tl: PolyType =>
490502
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
491503
val targs =
492504
for (tpt <- tpts) yield
493505
tpt.tpe match {
494506
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
495507
}
496-
(caseClass.typeRef.appliedTo(targs), tl.instantiate(targs))
508+
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
497509
case methTpe =>
498-
(caseClass.typeRef, methTpe)
499-
}
500-
methTpe match {
501-
case methTpe: MethodType =>
502-
val elems =
503-
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
504-
val elem =
505-
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
506-
.ensureConforms(formal.translateFromRepeated(toArray = false))
507-
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
508-
}
509-
New(classRef, elems)
510-
}
511-
}
510+
(baseRef, extractParams(methTpe))
511+
end computeFromCaseClass
512+
513+
val (classRefApplied, paramInfos) = computeFromCaseClass
514+
val elems =
515+
for ((formal, idx) <- paramInfos.zipWithIndex) yield
516+
val elem =
517+
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
518+
.ensureConforms(formal.translateFromRepeated(toArray = false))
519+
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
520+
New(classRefApplied, elems)
521+
end fromProductBody
512522

513523
/** For an enum T:
514524
*
@@ -526,24 +536,36 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
526536
* a wildcard for each type parameter. The normalized type of an object
527537
* O is O.type.
528538
*/
529-
def ordinalBody(cls: Symbol, param: Tree)(using Context): Tree =
530-
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
531-
else {
539+
def ordinalBody(cls: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfSum])(using Context): Tree =
540+
if cls.is(Enum) then
541+
param.select(nme.ordinal).ensureApplied
542+
else
543+
def computeChildTypes: List[Type] =
544+
def rawRef(child: Symbol): Type =
545+
if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
546+
optInfo match
547+
case Some(info) => info
548+
.childPres
549+
.lazyZip(cls.children)
550+
.map((pre, child) => rawRef(child).asSeenFrom(pre, child.owner))
551+
case _ =>
552+
cls.children.map(rawRef)
553+
end computeChildTypes
554+
val childTypes = computeChildTypes
532555
val cases =
533-
for ((child, idx) <- cls.children.zipWithIndex) yield {
534-
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
556+
for (patType, idx) <- childTypes.zipWithIndex yield
535557
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
536558
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
537-
}
559+
538560
Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases)
539-
}
561+
end ordinalBody
540562

541563
/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
542564
* and `MirroredMonoType` and `ordinal` members.
543565
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
544566
* and `MirroredMonoType` and `fromProduct` members.
545-
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror, ExtendsProductMirror,
546-
* or ExtendsSumMirror, remove the attachment and generate the corresponding mirror support,
567+
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror or ExtendsSumOfProductMirror,
568+
* remove the attachment and generate the corresponding mirror support,
547569
* On this case the represented class or object is referred to in a pre-existing `MirroredMonoType`
548570
* member of the template.
549571
*/
@@ -580,30 +602,33 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
580602
}
581603
def makeSingletonMirror() =
582604
addParent(defn.Mirror_SingletonClass.typeRef)
583-
def makeProductMirror(cls: Symbol) = {
605+
def makeProductMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfProduct]) = {
584606
addParent(defn.Mirror_ProductClass.typeRef)
585607
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
586-
fromProductBody(_, _).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
608+
fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
587609
}
588-
def makeSumMirror(cls: Symbol) = {
610+
def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = {
589611
addParent(defn.Mirror_SumClass.typeRef)
590612
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
591-
ordinalBody(_, _))
613+
ordinalBody(_, _, optInfo))
592614
}
593615

594616
if (clazz.is(Module)) {
595617
if (clazz.is(Case)) makeSingletonMirror()
596-
else if (linked.isGenericProduct) makeProductMirror(linked)
597-
else if (linked.isGenericSum) makeSumMirror(linked)
618+
else if (linked.isGenericProduct) makeProductMirror(linked, None)
619+
else if (linked.isGenericSum(NoType)) makeSumMirror(linked, None)
598620
else if (linked.is(Sealed))
599-
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum}")
621+
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(NoType)}")
600622
}
601623
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
602624
makeSingletonMirror()
603-
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
604-
makeProductMirror(monoType.typeRef.dealias.classSymbol)
605-
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
606-
makeSumMirror(monoType.typeRef.dealias.classSymbol)
625+
else
626+
impl.removeAttachment(ExtendsSumOrProductMirror).match
627+
case Some(prodImpl: MirrorImpl.OfProduct) =>
628+
makeProductMirror(monoType.typeRef.dealias.classSymbol, Some(prodImpl))
629+
case Some(sumImpl: MirrorImpl.OfSum) =>
630+
makeSumMirror(monoType.typeRef.dealias.classSymbol, Some(sumImpl))
631+
case _ =>
607632

608633
cpy.Template(impl)(parents = newParents, body = newBody)
609634
}

0 commit comments

Comments
 (0)