Skip to content

REPL goes through a phase #15663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions compiler/src/dotty/tools/repl/ReplCompilationUnit.scala

This file was deleted.

258 changes: 145 additions & 113 deletions compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ import dotty.tools.dotc.transform.PostTyper
import dotty.tools.dotc.typer.ImportInfo.{withRootImports, RootRef}
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.util.Spans._
import dotty.tools.dotc.util.{ParsedComment, SourceFile}
import dotty.tools.dotc.util.{ParsedComment, Property, SourceFile}
import dotty.tools.dotc.{CompilationUnit, Compiler, Run}
import dotty.tools.repl.results._

import scala.collection.mutable
import scala.util.chaining.given

/** This subclass of `Compiler` is adapted for use in the REPL.
*
Expand All @@ -29,12 +30,14 @@ import scala.collection.mutable
* - provides utility to query the type of an expression
* - provides utility to query the documentation of an expression
*/
class ReplCompiler extends Compiler {
class ReplCompiler extends Compiler:

override protected def frontendPhases: List[List[Phase]] = List(
List(new TyperPhase(addRootImports = false)),
List(new CollectTopLevelImports),
List(new PostTyper),
List(Parser()),
List(ReplPhase()),
List(TyperPhase(addRootImports = false)),
List(CollectTopLevelImports()),
List(PostTyper()),
)

def newRun(initCtx: Context, state: State): Run =
Expand All @@ -46,7 +49,7 @@ class ReplCompiler extends Compiler {

def importPreviousRun(id: Int)(using Context) = {
// we first import the wrapper object id
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id)
val path = nme.EMPTY_PACKAGE ++ "." ++ ReplCompiler.objectNames(id)
val ctx0 = ctx.fresh
.setNewScope
.withRootImports(RootRef(() => requiredModuleRef(path)) :: Nil)
Expand All @@ -67,117 +70,29 @@ class ReplCompiler extends Compiler {
}
run.suppressions.initSuspendedMessages(state.context.run)
run
end newRun

private val objectNames = mutable.Map.empty[Int, TermName]
private def packaged(stats: List[untpd.Tree])(using Context): untpd.PackageDef =
import untpd.*
PackageDef(Ident(nme.EMPTY_PACKAGE), stats)

private case class Definitions(stats: List[untpd.Tree], state: State)

private def definitions(trees: List[untpd.Tree], state: State): Definitions = inContext(state.context) {
import untpd._

// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
val flattened = trees match {
case List(Block(stats, expr)) =>
if (expr eq EmptyTree) stats // happens when expr is not an expression
else stats :+ expr
case _ =>
trees
}

var valIdx = state.valIndex
val defs = new mutable.ListBuffer[Tree]

/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
def maybeBumpValIdx(tree: Tree): Unit = tree match
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
case Some(n) if n >= valIdx => valIdx = n + 1
case _ =>
case _ =>

flattened.foreach {
case expr @ Assign(id: Ident, _) =>
// special case simple reassignment (e.g. x = 3)
// in order to print the new value in the REPL
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
defs += expr += assign
case expr if expr.isTerm =>
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
valIdx += 1
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
defs += vd
case other =>
maybeBumpValIdx(other)
defs += other
}

Definitions(
defs.toList,
state.copy(
objectIndex = state.objectIndex + 1,
valIndex = valIdx
)
)
}

/** Wrap trees in an object and add imports from the previous compilations
*
* The resulting structure is something like:
*
* ```
* package <none> {
* object rs$line$nextId {
* import rs$line${i <- 0 until nextId}._
*
* <trees>
* }
* }
* ```
*/
private def wrapped(defs: Definitions, objectTermName: TermName, span: Span): untpd.PackageDef =
inContext(defs.state.context) {
import untpd._

val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
val module = ModuleDef(objectTermName, tmpl)
.withSpan(span)
final def compile(parsed: Parsed)(using state: State): Result[(CompilationUnit, State)] =
assert(!parsed.trees.isEmpty)

PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
given Context = state.context
val unit = ReplCompilationUnit(ctx.source).tap { unit =>
unit.untpdTree = packaged(parsed.trees)
unit.untpdTree.putAttachment(ReplCompiler.ReplState, state)
}

private def createUnit(defs: Definitions, span: Span)(using Context): CompilationUnit = {
val objectName = ctx.source.file.toString
assert(objectName.startsWith(str.REPL_SESSION_LINE))
assert(objectName.endsWith(defs.state.objectIndex.toString))
val objectTermName = ctx.source.file.toString.toTermName
objectNames.update(defs.state.objectIndex, objectTermName)

val unit = new ReplCompilationUnit(ctx.source)
unit.untpdTree = wrapped(defs, objectTermName, span)
unit
}

private def runCompilationUnit(unit: CompilationUnit, state: State): Result[(CompilationUnit, State)] = {
val ctx = state.context
ctx.run.nn.compileUnits(unit :: Nil)
ctx.run.nn.printSummary() // this outputs "2 errors found" like normal - but we might decide that's needlessly noisy for the REPL

if (!ctx.reporter.hasErrors) (unit, state).result
else ctx.reporter.removeBufferedMessages(using ctx).errors
}
ctx.run.nn.printSummary() // "2 errors found"

final def compile(parsed: Parsed)(implicit state: State): Result[(CompilationUnit, State)] = {
assert(!parsed.trees.isEmpty)
val defs = definitions(parsed.trees, state)
val unit = createUnit(defs, Span(0, parsed.trees.last.span.end))(using state.context)
runCompilationUnit(unit, defs.state)
}
val newState = unit.tpdTree.getAttachment(ReplCompiler.ReplState).get
if !ctx.reporter.hasErrors then (unit, newState).result
else ctx.reporter.removeBufferedMessages.errors
end compile

final def typeOf(expr: String)(implicit state: State): Result[String] =
final def typeOf(expr: String)(using state: State): Result[String] =
typeCheck(expr).map { tree =>
given Context = state.context
tree.rhs match {
Expand All @@ -190,7 +105,7 @@ class ReplCompiler extends Compiler {
}
}

def docOf(expr: String)(implicit state: State): Result[String] = inContext(state.context) {
def docOf(expr: String)(using state: State): Result[String] = inContext(state.context) {

/** Extract the "selected" symbol from `tree`.
*
Expand Down Expand Up @@ -237,7 +152,7 @@ class ReplCompiler extends Compiler {
}
}

final def typeCheck(expr: String, errorsAllowed: Boolean = false)(implicit state: State): Result[tpd.ValDef] = {
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {

def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
Expand Down Expand Up @@ -300,4 +215,121 @@ class ReplCompiler extends Compiler {
}
}
}
}
object ReplCompiler:
val ReplState: Property.StickyKey[State] = Property.StickyKey()
val objectNames = mutable.Map.empty[Int, TermName]
end ReplCompiler

class ReplCompilationUnit(source: SourceFile) extends CompilationUnit(source):
override def isSuspendable: Boolean = false

/** A placeholder phase that receives parse trees..
*
* It is called "parser" for the convenience of collective muscle memory.
*
* This enables -Vprint:parser.
*/
class Parser extends Phase:
def phaseName: String = "parser"
def run(using Context): Unit = ()
end Parser

/** A phase that assembles wrapped parse trees from user input.
*
* Ths `ReplState` attachment indicates Repl wrapping is required.
*
* This enables -Vprint:repl so that users can see how their code snippet was wrapped.
*/
class ReplPhase extends Phase:
def phaseName: String = "repl"

def run(using Context): Unit =
ctx.compilationUnit.untpdTree match
case pkg @ PackageDef(_, stats) =>
pkg.getAttachment(ReplCompiler.ReplState).foreach {
case given State =>
val defs = definitions(stats)
val res = wrapped(defs, Span(0, stats.last.span.end))
res.putAttachment(ReplCompiler.ReplState, defs.state)
ctx.compilationUnit.untpdTree = res
}
case _ =>
end run

private case class Definitions(stats: List[untpd.Tree], state: State)

private def definitions(trees: List[untpd.Tree])(using Context, State): Definitions =
import untpd.*

// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
val flattened = trees match {
case List(Block(stats, expr)) =>
if (expr eq EmptyTree) stats // happens when expr is not an expression
else stats :+ expr
case _ =>
trees
}

val state = summon[State]
var valIdx = state.valIndex
val defs = mutable.ListBuffer.empty[Tree]

/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
def maybeBumpValIdx(tree: Tree): Unit = tree match
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
case Some(n) if n >= valIdx => valIdx = n + 1
case _ =>
case _ =>

flattened.foreach {
case expr @ Assign(id: Ident, _) =>
// special case simple reassignment (e.g. x = 3)
// in order to print the new value in the REPL
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
defs += expr += assign
case expr if expr.isTerm =>
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
valIdx += 1
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
defs += vd
case other =>
maybeBumpValIdx(other)
defs += other
}

Definitions(defs.toList, state.copy(objectIndex = state.objectIndex + 1, valIndex = valIdx))
end definitions

/** Wrap trees in an object and add imports from the previous compilations.
*
* The resulting structure is something like:
*
* ```
* package <none> {
* object rs$line$nextId {
* import rs$line${i <- 0 until nextId}.*
*
* <trees>
* }
* }
* ```
*/
private def wrapped(defs: Definitions, span: Span)(using Context): untpd.PackageDef =
import untpd.*

val objectName = ctx.source.file.toString
assert(objectName.startsWith(str.REPL_SESSION_LINE))
assert(objectName.endsWith(defs.state.objectIndex.toString))
val objectTermName = objectName.toTermName
ReplCompiler.objectNames.update(defs.state.objectIndex, objectTermName)

val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
val module = ModuleDef(objectTermName, tmpl).withSpan(span)

PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
end ReplPhase
Loading