Skip to content

Commit 7536df9

Browse files
Merge pull request #9032 from dotty-staging/fix-#9028
Fix #9028: Introduce super traits
2 parents 3c56b3e + 116c878 commit 7536df9

File tree

26 files changed

+249
-50
lines changed

26 files changed

+249
-50
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

+2
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
200200
case class Inline()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Inline)
201201

202202
case class Transparent()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.EmptyFlags)
203+
204+
case class Super()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.SuperTrait)
203205
}
204206

205207
/** Modifiers and annotations for definitions

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+36-18
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,12 @@ trait ConstraintHandling[AbstractContext] {
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302302
* of all common base types, provided the result is a subtype of `bound`.
303-
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
304-
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
303+
* 3. If `inst` is an intersection such that some operands are super trait instances
304+
* and others are not, replace as many super trait instances as possible with Any
305+
* as long as the result is still a subtype of `bound`. But fall back to the
306+
* original type if the resulting widened type is a supertype of all dropped
307+
* types (since in this case the type was not a true intersection of super traits
308+
* and other types to start with).
305309
*
306310
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
307311
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -313,21 +317,38 @@ trait ConstraintHandling[AbstractContext] {
313317
*/
314318
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
315319

316-
def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
320+
def dropSuperTraits(tp: Type): Type =
321+
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
322+
var dropped: List[Type] = List() // the types dropped so far, last one on top
323+
324+
def dropOneSuperTrait(tp: Type): Type =
325+
val tpd = tp.dealias
326+
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
327+
dropped = tpd :: dropped
328+
defn.AnyType
329+
else tpd match
330+
case AndType(tp1, tp2) =>
331+
val tp1w = dropOneSuperTrait(tp1)
332+
if tp1w ne tp1 then tp1w & tp2
333+
else
334+
val tp2w = dropOneSuperTrait(tp2)
335+
if tp2w ne tp2 then tp1 & tp2w
336+
else tpd
337+
case _ =>
338+
tp
317339

318-
def dropRestricted(tp: Type): Type = tp.dealias match
319-
case tpd @ AndType(tp1, tp2) =>
320-
if isRestricted(tp1) then tp2
321-
else if isRestricted(tp2) then tp1
340+
def recur(tp: Type): Type =
341+
val tpw = dropOneSuperTrait(tp)
342+
if tpw eq tp then tp
343+
else if tpw <:< bound then recur(tpw)
322344
else
323-
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
324-
if tpw ne tpd then tpw else tp
325-
case _ =>
326-
tp
345+
kept += dropped.head
346+
dropped = dropped.tail
347+
recur(tp)
327348

328-
def widenRestricted(tp: Type) =
329-
val tpw = dropRestricted(tp)
330-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
349+
val tpw = recur(tp)
350+
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
351+
end dropSuperTraits
331352

332353
def widenOr(tp: Type) =
333354
val tpw = tp.widenUnion
@@ -343,10 +364,7 @@ trait ConstraintHandling[AbstractContext] {
343364

344365
val wideInst =
345366
if isSingleton(bound) then inst
346-
else /*widenRestricted*/(widenOr(widenSingle(inst)))
347-
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
348-
// in `Namer`. It's left in here in case we want to generalize the scheme to other
349-
// "protected inheritance" classes.
367+
else dropSuperTraits(widenOr(widenSingle(inst)))
350368
wideInst match
351369
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
352370
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+16-1
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ class Definitions {
640640
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
641641
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
642642

643-
@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
644643
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
645644
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
646645
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
@@ -809,6 +808,7 @@ class Definitions {
809808
@tu lazy val ScalaStrictFPAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.strictfp")
810809
@tu lazy val ScalaStaticAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.static")
811810
@tu lazy val SerialVersionUIDAnnot: ClassSymbol = ctx.requiredClass("scala.SerialVersionUID")
811+
@tu lazy val SuperTraitAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.superTrait")
812812
@tu lazy val TASTYSignatureAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.internal.TASTYSignature")
813813
@tu lazy val TASTYLongSignatureAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.internal.TASTYLongSignature")
814814
@tu lazy val TailrecAnnot: ClassSymbol = ctx.requiredClass("scala.annotation.tailrec")
@@ -1313,6 +1313,21 @@ class Definitions {
13131313
def isInfix(sym: Symbol)(implicit ctx: Context): Boolean =
13141314
(sym eq Object_eq) || (sym eq Object_ne)
13151315

1316+
@tu lazy val assumedSuperTraits =
1317+
Set(ComparableClass, ProductClass, SerializableClass,
1318+
// add these for now, until we had a chance to retrofit 2.13 stdlib
1319+
// we should do a more through sweep through it then.
1320+
ctx.requiredClass("scala.collection.SortedOps"),
1321+
ctx.requiredClass("scala.collection.StrictOptimizedSortedSetOps"),
1322+
ctx.requiredClass("scala.collection.generic.DefaultSerializable"),
1323+
ctx.requiredClass("scala.collection.generic.IsIterable"),
1324+
ctx.requiredClass("scala.collection.generic.IsIterableOnce"),
1325+
ctx.requiredClass("scala.collection.generic.IsMap"),
1326+
ctx.requiredClass("scala.collection.generic.IsSeq"),
1327+
ctx.requiredClass("scala.collection.generic.Subtractable"),
1328+
ctx.requiredClass("scala.collection.immutable.StrictOptimizedSeqOps")
1329+
)
1330+
13161331
// ----- primitive value class machinery ------------------------------------------
13171332

13181333
/** This class would also be obviated by the implicit function type design */

compiler/src/dotty/tools/dotc/core/Flags.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ object Flags {
222222
/** Labeled with `final` modifier */
223223
val (Final @ _, _, _) = newFlags(6, "final")
224224

225-
/** A method symbol */
226-
val (_, Method @ _, HigherKinded @ _) = newFlags(7, "<method>", "<higher kinded>") // TODO drop HigherKinded
225+
/** A method symbol / a super trait */
226+
val (_, Method @ _, SuperTrait @ _) = newFlags(7, "<method>", "super")
227227

228228
/** A (term or type) parameter to a class or method */
229229
val (Param @ _, TermParam @ _, TypeParam @ _) = newFlags(8, "<param>")
@@ -439,7 +439,7 @@ object Flags {
439439
*/
440440
val FromStartFlags: FlagSet = commonFlags(
441441
Module, Package, Deferred, Method, Case, Enum,
442-
HigherKinded, Param, ParamAccessor,
442+
SuperTrait, Param, ParamAccessor,
443443
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
444444
OuterOrCovariant, LabelOrContravariant, CaseAccessor,
445445
Extension, NonMember, Implicit, Given, Permanent, Synthetic,

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+6
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,12 @@ object SymDenotations {
11701170
final def isEffectivelySealed(using Context): Boolean =
11711171
isOneOf(FinalOrSealed) || isClass && !isOneOf(EffectivelyOpenFlags)
11721172

1173+
final def isSuperTrait(using Context): Boolean =
1174+
isClass
1175+
&& (is(SuperTrait)
1176+
|| defn.assumedSuperTraits.contains(symbol.asClass)
1177+
|| hasAnnotation(defn.SuperTraitAnnot))
1178+
11731179
/** The class containing this denotation which has the given effective name. */
11741180
final def enclosingClassNamed(name: Name)(implicit ctx: Context): Symbol = {
11751181
val cls = enclosingClass

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ class TreePickler(pickler: TastyPickler) {
705705
if (flags.is(Sealed)) writeModTag(SEALED)
706706
if (flags.is(Abstract)) writeModTag(ABSTRACT)
707707
if (flags.is(Trait)) writeModTag(TRAIT)
708+
if flags.is(SuperTrait) then writeModTag(SUPERTRAIT)
708709
if (flags.is(Covariant)) writeModTag(COVARIANT)
709710
if (flags.is(Contravariant)) writeModTag(CONTRAVARIANT)
710711
if (flags.is(Opaque)) writeModTag(OPAQUE)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ class TreeUnpickler(reader: TastyReader,
639639
case STATIC => addFlag(JavaStatic)
640640
case OBJECT => addFlag(Module)
641641
case TRAIT => addFlag(Trait)
642+
case SUPERTRAIT => addFlag(SuperTrait)
642643
case ENUM => addFlag(Enum)
643644
case LOCAL => addFlag(Local)
644645
case SYNTHETIC => addFlag(Synthetic)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -3435,7 +3435,7 @@ object Parsers {
34353435
}
34363436
}
34373437

3438-
/** TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
3438+
/** TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
34393439
* | [‘case’] ‘object’ ObjectDef
34403440
* | ‘enum’ EnumDef
34413441
* | ‘given’ GivenDef
@@ -3445,6 +3445,8 @@ object Parsers {
34453445
in.token match {
34463446
case TRAIT =>
34473447
classDef(start, posMods(start, addFlag(mods, Trait)))
3448+
case SUPERTRAIT =>
3449+
classDef(start, posMods(start, addFlag(mods, Trait | SuperTrait)))
34483450
case CLASS =>
34493451
classDef(start, posMods(start, mods))
34503452
case CASECLASS =>

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ object Scanners {
586586
currentRegion = r.outer
587587
case _ =>
588588

589-
/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
589+
/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SUPER + TRAIT => SUPERTRAIT
590+
* SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
590591
* - Insert missing OUTDENTs at EOF
591592
*/
592593
def postProcessToken(): Unit = {
@@ -602,6 +603,10 @@ object Scanners {
602603
if (token == CLASS) fuse(CASECLASS)
603604
else if (token == OBJECT) fuse(CASEOBJECT)
604605
else reset()
606+
case SUPER =>
607+
lookAhead()
608+
if token == TRAIT then fuse(SUPERTRAIT)
609+
else reset()
605610
case SEMI =>
606611
lookAhead()
607612
if (token != ELSE) reset()

compiler/src/dotty/tools/dotc/parsing/Tokens.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ object Tokens extends TokensCommon {
184184
final val ERASED = 63; enter(ERASED, "erased")
185185
final val GIVEN = 64; enter(GIVEN, "given")
186186
final val EXPORT = 65; enter(EXPORT, "export")
187-
final val MACRO = 66; enter(MACRO, "macro") // TODO: remove
187+
final val SUPERTRAIT = 66; enter(SUPERTRAIT, "super trait")
188+
final val MACRO = 67; enter(MACRO, "macro") // TODO: remove
188189

189190
/** special symbols */
190191
final val NEWLINE = 78; enter(NEWLINE, "end of statement", "new line")
@@ -233,7 +234,7 @@ object Tokens extends TokensCommon {
233234
final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet(
234235
THIS, SUPER, USCORE, LPAREN, AT)
235236

236-
final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT)
237+
final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT, SUPERTRAIT)
237238

238239
final val dclIntroTokens: TokenSet = BitSet(DEF, VAL, VAR, TYPE, GIVEN)
239240

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
725725
if mdef.hasType then Modifiers(mdef.symbol) else mdef.rawMods
726726

727727
private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers(
728-
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
728+
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags | SuperTrait else ModifierFlags),
729729
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
730730
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))
731731

@@ -835,7 +835,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
835835
}
836836

837837
protected def templateText(tree: TypeDef, impl: Template): Text = {
838-
val decl = modText(tree.mods, tree.symbol, keywordStr(if (tree.mods.is(Trait)) "trait" else "class"), isType = true)
838+
val kw =
839+
if tree.mods.is(SuperTrait) then "super trait"
840+
else if tree.mods.is(Trait) then "trait"
841+
else "class"
842+
val decl = modText(tree.mods, tree.symbol, keywordStr(kw), isType = true)
839843
( decl ~~ typeText(nameIdText(tree)) ~ withEnclosingDef(tree) { toTextTemplate(impl) }
840844
// ~ (if (tree.hasType && printDebug) i"[decls = ${tree.symbol.info.decls}]" else "") // uncomment to enable
841845
)
@@ -941,6 +945,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
941945
else if (sym.isPackageObject) "package object"
942946
else if (flags.is(Module) && flags.is(Case)) "case object"
943947
else if (sym.isClass && flags.is(Case)) "case class"
948+
else if sym.isClass && flags.is(SuperTrait) then "super trait"
944949
else super.keyString(sym)
945950
}
946951

compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ExtractSemanticDB extends Phase:
101101

102102
private def excludeChildren(sym: Symbol)(using Context): Boolean =
103103
!sym.exists
104-
|| sym.isAllOf(HigherKinded | Param)
104+
|| sym.is(Param) && sym.info.bounds.hi.isInstanceOf[Types.HKTypeLambda]
105105

106106
/** Uses of this symbol where the reference has given span should be excluded from semanticdb */
107107
private def excludeUse(qualifier: Option[Symbol], sym: Symbol)(using Context): Boolean =

compiler/src/dotty/tools/dotc/typer/Namer.scala

+1-17
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ class Namer { typer: Typer =>
404404
case _: TypeBoundsTree | _: MatchTypeTree =>
405405
flags |= Deferred // Typedefs with Match rhs classify as abstract
406406
case LambdaTypeTree(_, body) =>
407-
flags |= HigherKinded
408407
analyzeRHS(body)
409408
case _ =>
410409
if rhs.isEmpty || flags.is(Opaque) then flags |= Deferred
@@ -1459,19 +1458,6 @@ class Namer { typer: Typer =>
14591458
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
14601459
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)
14611460

1462-
def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass
1463-
1464-
// Drop EnumValue parents from inferred types of enum constants
1465-
def dropEnumValue(tp: Type): Type = tp.dealias match
1466-
case tpd @ AndType(tp1, tp2) =>
1467-
if isEnumValue(tp1) then tp2
1468-
else if isEnumValue(tp2) then tp1
1469-
else
1470-
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
1471-
if tpw ne tpd then tpw else tp
1472-
case _ =>
1473-
tp
1474-
14751461
// Widen rhs type and eliminate `|' but keep ConstantTypes if
14761462
// definition is inline (i.e. final in Scala2) and keep module singleton types
14771463
// instead of widening to the underlying module class types.
@@ -1480,9 +1466,7 @@ class Namer { typer: Typer =>
14801466
def widenRhs(tp: Type): Type =
14811467
tp.widenTermRefExpr.simplified match
14821468
case ctp: ConstantType if isInlineVal => ctp
1483-
case tp =>
1484-
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
1485-
if sym.is(Enum) then dropEnumValue(tp1) else tp1
1469+
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
14861470

14871471
// Replace aliases to Unit by Unit itself. If we leave the alias in
14881472
// it would be erased to BoxedUnit.

docs/docs/internals/syntax.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ VarDef ::= PatDef
388388
DefDef ::= DefSig [‘:’ Type] ‘=’ Expr DefDef(_, name, tparams, vparamss, tpe, expr)
389389
| ‘this’ DefParamClause DefParamClauses ‘=’ ConstrExpr DefDef(_, <init>, Nil, vparamss, EmptyTree, expr | Block)
390390
391-
TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
391+
TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
392392
| [‘case’] ‘object’ ObjectDef
393393
| ‘enum’ EnumDef
394394
| ‘given’ GivenDef

0 commit comments

Comments
 (0)