Skip to content

Commit 62ca3fc

Browse files
authored
Merge pull request #15663 from som-snytt/issue/13885-repl-parser-more-phaselike
2 parents 19eff87 + 5f4653d commit 62ca3fc

File tree

7 files changed

+167
-145
lines changed

7 files changed

+167
-145
lines changed

compiler/src/dotty/tools/repl/ReplCompilationUnit.scala

-8
This file was deleted.

compiler/src/dotty/tools/repl/ReplCompiler.scala

+145-113
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ import dotty.tools.dotc.transform.PostTyper
1515
import dotty.tools.dotc.typer.ImportInfo.{withRootImports, RootRef}
1616
import dotty.tools.dotc.typer.TyperPhase
1717
import dotty.tools.dotc.util.Spans._
18-
import dotty.tools.dotc.util.{ParsedComment, SourceFile}
18+
import dotty.tools.dotc.util.{ParsedComment, Property, SourceFile}
1919
import dotty.tools.dotc.{CompilationUnit, Compiler, Run}
2020
import dotty.tools.repl.results._
2121

2222
import scala.collection.mutable
23+
import scala.util.chaining.given
2324

2425
/** This subclass of `Compiler` is adapted for use in the REPL.
2526
*
@@ -29,12 +30,14 @@ import scala.collection.mutable
2930
* - provides utility to query the type of an expression
3031
* - provides utility to query the documentation of an expression
3132
*/
32-
class ReplCompiler extends Compiler {
33+
class ReplCompiler extends Compiler:
3334

3435
override protected def frontendPhases: List[List[Phase]] = List(
35-
List(new TyperPhase(addRootImports = false)),
36-
List(new CollectTopLevelImports),
37-
List(new PostTyper),
36+
List(Parser()),
37+
List(ReplPhase()),
38+
List(TyperPhase(addRootImports = false)),
39+
List(CollectTopLevelImports()),
40+
List(PostTyper()),
3841
)
3942

4043
def newRun(initCtx: Context, state: State): Run =
@@ -46,7 +49,7 @@ class ReplCompiler extends Compiler {
4649

4750
def importPreviousRun(id: Int)(using Context) = {
4851
// we first import the wrapper object id
49-
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id)
52+
val path = nme.EMPTY_PACKAGE ++ "." ++ ReplCompiler.objectNames(id)
5053
val ctx0 = ctx.fresh
5154
.setNewScope
5255
.withRootImports(RootRef(() => requiredModuleRef(path)) :: Nil)
@@ -67,117 +70,29 @@ class ReplCompiler extends Compiler {
6770
}
6871
run.suppressions.initSuspendedMessages(state.context.run)
6972
run
73+
end newRun
7074

71-
private val objectNames = mutable.Map.empty[Int, TermName]
75+
private def packaged(stats: List[untpd.Tree])(using Context): untpd.PackageDef =
76+
import untpd.*
77+
PackageDef(Ident(nme.EMPTY_PACKAGE), stats)
7278

73-
private case class Definitions(stats: List[untpd.Tree], state: State)
74-
75-
private def definitions(trees: List[untpd.Tree], state: State): Definitions = inContext(state.context) {
76-
import untpd._
77-
78-
// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
79-
val flattened = trees match {
80-
case List(Block(stats, expr)) =>
81-
if (expr eq EmptyTree) stats // happens when expr is not an expression
82-
else stats :+ expr
83-
case _ =>
84-
trees
85-
}
86-
87-
var valIdx = state.valIndex
88-
val defs = new mutable.ListBuffer[Tree]
89-
90-
/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
91-
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
92-
def maybeBumpValIdx(tree: Tree): Unit = tree match
93-
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
94-
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
95-
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
96-
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
97-
case Some(n) if n >= valIdx => valIdx = n + 1
98-
case _ =>
99-
case _ =>
100-
101-
flattened.foreach {
102-
case expr @ Assign(id: Ident, _) =>
103-
// special case simple reassignment (e.g. x = 3)
104-
// in order to print the new value in the REPL
105-
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
106-
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
107-
defs += expr += assign
108-
case expr if expr.isTerm =>
109-
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
110-
valIdx += 1
111-
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
112-
defs += vd
113-
case other =>
114-
maybeBumpValIdx(other)
115-
defs += other
116-
}
117-
118-
Definitions(
119-
defs.toList,
120-
state.copy(
121-
objectIndex = state.objectIndex + 1,
122-
valIndex = valIdx
123-
)
124-
)
125-
}
126-
127-
/** Wrap trees in an object and add imports from the previous compilations
128-
*
129-
* The resulting structure is something like:
130-
*
131-
* ```
132-
* package <none> {
133-
* object rs$line$nextId {
134-
* import rs$line${i <- 0 until nextId}._
135-
*
136-
* <trees>
137-
* }
138-
* }
139-
* ```
140-
*/
141-
private def wrapped(defs: Definitions, objectTermName: TermName, span: Span): untpd.PackageDef =
142-
inContext(defs.state.context) {
143-
import untpd._
144-
145-
val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
146-
val module = ModuleDef(objectTermName, tmpl)
147-
.withSpan(span)
79+
final def compile(parsed: Parsed)(using state: State): Result[(CompilationUnit, State)] =
80+
assert(!parsed.trees.isEmpty)
14881

149-
PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
82+
given Context = state.context
83+
val unit = ReplCompilationUnit(ctx.source).tap { unit =>
84+
unit.untpdTree = packaged(parsed.trees)
85+
unit.untpdTree.putAttachment(ReplCompiler.ReplState, state)
15086
}
151-
152-
private def createUnit(defs: Definitions, span: Span)(using Context): CompilationUnit = {
153-
val objectName = ctx.source.file.toString
154-
assert(objectName.startsWith(str.REPL_SESSION_LINE))
155-
assert(objectName.endsWith(defs.state.objectIndex.toString))
156-
val objectTermName = ctx.source.file.toString.toTermName
157-
objectNames.update(defs.state.objectIndex, objectTermName)
158-
159-
val unit = new ReplCompilationUnit(ctx.source)
160-
unit.untpdTree = wrapped(defs, objectTermName, span)
161-
unit
162-
}
163-
164-
private def runCompilationUnit(unit: CompilationUnit, state: State): Result[(CompilationUnit, State)] = {
165-
val ctx = state.context
16687
ctx.run.nn.compileUnits(unit :: Nil)
167-
ctx.run.nn.printSummary() // this outputs "2 errors found" like normal - but we might decide that's needlessly noisy for the REPL
168-
169-
if (!ctx.reporter.hasErrors) (unit, state).result
170-
else ctx.reporter.removeBufferedMessages(using ctx).errors
171-
}
88+
ctx.run.nn.printSummary() // "2 errors found"
17289

173-
final def compile(parsed: Parsed)(implicit state: State): Result[(CompilationUnit, State)] = {
174-
assert(!parsed.trees.isEmpty)
175-
val defs = definitions(parsed.trees, state)
176-
val unit = createUnit(defs, Span(0, parsed.trees.last.span.end))(using state.context)
177-
runCompilationUnit(unit, defs.state)
178-
}
90+
val newState = unit.tpdTree.getAttachment(ReplCompiler.ReplState).get
91+
if !ctx.reporter.hasErrors then (unit, newState).result
92+
else ctx.reporter.removeBufferedMessages.errors
93+
end compile
17994

180-
final def typeOf(expr: String)(implicit state: State): Result[String] =
95+
final def typeOf(expr: String)(using state: State): Result[String] =
18196
typeCheck(expr).map { tree =>
18297
given Context = state.context
18398
tree.rhs match {
@@ -190,7 +105,7 @@ class ReplCompiler extends Compiler {
190105
}
191106
}
192107

193-
def docOf(expr: String)(implicit state: State): Result[String] = inContext(state.context) {
108+
def docOf(expr: String)(using state: State): Result[String] = inContext(state.context) {
194109

195110
/** Extract the "selected" symbol from `tree`.
196111
*
@@ -237,7 +152,7 @@ class ReplCompiler extends Compiler {
237152
}
238153
}
239154

240-
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(implicit state: State): Result[tpd.ValDef] = {
155+
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
241156

242157
def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
243158
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
@@ -300,4 +215,121 @@ class ReplCompiler extends Compiler {
300215
}
301216
}
302217
}
303-
}
218+
object ReplCompiler:
219+
val ReplState: Property.StickyKey[State] = Property.StickyKey()
220+
val objectNames = mutable.Map.empty[Int, TermName]
221+
end ReplCompiler
222+
223+
class ReplCompilationUnit(source: SourceFile) extends CompilationUnit(source):
224+
override def isSuspendable: Boolean = false
225+
226+
/** A placeholder phase that receives parse trees..
227+
*
228+
* It is called "parser" for the convenience of collective muscle memory.
229+
*
230+
* This enables -Vprint:parser.
231+
*/
232+
class Parser extends Phase:
233+
def phaseName: String = "parser"
234+
def run(using Context): Unit = ()
235+
end Parser
236+
237+
/** A phase that assembles wrapped parse trees from user input.
238+
*
239+
* Ths `ReplState` attachment indicates Repl wrapping is required.
240+
*
241+
* This enables -Vprint:repl so that users can see how their code snippet was wrapped.
242+
*/
243+
class ReplPhase extends Phase:
244+
def phaseName: String = "repl"
245+
246+
def run(using Context): Unit =
247+
ctx.compilationUnit.untpdTree match
248+
case pkg @ PackageDef(_, stats) =>
249+
pkg.getAttachment(ReplCompiler.ReplState).foreach {
250+
case given State =>
251+
val defs = definitions(stats)
252+
val res = wrapped(defs, Span(0, stats.last.span.end))
253+
res.putAttachment(ReplCompiler.ReplState, defs.state)
254+
ctx.compilationUnit.untpdTree = res
255+
}
256+
case _ =>
257+
end run
258+
259+
private case class Definitions(stats: List[untpd.Tree], state: State)
260+
261+
private def definitions(trees: List[untpd.Tree])(using Context, State): Definitions =
262+
import untpd.*
263+
264+
// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
265+
val flattened = trees match {
266+
case List(Block(stats, expr)) =>
267+
if (expr eq EmptyTree) stats // happens when expr is not an expression
268+
else stats :+ expr
269+
case _ =>
270+
trees
271+
}
272+
273+
val state = summon[State]
274+
var valIdx = state.valIndex
275+
val defs = mutable.ListBuffer.empty[Tree]
276+
277+
/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
278+
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
279+
def maybeBumpValIdx(tree: Tree): Unit = tree match
280+
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
281+
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
282+
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
283+
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
284+
case Some(n) if n >= valIdx => valIdx = n + 1
285+
case _ =>
286+
case _ =>
287+
288+
flattened.foreach {
289+
case expr @ Assign(id: Ident, _) =>
290+
// special case simple reassignment (e.g. x = 3)
291+
// in order to print the new value in the REPL
292+
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
293+
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
294+
defs += expr += assign
295+
case expr if expr.isTerm =>
296+
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
297+
valIdx += 1
298+
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
299+
defs += vd
300+
case other =>
301+
maybeBumpValIdx(other)
302+
defs += other
303+
}
304+
305+
Definitions(defs.toList, state.copy(objectIndex = state.objectIndex + 1, valIndex = valIdx))
306+
end definitions
307+
308+
/** Wrap trees in an object and add imports from the previous compilations.
309+
*
310+
* The resulting structure is something like:
311+
*
312+
* ```
313+
* package <none> {
314+
* object rs$line$nextId {
315+
* import rs$line${i <- 0 until nextId}.*
316+
*
317+
* <trees>
318+
* }
319+
* }
320+
* ```
321+
*/
322+
private def wrapped(defs: Definitions, span: Span)(using Context): untpd.PackageDef =
323+
import untpd.*
324+
325+
val objectName = ctx.source.file.toString
326+
assert(objectName.startsWith(str.REPL_SESSION_LINE))
327+
assert(objectName.endsWith(defs.state.objectIndex.toString))
328+
val objectTermName = objectName.toTermName
329+
ReplCompiler.objectNames.update(defs.state.objectIndex, objectTermName)
330+
331+
val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
332+
val module = ModuleDef(objectTermName, tmpl).withSpan(span)
333+
334+
PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
335+
end ReplPhase

0 commit comments

Comments
 (0)