Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine refining annotations #22574

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,31 +189,6 @@ object Trees {

override def toText(printer: Printer): Text = printer.toText(this)

def sameTree(that: Tree[?]): Boolean = {
def isSame(x: Any, y: Any): Boolean =
x.asInstanceOf[AnyRef].eq(y.asInstanceOf[AnyRef]) || {
x match {
case x: Tree[?] =>
y match {
case y: Tree[?] => x.sameTree(y)
case _ => false
}
case x: List[?] =>
y match {
case y: List[?] => x.corresponds(y)(isSame)
case _ => false
}
case _ =>
false
}
}
this.getClass == that.getClass && {
val it1 = this.productIterator
val it2 = that.productIterator
it1.corresponds(it2)(isSame)
}
}

override def hashCode(): Int = System.identityHashCode(this)
override def equals(that: Any): Boolean = this eq that.asInstanceOf[AnyRef]
}
Expand Down
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,60 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else
applyOverloaded(tree, nme.EQ, that :: Nil, Nil, defn.BooleanType)

def sameTree(that: Tree, thisParamSyms: List[Symbol] = Nil, thatParamRefs: List[TermRef] = Nil)(using Context): Boolean =
def recur(tree1: Tree, tree2: Tree) =
tree1.sameTree(tree2, thisParamSyms, thatParamRefs)

def sameTrees(trees1: List[Tree], trees2: List[Tree]) =
trees1.corresponds(trees2)(recur)

def sameType(tp1: Type, tp2: Type) =
(tp1 frozen_=:= tp2) || (tp1.subst(thisParamSyms, thatParamRefs) frozen_=:= tp2)

val res = tree match
case Literal(_) | Ident(_) =>
sameType(tree.tpe, that.tpe)
case Select(qual1, name1) =>
that match
case Select(qual2, name2) => name1 == name2 && recur(qual1, qual2)
case _ => false
case Apply(fun1, args1) =>
that match
case Apply(fun2, args2) => recur(fun1, fun2) && sameTrees(args1, args2)
case _ => false
case TypeApply(fun1, args1) =>
that match
case TypeApply(fun2, args2) =>
recur(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => sameType(arg1.tpe, arg2.tpe))
case _ => false
case tpt1: TypeTree =>
that match
case tpt2: TypeTree => sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case Typed(expr1, tpt1) =>
that match
case Typed(expr2, tpt2) => recur(expr1, expr2) && sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case New(tpt1) =>
that match
case New(tpt2) => sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case closureDef(def1) =>
that match
case closureDef(def2) =>
val newThisParamSyms = def1.symbol.paramSymss.flatten ++ thisParamSyms
val newThatParamRefs = def2.symbol.paramSymss.flatten.map(_.termRef) ++ thatParamRefs
def1.rhs.sameTree(def2.rhs, newThisParamSyms, newThatParamRefs)
case _ => false
case Block(stats1, expr1) =>
that match
case Block(stats2, expr2) => sameTrees(stats1, stats2) && recur(expr1, expr2)
case _ => false
case _ => false

res


/** `tree.isInstanceOf[tp]`, with special treatment of singleton types */
def isInstance(tp: Type)(using Context): Tree = tp.dealias match {
case ConstantType(c) if c.tag == StringTag =>
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ object Annotations {
def argumentConstantString(i: Int)(using Context): Option[String] =
for (case Constant(s: String) <- argumentConstant(i)) yield s

/** All type and term argument trees of this annotation in a single flat list */
private def allArguments(using Context): List[Tree] = tpd.allArguments(tree)

/** The tree evaluation is in progress. */
def isEvaluating: Boolean = false

Expand Down Expand Up @@ -88,7 +91,8 @@ object Annotations {
def ensureCompleted(using Context): Unit = tree

def sameAnnotation(that: Annotation)(using Context): Boolean =
symbol == that.symbol && tree.sameTree(that.tree)
def sameArg(arg1: Tree, arg2: Tree): Boolean = tpd.stripNamedArg(arg1).sameTree(tpd.stripNamedArg(arg2))
symbol == that.symbol && allArguments.corresponds(that.allArguments)(sameArg)

def hasOneOfMetaAnnotation(metaSyms: Set[Symbol], orNoneOf: Set[Symbol] = Set.empty)(using Context): Boolean = atPhaseNoLater(erasurePhase) {
def go(metaSyms: Set[Symbol]) =
Expand Down
40 changes: 40 additions & 0 deletions tests/neg/annot-refining-infer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
class MyAnnotation(x: Any) extends scala.annotation.RefiningAnnotation

def id[T](x: T): T = x
def id2[T](x: T, y: T): T = x

def foo1[T](x: T, g: T => Unit): T = x
def foo2[T](x: T, y: T, g: T => Unit): T = x
def foo3[T](g: T => Unit, x: T, y: T): T = x
def foo4[T](x: T, g: T => Unit, h: T => Unit): T = x

def take42[T](x: T @MyAnnotation(42)): Unit = ()
def take43[T](x: T @MyAnnotation(43)): Unit = ()
def take42or43[S](x: S @MyAnnotation(42) | S @MyAnnotation(43)): Unit = ()
def take42or43Int(x: Int @MyAnnotation(42) | Int @MyAnnotation(43)): Unit = ()

def main =
val c42: Int @MyAnnotation(42) = ???
val c43: Int @MyAnnotation(43) = ???

val v01 = id2[Int @MyAnnotation(42) | Int @MyAnnotation(43)](c42, c43)
val v02: Int @MyAnnotation(42) | Int @MyAnnotation(43) = c42
val v03: Int @MyAnnotation(42) | Int @MyAnnotation(43) = id2(c42, c43)

val v04 = foo1(c42, take42)
val v05: Int @MyAnnotation(42) = v13
val v06 = foo1(c42, take43) // error
val v07 = foo1(c42, take42or43)

val v08 = foo2(c42, c42, take42)
val v09: Int @MyAnnotation(42) = v15
val v10 = foo2(c42, c43, take42) // error
val v11 = foo2(c42, c43, take42or43) // error
val v12 = foo2(c42, c43, take42or43Int)

val v13 = foo3(take42or43, c42, c43) // error
val v14 = foo3(take42or43Int, c42, c43)

val v15 = foo4(c42, take42, take42)
val v16: Int @MyAnnotation(42) = v15
val v17 = foo4(c42, take42, take43) // error
92 changes: 92 additions & 0 deletions tests/neg/annot-refining-sub.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import compiletime.ops.int.+

class annot1(a: Any) extends scala.annotation.RefiningAnnotation
class annot2(a: Any, b: Any) extends scala.annotation.RefiningAnnotation
class annot3(a: Any, b: Any = 3) extends scala.annotation.RefiningAnnotation
class annot4[Int] extends scala.annotation.RefiningAnnotation

class Box[T](val a: Int)
case class Box2[T](val a: Int)
class Box3:
type T

def id[T](x: T): T = x

type BoxAlias = Box[Int]
type Box2Alias = Box2[Int]

object O:
val d: Int = 42

def main =
val c: Int = 42
val o: O.type = O

val v1: Int @annot1(1) = ??? : Int @annot1(1)
val v2: Int @annot1(c) = ??? : Int @annot1(c)
val v3: Int @annot1(O.d) = ??? : Int @annot1(O.d)
val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d)
val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2))
val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2)
val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: no constant folding
val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c)
val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error: no algebraic simplification
val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1))
val v11: Int @annot1(Box(c)) = ??? : Int @annot1(Box(c))
val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1))
val v13: Int @annot1(Box2(c)) = ??? : Int @annot1(Box2(c))
val v14: Int @annot1(c: Int) = ??? : Int @annot1(c: Int)
val v15: Int @annot1(c) = ??? : Int @annot1(c: Int) // error
val v16: Int @annot1(c: Int) = ??? : Int @annot1(c) // error
val v17: Int @annot1(id[Int]) = ??? : Int @annot1(id[Int])
val v18: Int @annot1(id[Int]) = ??? : Int @annot1(id[String]) // error
val v19: Int @annot1(id[Any]) = ??? : Int @annot1(id[Int]) // error
val v20: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1): BoxAlias) // error
val v21: Int @annot1(Box(c): BoxAlias) = ??? : Int @annot1(Box(c)) // error
val v22: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1): Box2Alias) // error
val v23: Int @annot1(Box2(c): Box2Alias) = ??? : Int @annot1(Box2(c)) // error
val v24: Int @annot1(Box3()) = ??? : Int @annot1(Box3())
val v25: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3 {type T = Int})
val v26: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3 {type T = String}) // error
val v27: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3) // error
val v28: Int @annot1(a=c) = ??? : Int @annot1(a=c)
val v29: Int @annot1(a=c) = ??? : Int @annot1(c)
val v30: Int @annot1(c) = ??? : Int @annot1(a=c)
val v31: Int @annot1((d: Int) => d) = ??? : Int @annot1((d: Int) => d)
val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e)
val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d)
val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1)
val v35: Int @annot1((d: Int) => id(d)) = ??? : Int @annot1((e: Int) => id(e))
val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type])
val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T])
val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e)
val v39: Int @annot1((d: Int) => ((e: Int) => e)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2))
val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2)
val v41: Int @annot2(c, c) = ??? : Int @annot2(c, c)
val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c)
val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c)
val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c)

val v45: Int @annot3(1) = ??? : Int @annot3(1)
val v46: Int @annot3(c) = ??? : Int @annot3(c)
val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: default arg tree is different, fix in the future?
val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: same as above
val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: same as above
val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: same as above

val v51: Int @annot4[1] = ??? : Int @annot4[1]
val v52: Int @annot4[c.type] = ??? : Int @annot4[c.type]
val v53: Int @annot4[O.d.type] = ??? : Int @annot4[O.d.type]
val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]
val v55: Int @annot4[Int] = ??? : Int @annot4[Int]
val v56: Int @annot4[Int] = ??? : Int @annot4[1] // error
val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)]
val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2]
val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1]
val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type]
val v61: Int @annot4[1 + c.type] = ??? : Int @annot4[c.type + 1] // error
val v62: Int @annot4[Box[Int]] = ??? : Int @annot4[Box[Int]]
val v63: Int @annot4[Box[String]] = ??? : Int @annot4[Box[Int]] // error
val v64: Int @annot4[Box2[Int]] = ??? : Int @annot4[Box2[Int]]
val v65: Int @annot4[Box2[String]] = ??? : Int @annot4[Box2[Int]] // error
val v66: Int @annot4[1] = ??? : Int @annot4[Int] // error
Loading