Skip to content

Commit 50d39ec

Browse files
authored
1 parent b3fea31 commit 50d39ec

File tree

10 files changed

+1261
-4
lines changed

10 files changed

+1261
-4
lines changed

.github/workflows/actions.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
fail-fast: false
3434
matrix:
3535
java-version: [8, 11]
36-
scala-version: [2.12.18, 2.13.12, 3.2.2, 3.3.1]
36+
scala-version: [2.12.18, 2.13.12, 3.2.2, 3.3.3]
3737
runs-on: ubuntu-latest
3838
steps:
3939
- uses: actions/checkout@v2
@@ -47,7 +47,7 @@ jobs:
4747
strategy:
4848
fail-fast: false
4949
matrix:
50-
scala-version: [2.12.18, 2.13.12, 3.2.2]
50+
scala-version: [2.12.18, 2.13.12, 3.3.3]
5151
runs-on: ubuntu-latest
5252
steps:
5353
- uses: actions/checkout@v2
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
package ammonite.compiler
2+
3+
import ammonite.util.{ImportData, Imports, Name => AmmName, Printer, Util}
4+
5+
import dotty.tools.dotc
6+
import dotty.tools.dotc.core.StdNames.nme
7+
import dotc.ast.Trees._
8+
import dotc.ast.{tpd, untpd}
9+
import dotc.core.Flags
10+
import dotc.core.Contexts._
11+
import dotc.core.Names.Name
12+
import dotc.core.Phases.Phase
13+
import dotc.core.Symbols.{NoSymbol, Symbol, newSymbol}
14+
import dotc.core.Types.{TermRef, Type, TypeTraverser}
15+
16+
import scala.collection.mutable
17+
18+
class AmmonitePhase(
19+
userCodeNestingLevel: => Int,
20+
needsUsedEarlierDefinitions: => Boolean
21+
) extends Phase:
22+
import tpd._
23+
24+
def phaseName: String = "ammonite"
25+
26+
private var myImports = new mutable.ListBuffer[(Boolean, String, String, Seq[AmmName])]
27+
private var usedEarlierDefinitions0 = new mutable.ListBuffer[String]
28+
29+
def importData: Seq[ImportData] =
30+
val grouped = myImports
31+
.toList
32+
.distinct
33+
.groupBy { case (a, b, c, d) => (b, c, d) }
34+
.mapValues(_.map(_._1))
35+
36+
val open = for {
37+
((fromName, toName, importString), items) <- grouped
38+
if !CompilerUtil.ignoredNames(fromName)
39+
} yield {
40+
val importType = items match{
41+
case Seq(true) => ImportData.Type
42+
case Seq(false) => ImportData.Term
43+
case Seq(_, _) => ImportData.TermType
44+
}
45+
46+
ImportData(AmmName(fromName), AmmName(toName), importString, importType)
47+
}
48+
49+
open.toVector.sortBy(x => Util.encodeScalaSourcePath(x.prefix))
50+
51+
def usedEarlierDefinitions: Seq[String] =
52+
usedEarlierDefinitions0.toList.distinct
53+
54+
private def saneSym(name: Name, sym: Symbol)(using Context): Boolean =
55+
!name.decode.toString.contains('$') &&
56+
sym.exists &&
57+
// !sym.is(Flags.Synthetic) &&
58+
!scala.util.Try(sym.is(Flags.Private)).toOption.getOrElse(true) &&
59+
!scala.util.Try(sym.is(Flags.Protected)).toOption.getOrElse(true) &&
60+
// sym.is(Flags.Public) &&
61+
!CompilerUtil.ignoredSyms(sym.toString) &&
62+
!CompilerUtil.ignoredNames(name.decode.toString)
63+
64+
private def saneSym(sym: Symbol)(using Context): Boolean =
65+
saneSym(sym.name, sym)
66+
67+
private def processTree(t: tpd.Tree)(using Context): Unit = {
68+
val sym = t.symbol
69+
val name = t match {
70+
case t: tpd.ValDef => t.name
71+
case _ => sym.name
72+
}
73+
if (saneSym(name, sym)) {
74+
val name = sym.name.decode.toString
75+
myImports.addOne((sym.isType, name, name, Nil))
76+
}
77+
}
78+
79+
private def processImport(i: tpd.Import)(using Context): Unit = {
80+
val expr = i.expr
81+
val selectors = i.selectors
82+
83+
// Most of that logic was adapted from AmmonitePlugin, the Scala 2 counterpart
84+
// of this file.
85+
86+
val prefix =
87+
val (_ :: nameListTail, symbolHead :: _) = {
88+
def rec(expr: tpd.Tree): List[(Name, Symbol)] = {
89+
expr match {
90+
case s @ tpd.Select(lhs, _) => (s.symbol.name -> s.symbol) :: rec(lhs)
91+
case i @ tpd.Ident(name) => List(name -> i.symbol)
92+
case t @ tpd.This(pkg) => List(pkg.name -> t.symbol)
93+
}
94+
}
95+
rec(expr).reverse.unzip
96+
}
97+
98+
val headFullPath = symbolHead.fullName.decode.toString.split('.')
99+
.map(n => if (n.endsWith("$")) n.stripSuffix("$") else n) // meh
100+
// prefix package imports with `_root_` to try and stop random
101+
// variables from interfering with them. If someone defines a value
102+
// called `_root_`, this will still break, but that's their problem
103+
val rootPrefix = if(symbolHead.denot.is(Flags.Package)) Seq("_root_") else Nil
104+
val tailPath = nameListTail.map(_.decode.toString)
105+
106+
(rootPrefix ++ headFullPath ++ tailPath).map(AmmName(_))
107+
108+
def isMask(sel: untpd.ImportSelector) = sel.name != nme.WILDCARD && sel.rename == nme.WILDCARD
109+
110+
val renameMap =
111+
112+
/**
113+
* A map of each name importable from `expr`, to a `Seq[Boolean]`
114+
* containing a `true` if there's a type-symbol you can import, `false`
115+
* if there's a non-type symbol and both if there are both type and
116+
* non-type symbols that are importable for that name
117+
*/
118+
val importableIsTypes =
119+
expr.tpe
120+
.allMembers
121+
.map(_.symbol)
122+
.filter(saneSym(_))
123+
.groupBy(_.name.decode.toString)
124+
.mapValues(_.map(_.isType).toVector)
125+
126+
val renamings = for{
127+
t @ untpd.ImportSelector(name, renameTree, _) <- selectors
128+
if !isMask(t)
129+
// getOrElse just in case...
130+
isType <- importableIsTypes.getOrElse(name.name.decode.toString, Nil)
131+
Ident(rename) <- Option(renameTree)
132+
} yield ((isType, rename.decode.toString), name.name.decode.toString)
133+
134+
renamings.toMap
135+
136+
137+
def isUnimportableUnlessRenamed(sym: Symbol): Boolean =
138+
sym eq NoSymbol
139+
140+
@scala.annotation.tailrec
141+
def transformImport(selectors: List[untpd.ImportSelector], sym: Symbol): List[Symbol] =
142+
selectors match {
143+
case Nil => Nil
144+
case sel :: Nil if sel.isWildcard =>
145+
if (isUnimportableUnlessRenamed(sym)) Nil
146+
else List(sym)
147+
case (sel @ untpd.ImportSelector(from, to, _)) :: _
148+
if from.name == (if (from.isTerm) sym.name.toTermName else sym.name.toTypeName) =>
149+
if (isMask(sel)) Nil
150+
else List(
151+
newSymbol(sym.owner, sel.rename, sym.flags, sym.info, sym.privateWithin, sym.coord)
152+
)
153+
case _ :: rest => transformImport(rest, sym)
154+
}
155+
156+
val symNames =
157+
for {
158+
sym <- expr.tpe.allMembers.map(_.symbol).flatMap(transformImport(selectors, _))
159+
if saneSym(sym)
160+
} yield (sym.isType, sym.name.decode.toString)
161+
162+
val syms = for {
163+
// For some reason `info.allImportedSymbols` does not show imported
164+
// type aliases when they are imported directly e.g.
165+
//
166+
// import scala.reflect.macros.Context
167+
//
168+
// As opposed to via import scala.reflect.macros._.
169+
// Thus we need to combine allImportedSymbols with the renameMap
170+
(isType, sym) <- (symNames.toList ++ renameMap.keys).distinct
171+
} yield (isType, renameMap.getOrElse((isType, sym), sym), sym, prefix)
172+
173+
myImports ++= syms
174+
}
175+
176+
private def updateUsedEarlierDefinitions(
177+
wrapperSym: Symbol,
178+
stats: List[tpd.Tree]
179+
)(using Context): Unit = {
180+
/*
181+
* We list the variables from the first wrapper
182+
* used from the user code.
183+
*
184+
* E.g. if, after wrapping, the code looks like
185+
* ```
186+
* class cmd2 {
187+
*
188+
* val cmd0 = ???
189+
* val cmd1 = ???
190+
*
191+
* import cmd0.{
192+
* n
193+
* }
194+
*
195+
* class Helper {
196+
* // user-typed code
197+
* val n0 = n + 1
198+
* }
199+
* }
200+
* ```
201+
* this would process the tree of `val n0 = n + 1`, find `n` as a tree like
202+
* `cmd2.this.cmd0.n`, and put `cmd0` in `uses`.
203+
*/
204+
205+
val typeTraverser: TypeTraverser = new TypeTraverser {
206+
def traverse(tpe: Type) = tpe match {
207+
case tr: TermRef if tr.prefix.typeSymbol == wrapperSym =>
208+
tr.designator match {
209+
case n: Name => usedEarlierDefinitions0 += n.decode.toString
210+
case s: Symbol => usedEarlierDefinitions0 += s.name.decode.toString
211+
case _ => // can this happen?
212+
}
213+
case _ =>
214+
traverseChildren(tpe)
215+
}
216+
}
217+
218+
val traverser: TreeTraverser = new TreeTraverser {
219+
def traverse(tree: Tree)(using Context) = tree match {
220+
case tpd.Select(node, name) if node.symbol == wrapperSym =>
221+
usedEarlierDefinitions0 += name.decode.toString
222+
case tt @ tpd.TypeTree() =>
223+
typeTraverser.traverse(tt.tpe)
224+
case _ =>
225+
traverseChildren(tree)
226+
}
227+
}
228+
229+
for (tree <- stats)
230+
traverser.traverse(tree)
231+
}
232+
233+
private def unpkg(tree: tpd.Tree): List[tpd.Tree] =
234+
tree match {
235+
case PackageDef(_, elems) => elems.flatMap(unpkg)
236+
case _ => List(tree)
237+
}
238+
239+
def run(using Context): Unit =
240+
val elems = unpkg(ctx.compilationUnit.tpdTree)
241+
def mainStats(trees: List[tpd.Tree]): List[tpd.Tree] =
242+
trees
243+
.reverseIterator
244+
.collectFirst {
245+
case TypeDef(name, rhs0: Template) => rhs0.body
246+
}
247+
.getOrElse(Nil)
248+
249+
val rootStats = mainStats(elems)
250+
val stats = (1 until userCodeNestingLevel)
251+
.foldLeft(rootStats)((trees, _) => mainStats(trees))
252+
253+
if (needsUsedEarlierDefinitions) {
254+
val wrapperSym = elems.last.symbol
255+
updateUsedEarlierDefinitions(wrapperSym, stats)
256+
}
257+
258+
stats.foreach {
259+
case i: Import => processImport(i)
260+
case t: tpd.DefDef => processTree(t)
261+
case t: tpd.ValDef => processTree(t)
262+
case t: tpd.TypeDef => processTree(t)
263+
case _ =>
264+
}

0 commit comments

Comments
 (0)