Skip to content

Commit 2fa7b5f

Browse files
committed
Fix compiler API & syntax incompatibilities
1 parent ac46a05 commit 2fa7b5f

File tree

4 files changed

+1228
-0
lines changed

4 files changed

+1228
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
package ammonite.compiler
2+
3+
import ammonite.util.{ImportData, Imports, Printer, Util, Name as AmmName}
4+
import dotty.tools.dotc
5+
import dotty.tools.dotc.ast.Trees.*
6+
import dotty.tools.dotc.ast.{tpd, untpd}
7+
import dotty.tools.dotc.core.Contexts.*
8+
import dotty.tools.dotc.core.Flags
9+
import dotty.tools.dotc.core.Names.Name
10+
import dotty.tools.dotc.core.Phases.Phase
11+
import dotty.tools.dotc.core.StdNames.nme
12+
import dotty.tools.dotc.core.Symbols.{NoSymbol, Symbol, newSymbol}
13+
import dotty.tools.dotc.core.Types.{TermRef, Type, TypeTraverser}
14+
15+
import scala.collection.mutable
16+
17+
class AmmonitePhase(
18+
userCodeNestingLevel: => Int,
19+
needsUsedEarlierDefinitions: => Boolean
20+
) extends Phase:
21+
import tpd.*
22+
23+
def phaseName: String = "ammonite"
24+
25+
private var myImports = new mutable.ListBuffer[(Boolean, String, String, Seq[AmmName])]
26+
private var usedEarlierDefinitions0 = new mutable.ListBuffer[String]
27+
28+
def importData: Seq[ImportData] =
29+
val grouped = myImports
30+
.toList
31+
.distinct
32+
.groupBy { case (a, b, c, d) => (b, c, d) }
33+
.mapValues(_.map(_._1))
34+
35+
val open = for {
36+
((fromName, toName, importString), items) <- grouped
37+
if !CompilerUtil.ignoredNames(fromName)
38+
} yield {
39+
val importType = items match{
40+
case Seq(true) => ImportData.Type
41+
case Seq(false) => ImportData.Term
42+
case Seq(_, _) => ImportData.TermType
43+
}
44+
45+
ImportData(AmmName(fromName), AmmName(toName), importString, importType)
46+
}
47+
48+
open.toVector.sortBy(x => Util.encodeScalaSourcePath(x.prefix))
49+
50+
def usedEarlierDefinitions: Seq[String] =
51+
usedEarlierDefinitions0.toList.distinct
52+
53+
private def saneSym(name: Name, sym: Symbol)(using Context): Boolean =
54+
!name.decode.toString.contains('$') &&
55+
sym.exists &&
56+
// !sym.is(Flags.Synthetic) &&
57+
!scala.util.Try(sym.is(Flags.Private)).toOption.getOrElse(true) &&
58+
!scala.util.Try(sym.is(Flags.Protected)).toOption.getOrElse(true) &&
59+
// sym.is(Flags.Public) &&
60+
!CompilerUtil.ignoredSyms(sym.toString) &&
61+
!CompilerUtil.ignoredNames(name.decode.toString)
62+
63+
private def saneSym(sym: Symbol)(using Context): Boolean =
64+
saneSym(sym.name, sym)
65+
66+
private def processTree(t: tpd.Tree)(using Context): Unit = {
67+
val sym = t.symbol
68+
val name = t match {
69+
case t: tpd.ValDef => t.name
70+
case _ => sym.name
71+
}
72+
if (saneSym(name, sym)) {
73+
val name = sym.name.decode.toString
74+
myImports.addOne((sym.isType, name, name, Nil))
75+
}
76+
}
77+
78+
private def processImport(i: tpd.Import)(using Context): Unit = {
79+
val expr = i.expr
80+
val selectors = i.selectors
81+
82+
// Most of that logic was adapted from AmmonitePlugin, the Scala 2 counterpart
83+
// of this file.
84+
85+
val prefix =
86+
val (_ :: nameListTail, symbolHead :: _) = {
87+
def rec(expr: tpd.Tree): List[(Name, Symbol)] = {
88+
expr match {
89+
case s @ tpd.Select(lhs, _) => (s.symbol.name -> s.symbol) :: rec(lhs)
90+
case i @ tpd.Ident(name) => List(name -> i.symbol)
91+
case t @ tpd.This(pkg) => List(pkg.name -> t.symbol)
92+
}
93+
}
94+
rec(expr).reverse.unzip
95+
}
96+
97+
val headFullPath = symbolHead.fullName.decode.toString.split('.')
98+
.map(n => if (n.endsWith("$")) n.stripSuffix("$") else n) // meh
99+
// prefix package imports with `_root_` to try and stop random
100+
// variables from interfering with them. If someone defines a value
101+
// called `_root_`, this will still break, but that's their problem
102+
val rootPrefix = if(symbolHead.denot.is(Flags.Package)) Seq("_root_") else Nil
103+
val tailPath = nameListTail.map(_.decode.toString)
104+
105+
(rootPrefix ++ headFullPath ++ tailPath).map(AmmName(_))
106+
107+
def isMask(sel: untpd.ImportSelector) = sel.name != nme.WILDCARD && sel.rename == nme.WILDCARD
108+
109+
val renameMap =
110+
111+
/**
112+
* A map of each name importable from `expr`, to a `Seq[Boolean]`
113+
* containing a `true` if there's a type-symbol you can import, `false`
114+
* if there's a non-type symbol and both if there are both type and
115+
* non-type symbols that are importable for that name
116+
*/
117+
val importableIsTypes =
118+
expr.tpe
119+
.allMembers
120+
.map(_.symbol)
121+
.filter(saneSym(_))
122+
.groupBy(_.name.decode.toString)
123+
.mapValues(_.map(_.isType).toVector)
124+
125+
val renamings = for{
126+
t @ untpd.ImportSelector(name, renameTree, _) <- selectors
127+
if !isMask(t)
128+
// getOrElse just in case...
129+
isType <- importableIsTypes.getOrElse(name.name.decode.toString, Nil)
130+
case Ident(rename) <- Option(renameTree)
131+
} yield ((isType, rename.decode.toString), name.name.decode.toString)
132+
133+
renamings.toMap
134+
135+
136+
def isUnimportableUnlessRenamed(sym: Symbol): Boolean =
137+
sym eq NoSymbol
138+
139+
@scala.annotation.tailrec
140+
def transformImport(selectors: List[untpd.ImportSelector], sym: Symbol): List[Symbol] =
141+
selectors match {
142+
case Nil => Nil
143+
case sel :: Nil if sel.isWildcard =>
144+
if (isUnimportableUnlessRenamed(sym)) Nil
145+
else List(sym)
146+
case (sel @ untpd.ImportSelector(from, to, _)) :: _
147+
if from.name == (if (from.isTerm) sym.name.toTermName else sym.name.toTypeName) =>
148+
if (isMask(sel)) Nil
149+
else List(
150+
newSymbol(sym.owner, sel.rename, sym.flags, sym.info, sym.privateWithin, sym.coord)
151+
)
152+
case _ :: rest => transformImport(rest, sym)
153+
}
154+
155+
val symNames =
156+
for {
157+
sym <- expr.tpe.allMembers.map(_.symbol).flatMap(transformImport(selectors, _))
158+
if saneSym(sym)
159+
} yield (sym.isType, sym.name.decode.toString)
160+
161+
val syms = for {
162+
// For some reason `info.allImportedSymbols` does not show imported
163+
// type aliases when they are imported directly e.g.
164+
//
165+
// import scala.reflect.macros.Context
166+
//
167+
// As opposed to via import scala.reflect.macros._.
168+
// Thus we need to combine allImportedSymbols with the renameMap
169+
(isType, sym) <- (symNames.toList ++ renameMap.keys).distinct
170+
} yield (isType, renameMap.getOrElse((isType, sym), sym), sym, prefix)
171+
172+
myImports ++= syms
173+
}
174+
175+
private def updateUsedEarlierDefinitions(
176+
wrapperSym: Symbol,
177+
stats: List[tpd.Tree]
178+
)(using Context): Unit = {
179+
/*
180+
* We list the variables from the first wrapper
181+
* used from the user code.
182+
*
183+
* E.g. if, after wrapping, the code looks like
184+
* ```
185+
* class cmd2 {
186+
*
187+
* val cmd0 = ???
188+
* val cmd1 = ???
189+
*
190+
* import cmd0.{
191+
* n
192+
* }
193+
*
194+
* class Helper {
195+
* // user-typed code
196+
* val n0 = n + 1
197+
* }
198+
* }
199+
* ```
200+
* this would process the tree of `val n0 = n + 1`, find `n` as a tree like
201+
* `cmd2.this.cmd0.n`, and put `cmd0` in `uses`.
202+
*/
203+
204+
val typeTraverser: TypeTraverser = new TypeTraverser {
205+
def traverse(tpe: Type) = tpe match {
206+
case tr: TermRef if tr.prefix.typeSymbol == wrapperSym =>
207+
tr.designator match {
208+
case n: Name => usedEarlierDefinitions0 += n.decode.toString
209+
case s: Symbol => usedEarlierDefinitions0 += s.name.decode.toString
210+
case _ => // can this happen?
211+
}
212+
case _ =>
213+
traverseChildren(tpe)
214+
}
215+
}
216+
217+
val traverser: TreeTraverser = new TreeTraverser {
218+
def traverse(tree: Tree)(using Context) = tree match {
219+
case tpd.Select(node, name) if node.symbol == wrapperSym =>
220+
usedEarlierDefinitions0 += name.decode.toString
221+
case tt @ tpd.TypeTree() =>
222+
typeTraverser.traverse(tt.tpe)
223+
case _ =>
224+
traverseChildren(tree)
225+
}
226+
}
227+
228+
for (tree <- stats)
229+
traverser.traverse(tree)
230+
}
231+
232+
private def unpkg(tree: tpd.Tree): List[tpd.Tree] =
233+
tree match {
234+
case PackageDef(_, elems) => elems.flatMap(unpkg)
235+
case _ => List(tree)
236+
}
237+
238+
def run(using Context): Unit =
239+
val elems = unpkg(ctx.compilationUnit.tpdTree)
240+
def mainStats(trees: List[tpd.Tree]): List[tpd.Tree] =
241+
trees
242+
.reverseIterator
243+
.collectFirst {
244+
case TypeDef(name, rhs0: Template) => rhs0.body
245+
}
246+
.getOrElse(Nil)
247+
248+
val rootStats = mainStats(elems)
249+
val stats = (1 until userCodeNestingLevel)
250+
.foldLeft(rootStats)((trees, _) => mainStats(trees))
251+
252+
if (needsUsedEarlierDefinitions) {
253+
val wrapperSym = elems.last.symbol
254+
updateUsedEarlierDefinitions(wrapperSym, stats)
255+
}
256+
257+
stats.foreach {
258+
case i: Import => processImport(i)
259+
case t: tpd.DefDef => processTree(t)
260+
case t: tpd.ValDef => processTree(t)
261+
case t: tpd.TypeDef => processTree(t)
262+
case _ =>
263+
}

0 commit comments

Comments
 (0)