diff --git a/compiler/src/dotty/tools/repl/ReplCompilationUnit.scala b/compiler/src/dotty/tools/repl/ReplCompilationUnit.scala deleted file mode 100644 index d3f2cee7e207..000000000000 --- a/compiler/src/dotty/tools/repl/ReplCompilationUnit.scala +++ /dev/null @@ -1,8 +0,0 @@ -package dotty.tools.repl - -import dotty.tools.dotc.CompilationUnit -import dotty.tools.dotc.util.SourceFile - - -class ReplCompilationUnit(source: SourceFile) extends CompilationUnit(source): - override def isSuspendable: Boolean = false diff --git a/compiler/src/dotty/tools/repl/ReplCompiler.scala b/compiler/src/dotty/tools/repl/ReplCompiler.scala index fb71d4bbb805..aab7ea46e9ed 100644 --- a/compiler/src/dotty/tools/repl/ReplCompiler.scala +++ b/compiler/src/dotty/tools/repl/ReplCompiler.scala @@ -15,11 +15,12 @@ import dotty.tools.dotc.transform.PostTyper import dotty.tools.dotc.typer.ImportInfo.{withRootImports, RootRef} import dotty.tools.dotc.typer.TyperPhase import dotty.tools.dotc.util.Spans._ -import dotty.tools.dotc.util.{ParsedComment, SourceFile} +import dotty.tools.dotc.util.{ParsedComment, Property, SourceFile} import dotty.tools.dotc.{CompilationUnit, Compiler, Run} import dotty.tools.repl.results._ import scala.collection.mutable +import scala.util.chaining.given /** This subclass of `Compiler` is adapted for use in the REPL. * @@ -29,12 +30,14 @@ import scala.collection.mutable * - provides utility to query the type of an expression * - provides utility to query the documentation of an expression */ -class ReplCompiler extends Compiler { +class ReplCompiler extends Compiler: override protected def frontendPhases: List[List[Phase]] = List( - List(new TyperPhase(addRootImports = false)), - List(new CollectTopLevelImports), - List(new PostTyper), + List(Parser()), + List(ReplPhase()), + List(TyperPhase(addRootImports = false)), + List(CollectTopLevelImports()), + List(PostTyper()), ) def newRun(initCtx: Context, state: State): Run = @@ -46,7 +49,7 @@ class ReplCompiler extends Compiler { def importPreviousRun(id: Int)(using Context) = { // we first import the wrapper object id - val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id) + val path = nme.EMPTY_PACKAGE ++ "." ++ ReplCompiler.objectNames(id) val ctx0 = ctx.fresh .setNewScope .withRootImports(RootRef(() => requiredModuleRef(path)) :: Nil) @@ -67,117 +70,29 @@ class ReplCompiler extends Compiler { } run.suppressions.initSuspendedMessages(state.context.run) run + end newRun - private val objectNames = mutable.Map.empty[Int, TermName] + private def packaged(stats: List[untpd.Tree])(using Context): untpd.PackageDef = + import untpd.* + PackageDef(Ident(nme.EMPTY_PACKAGE), stats) - private case class Definitions(stats: List[untpd.Tree], state: State) - - private def definitions(trees: List[untpd.Tree], state: State): Definitions = inContext(state.context) { - import untpd._ - - // If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)` - val flattened = trees match { - case List(Block(stats, expr)) => - if (expr eq EmptyTree) stats // happens when expr is not an expression - else stats :+ expr - case _ => - trees - } - - var valIdx = state.valIndex - val defs = new mutable.ListBuffer[Tree] - - /** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number, - * such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */ - def maybeBumpValIdx(tree: Tree): Unit = tree match - case apply: Apply => for a <- apply.args do maybeBumpValIdx(a) - case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t) - case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p) - case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match - case Some(n) if n >= valIdx => valIdx = n + 1 - case _ => - case _ => - - flattened.foreach { - case expr @ Assign(id: Ident, _) => - // special case simple reassignment (e.g. x = 3) - // in order to print the new value in the REPL - val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName - val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span) - defs += expr += assign - case expr if expr.isTerm => - val resName = (str.REPL_RES_PREFIX + valIdx).toTermName - valIdx += 1 - val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span) - defs += vd - case other => - maybeBumpValIdx(other) - defs += other - } - - Definitions( - defs.toList, - state.copy( - objectIndex = state.objectIndex + 1, - valIndex = valIdx - ) - ) - } - - /** Wrap trees in an object and add imports from the previous compilations - * - * The resulting structure is something like: - * - * ``` - * package { - * object rs$line$nextId { - * import rs$line${i <- 0 until nextId}._ - * - * - * } - * } - * ``` - */ - private def wrapped(defs: Definitions, objectTermName: TermName, span: Span): untpd.PackageDef = - inContext(defs.state.context) { - import untpd._ - - val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats) - val module = ModuleDef(objectTermName, tmpl) - .withSpan(span) + final def compile(parsed: Parsed)(using state: State): Result[(CompilationUnit, State)] = + assert(!parsed.trees.isEmpty) - PackageDef(Ident(nme.EMPTY_PACKAGE), List(module)) + given Context = state.context + val unit = ReplCompilationUnit(ctx.source).tap { unit => + unit.untpdTree = packaged(parsed.trees) + unit.untpdTree.putAttachment(ReplCompiler.ReplState, state) } - - private def createUnit(defs: Definitions, span: Span)(using Context): CompilationUnit = { - val objectName = ctx.source.file.toString - assert(objectName.startsWith(str.REPL_SESSION_LINE)) - assert(objectName.endsWith(defs.state.objectIndex.toString)) - val objectTermName = ctx.source.file.toString.toTermName - objectNames.update(defs.state.objectIndex, objectTermName) - - val unit = new ReplCompilationUnit(ctx.source) - unit.untpdTree = wrapped(defs, objectTermName, span) - unit - } - - private def runCompilationUnit(unit: CompilationUnit, state: State): Result[(CompilationUnit, State)] = { - val ctx = state.context ctx.run.nn.compileUnits(unit :: Nil) - ctx.run.nn.printSummary() // this outputs "2 errors found" like normal - but we might decide that's needlessly noisy for the REPL - - if (!ctx.reporter.hasErrors) (unit, state).result - else ctx.reporter.removeBufferedMessages(using ctx).errors - } + ctx.run.nn.printSummary() // "2 errors found" - final def compile(parsed: Parsed)(implicit state: State): Result[(CompilationUnit, State)] = { - assert(!parsed.trees.isEmpty) - val defs = definitions(parsed.trees, state) - val unit = createUnit(defs, Span(0, parsed.trees.last.span.end))(using state.context) - runCompilationUnit(unit, defs.state) - } + val newState = unit.tpdTree.getAttachment(ReplCompiler.ReplState).get + if !ctx.reporter.hasErrors then (unit, newState).result + else ctx.reporter.removeBufferedMessages.errors + end compile - final def typeOf(expr: String)(implicit state: State): Result[String] = + final def typeOf(expr: String)(using state: State): Result[String] = typeCheck(expr).map { tree => given Context = state.context tree.rhs match { @@ -190,7 +105,7 @@ class ReplCompiler extends Compiler { } } - def docOf(expr: String)(implicit state: State): Result[String] = inContext(state.context) { + def docOf(expr: String)(using state: State): Result[String] = inContext(state.context) { /** Extract the "selected" symbol from `tree`. * @@ -237,7 +152,7 @@ class ReplCompiler extends Compiler { } } - final def typeCheck(expr: String, errorsAllowed: Boolean = false)(implicit state: State): Result[tpd.ValDef] = { + final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = { def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = { def wrap(trees: List[untpd.Tree]): untpd.PackageDef = { @@ -300,4 +215,121 @@ class ReplCompiler extends Compiler { } } } -} +object ReplCompiler: + val ReplState: Property.StickyKey[State] = Property.StickyKey() + val objectNames = mutable.Map.empty[Int, TermName] +end ReplCompiler + +class ReplCompilationUnit(source: SourceFile) extends CompilationUnit(source): + override def isSuspendable: Boolean = false + +/** A placeholder phase that receives parse trees.. + * + * It is called "parser" for the convenience of collective muscle memory. + * + * This enables -Vprint:parser. + */ +class Parser extends Phase: + def phaseName: String = "parser" + def run(using Context): Unit = () +end Parser + +/** A phase that assembles wrapped parse trees from user input. + * + * Ths `ReplState` attachment indicates Repl wrapping is required. + * + * This enables -Vprint:repl so that users can see how their code snippet was wrapped. + */ +class ReplPhase extends Phase: + def phaseName: String = "repl" + + def run(using Context): Unit = + ctx.compilationUnit.untpdTree match + case pkg @ PackageDef(_, stats) => + pkg.getAttachment(ReplCompiler.ReplState).foreach { + case given State => + val defs = definitions(stats) + val res = wrapped(defs, Span(0, stats.last.span.end)) + res.putAttachment(ReplCompiler.ReplState, defs.state) + ctx.compilationUnit.untpdTree = res + } + case _ => + end run + + private case class Definitions(stats: List[untpd.Tree], state: State) + + private def definitions(trees: List[untpd.Tree])(using Context, State): Definitions = + import untpd.* + + // If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)` + val flattened = trees match { + case List(Block(stats, expr)) => + if (expr eq EmptyTree) stats // happens when expr is not an expression + else stats :+ expr + case _ => + trees + } + + val state = summon[State] + var valIdx = state.valIndex + val defs = mutable.ListBuffer.empty[Tree] + + /** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number, + * such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */ + def maybeBumpValIdx(tree: Tree): Unit = tree match + case apply: Apply => for a <- apply.args do maybeBumpValIdx(a) + case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t) + case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p) + case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match + case Some(n) if n >= valIdx => valIdx = n + 1 + case _ => + case _ => + + flattened.foreach { + case expr @ Assign(id: Ident, _) => + // special case simple reassignment (e.g. x = 3) + // in order to print the new value in the REPL + val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName + val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span) + defs += expr += assign + case expr if expr.isTerm => + val resName = (str.REPL_RES_PREFIX + valIdx).toTermName + valIdx += 1 + val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span) + defs += vd + case other => + maybeBumpValIdx(other) + defs += other + } + + Definitions(defs.toList, state.copy(objectIndex = state.objectIndex + 1, valIndex = valIdx)) + end definitions + + /** Wrap trees in an object and add imports from the previous compilations. + * + * The resulting structure is something like: + * + * ``` + * package { + * object rs$line$nextId { + * import rs$line${i <- 0 until nextId}.* + * + * + * } + * } + * ``` + */ + private def wrapped(defs: Definitions, span: Span)(using Context): untpd.PackageDef = + import untpd.* + + val objectName = ctx.source.file.toString + assert(objectName.startsWith(str.REPL_SESSION_LINE)) + assert(objectName.endsWith(defs.state.objectIndex.toString)) + val objectTermName = objectName.toTermName + ReplCompiler.objectNames.update(defs.state.objectIndex, objectTermName) + + val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats) + val module = ModuleDef(objectTermName, tmpl).withSpan(span) + + PackageDef(Ident(nme.EMPTY_PACKAGE), List(module)) +end ReplPhase diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index 77be30bec25e..e64ecf02e2e3 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -138,7 +138,7 @@ class ReplDriver(settings: Array[String], * observable outside of the CLI, for this reason, most helper methods are * `protected final` to facilitate testing. */ - final def runUntilQuit(initialState: State = initialState): State = { + final def runUntilQuit(using initialState: State = initialState)(): State = { val terminal = new JLineTerminal out.println( @@ -162,17 +162,17 @@ class ReplDriver(settings: Array[String], } } - @tailrec def loop(state: State): State = { + @tailrec def loop(using state: State)(): State = { val res = readLine(state) if (res == Quit) state - else loop(interpret(res)(state)) + else loop(using interpret(res))() } - try runBody { loop(initialState) } + try runBody { loop() } finally terminal.close() } - final def run(input: String)(implicit state: State): State = runBody { + final def run(input: String)(using state: State): State = runBody { val parsed = ParseResult(input)(state) interpret(parsed) } @@ -180,7 +180,7 @@ class ReplDriver(settings: Array[String], private def runBody(body: => State): State = rendering.classLoader()(using rootCtx).asContext(withRedirectedOutput(body)) // TODO: i5069 - final def bind(name: String, value: Any)(implicit state: State): State = state + final def bind(name: String, value: Any)(using state: State): State = state // redirecting the output allows us to test `println` in scripted tests private def withRedirectedOutput(op: => State): State = { @@ -243,7 +243,7 @@ class ReplDriver(settings: Array[String], .getOrElse(Nil) end completions - private def interpret(res: ParseResult)(implicit state: State): State = { + private def interpret(res: ParseResult)(using state: State): State = { res match { case parsed: Parsed if parsed.trees.nonEmpty => compile(parsed, state) @@ -279,7 +279,7 @@ class ReplDriver(settings: Array[String], imports.foldLeft(ctx.fresh.setNewScope)((ctx, imp) => ctx.importContext(imp, imp.symbol(using ctx))) - implicit val state = { + given State = { val state0 = newRun(istate, parsed.reporter) state0.copy(context = state0.context.withSource(parsed.source)) } @@ -305,14 +305,14 @@ class ReplDriver(settings: Array[String], inContext(newState.context) { val (updatedState, definitions) = if (!ctx.settings.XreplDisableDisplay.value) - renderDefinitions(unit.tpdTree, newestWrapper)(newStateWithImports) + renderDefinitions(unit.tpdTree, newestWrapper)(using newStateWithImports) else (newStateWithImports, Seq.empty) // output is printed in the order it was put in. warnings should be // shown before infos (eg. typedefs) for the same line. column // ordering is mostly to make tests deterministic - implicit val diagnosticOrdering: Ordering[Diagnostic] = + given Ordering[Diagnostic] = Ordering[(Int, Int, Int)].on(d => (d.pos.line, -d.level, d.pos.column)) (definitions ++ warnings) @@ -325,7 +325,7 @@ class ReplDriver(settings: Array[String], ) } - private def renderDefinitions(tree: tpd.Tree, newestWrapper: Name)(implicit state: State): (State, Seq[Diagnostic]) = { + private def renderDefinitions(tree: tpd.Tree, newestWrapper: Name)(using state: State): (State, Seq[Diagnostic]) = { given Context = state.context def resAndUnit(denot: Denotation) = { @@ -417,7 +417,7 @@ class ReplDriver(settings: Array[String], } /** Interpret `cmd` to action and propagate potentially new `state` */ - private def interpretCommand(cmd: Command)(implicit state: State): State = cmd match { + private def interpretCommand(cmd: Command)(using state: State): State = cmd match { case UnknownCommand(cmd) => out.println(s"""Unknown command: "$cmd", run ":help" for a list of commands""") state @@ -465,7 +465,7 @@ class ReplDriver(settings: Array[String], expr match { case "" => out.println(s":type ") case _ => - compiler.typeOf(expr)(newRun(state)).fold( + compiler.typeOf(expr)(using newRun(state)).fold( displayErrors, res => out.println(res) // result has some highlights ) @@ -476,7 +476,7 @@ class ReplDriver(settings: Array[String], expr match { case "" => out.println(s":doc ") case _ => - compiler.docOf(expr)(newRun(state)).fold( + compiler.docOf(expr)(using newRun(state)).fold( displayErrors, res => out.println(res) ) @@ -499,7 +499,7 @@ class ReplDriver(settings: Array[String], } /** shows all errors nicely formatted */ - private def displayErrors(errs: Seq[Diagnostic])(implicit state: State): State = { + private def displayErrors(errs: Seq[Diagnostic])(using state: State): State = { errs.foreach(printDiagnostic) state } @@ -513,7 +513,7 @@ class ReplDriver(settings: Array[String], } /** Print warnings & errors using ReplConsoleReporter, and info straight to out */ - private def printDiagnostic(dia: Diagnostic)(implicit state: State) = dia.level match + private def printDiagnostic(dia: Diagnostic)(using state: State) = dia.level match case interfaces.Diagnostic.INFO => out.println(dia.msg) // print REPL's special info diagnostics directly to out case _ => ReplConsoleReporter.doReport(dia)(using state.context) diff --git a/compiler/src/dotty/tools/repl/ScriptEngine.scala b/compiler/src/dotty/tools/repl/ScriptEngine.scala index 6be76e0fa369..7d385daa43e4 100644 --- a/compiler/src/dotty/tools/repl/ScriptEngine.scala +++ b/compiler/src/dotty/tools/repl/ScriptEngine.scala @@ -37,7 +37,7 @@ class ScriptEngine extends AbstractScriptEngine { @throws[ScriptException] def eval(script: String, context: ScriptContext): Object = { val vid = state.valIndex - state = driver.run(script)(state) + state = driver.run(script)(using state) val oid = state.objectIndex Class.forName(s"${Rendering.REPL_WRAPPER_NAME_PREFIX}$oid", true, rendering.classLoader()(using state.context)) .getDeclaredMethods.find(_.getName == s"${str.REPL_RES_PREFIX}$vid") diff --git a/compiler/test/dotty/tools/repl/ReplTest.scala b/compiler/test/dotty/tools/repl/ReplTest.scala index 1af1f68d6533..34cad747fde6 100644 --- a/compiler/test/dotty/tools/repl/ReplTest.scala +++ b/compiler/test/dotty/tools/repl/ReplTest.scala @@ -20,7 +20,7 @@ import org.junit.{After, Before} import org.junit.Assert._ class ReplTest(options: Array[String] = ReplTest.defaultOptions, out: ByteArrayOutputStream = new ByteArrayOutputStream) -extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.name)) with MessageRendering { +extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.name)) with MessageRendering: /** Get the stored output from `out`, resetting the buffer */ def storedOutput(): String = { val output = stripColor(out.toString(StandardCharsets.UTF_8.name)) @@ -50,7 +50,7 @@ extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.na def evaluate(state: State, input: String) = try { - val nstate = run(input.drop(prompt.length))(state) + val nstate = run(input.drop(prompt.length))(using state) val out = input + EOL + storedOutput() (out, nstate) } @@ -102,7 +102,6 @@ extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.na fail(s"Error in script $name, expected output did not match actual") end if } -} object ReplTest: val commonOptions = Array("-color:never", "-language:experimental.erasedDefinitions", "-pagewidth", "80") diff --git a/compiler/test/dotty/tools/repl/TypeTests.scala b/compiler/test/dotty/tools/repl/TypeTests.scala index 6a8c38867ff0..d864d61d07aa 100644 --- a/compiler/test/dotty/tools/repl/TypeTests.scala +++ b/compiler/test/dotty/tools/repl/TypeTests.scala @@ -1,9 +1,9 @@ package dotty.tools.repl -import org.junit.Assert._ +import org.junit.Assert.* import org.junit.Test -class TypeTests extends ReplTest { +class TypeTests extends ReplTest: @Test def typeOf1 = initially { run(":type 1") assertEquals("Int", storedOutput().trim) @@ -27,4 +27,3 @@ class TypeTests extends ReplTest { run(":type") assertEquals(":type ", storedOutput().trim) } -} diff --git a/language-server/src/dotty/tools/languageserver/worksheet/ReplProcess.scala b/language-server/src/dotty/tools/languageserver/worksheet/ReplProcess.scala index faed5754861e..e8c02744397a 100644 --- a/language-server/src/dotty/tools/languageserver/worksheet/ReplProcess.scala +++ b/language-server/src/dotty/tools/languageserver/worksheet/ReplProcess.scala @@ -10,7 +10,7 @@ object ReplProcess { while (true) { val code = in.next() // blocking - state = driver.run(code)(state) + state = driver.run(code)(using state) Console.print(InputStreamConsumer.delimiter) // needed to mark the end of REPL output Console.flush() }