Skip to content

Commit 7263aaa

Browse files
committed
Merge pull request #141 from retronym/ticket/await-extractor
Enable a compiler plugin to use the async transform after patmat
2 parents 93f207f + 168e10c commit 7263aaa

11 files changed

+459
-55
lines changed

src/main/scala/scala/async/internal/AnfTransform.scala

+70-11
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
1616
import c.internal._
1717
import decorators._
1818

19-
def anfTransform(tree: Tree): Block = {
19+
def anfTransform(tree: Tree, owner: Symbol): Block = {
2020
// Must prepend the () for issue #31.
21-
val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
21+
val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)
2222

2323
sealed abstract class AnfMode
2424
case object Anf extends AnfMode
2525
case object Linearizing extends AnfMode
2626

27+
val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
28+
2729
var mode: AnfMode = Anf
28-
typingTransform(block)((tree, api) => {
30+
typingTransform(tree1, owner)((tree, api) => {
2931
def blockToList(tree: Tree): List[Tree] = tree match {
3032
case Block(stats, expr) => stats :+ expr
3133
case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
3436
def listToBlock(trees: List[Tree]): Block = trees match {
3537
case trees @ (init :+ last) =>
3638
val pos = trees.map(_.pos).reduceLeft(_ union _)
37-
Block(init, last).setType(last.tpe).setPos(pos)
39+
newBlock(init, last).setType(last.tpe).setPos(pos)
3840
}
3941

4042
object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
6668
stats :+ valDef :+ atPos(tree.pos)(ref1)
6769

6870
case If(cond, thenp, elsep) =>
71+
// If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
72+
// as though it was typed with `Unit`.
73+
def isPatMatGeneratedJump(t: Tree): Boolean = t match {
74+
case Block(_, expr) => isPatMatGeneratedJump(expr)
75+
case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
76+
case _: Apply if isLabel(t.symbol) => true
77+
case _ => false
78+
}
79+
if (isPatMatGeneratedJump(expr)) {
80+
internal.setType(expr, definitions.UnitTpe)
81+
}
6982
// if type of if-else is Unit don't introduce assignment,
7083
// but add Unit value to bring it into form expected by async transform
7184
if (expr.tpe =:= definitions.UnitTpe) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
7790
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
7891
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
7992
orig match {
80-
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
93+
case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
8194
case _ => Assign(Ident(varDef.symbol), cast(orig))
8295
}
8396
})
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
115128
}
116129
}
117130

118-
private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
131+
def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
119132
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
120133
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
121134
}
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
152165
}
153166

154167
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
155-
val containsAwait = tree exists isAwait
156-
if (!containsAwait) {
168+
if (!containsAwait(tree)) {
157169
tree match {
158170
case Block(stats, expr) =>
159171
// avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
207219
funStats ++ argStatss.flatten.flatten :+ typedNewApply
208220

209221
case Block(stats, expr) =>
210-
(stats :+ expr).flatMap(linearize.transformToList)
222+
val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
223+
eliminateMatchEndLabelParameter(trees)
211224

212225
case ValDef(mods, name, tpt, rhs) =>
213-
if (rhs exists isAwait) {
226+
if (containsAwait(rhs)) {
214227
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
215228
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
216229
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
247260
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
248261

249262
case LabelDef(name, params, rhs) =>
250-
List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
263+
List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
251264

252265
case TypeApply(fun, targs) =>
253266
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,52 @@ private[async] trait AnfTransform {
259272
}
260273
}
261274

275+
// Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
276+
//
277+
// CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
278+
// a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
279+
//
280+
// For our purposes, it is easier to:
281+
// - extract a `matchRes` variable
282+
// - rewrite the terminal label def to take no parameters, and instead read this temp variable
283+
// - change jumps to the terminal label to an assignment and a no-arg label application
284+
def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = {
285+
import internal.{methodType, setInfo}
286+
val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
287+
288+
val matchResults = collection.mutable.Buffer[Tree]()
289+
val statsExpr0 = statsExpr.reverseMap {
290+
case ld @ LabelDef(_, param :: Nil, body) =>
291+
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
292+
matchResults += matchResult
293+
caseDefToMatchResult(ld.symbol) = matchResult.symbol
294+
val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
295+
setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
296+
ld2
297+
case t =>
298+
if (caseDefToMatchResult.isEmpty) t
299+
else typingTransform(t)((tree, api) =>
300+
tree match {
301+
case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
302+
api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
303+
case Block(stats, expr) =>
304+
api.default(tree) match {
305+
case Block(stats, Block(stats1, expr)) =>
306+
treeCopy.Block(tree, stats ::: stats1, expr)
307+
case t => t
308+
}
309+
case _ =>
310+
api.default(tree)
311+
}
312+
)
313+
}
314+
matchResults.toList match {
315+
case Nil => statsExpr
316+
case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
317+
case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
318+
}
319+
}
320+
262321
def anfLinearize(tree: Tree): Block = {
263322
val trees: List[Tree] = mode match {
264323
case Anf => anf._transformToList(tree)

src/main/scala/scala/async/internal/AsyncBase.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ abstract class AsyncBase {
4343
(body: c.Expr[T])
4444
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
4545
import c.universe._, c.internal._, decorators._
46-
val asyncMacro = AsyncMacro(c, self)
46+
val asyncMacro = AsyncMacro(c, self)(body.tree)
4747

48-
val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T])
48+
val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T])
4949
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
5050

5151
// Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges

src/main/scala/scala/async/internal/AsyncId.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase {
4141
* A trivial implementation of [[FutureSystem]] that performs computations
4242
* on the current thread. Useful for testing.
4343
*/
44+
class Box[A] {
45+
var a: A = _
46+
}
4447
object IdentityFutureSystem extends FutureSystem {
45-
46-
class Prom[A] {
47-
var a: A = _
48-
}
48+
type Prom[A] = Box[A]
4949

5050
type Fut[A] = A
5151
type ExecContext = Unit
@@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem {
5757

5858
def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(())))
5959

60-
def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
60+
def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]]
6161
def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]]
6262
def execContextType: Type = weakTypeOf[Unit]
6363

src/main/scala/scala/async/internal/AsyncMacro.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package scala.async.internal
22

33
object AsyncMacro {
4-
def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = {
4+
def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
55
import language.reflectiveCalls
66
new AsyncMacro { self =>
77
val c: c0.type = c0
8+
val body: c.Tree = body0
89
// This member is required by `AsyncTransform`:
910
val asyncBase: AsyncBase = base
1011
// These members are required by `ExprBuilder`:
1112
val futureSystem: FutureSystem = base.futureSystem
1213
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
14+
val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
1315
}
1416
}
1517
}
@@ -19,7 +21,10 @@ private[async] trait AsyncMacro
1921
with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {
2022

2123
val c: scala.reflect.macros.Context
24+
val body: c.Tree
25+
val containsAwait: c.Tree => Boolean
2226

2327
lazy val macroPos = c.macroApplication.pos.makeTransparent
2428
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)
29+
2530
}

src/main/scala/scala/async/internal/AsyncTransform.scala

+6-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ trait AsyncTransform {
99

1010
val asyncBase: AsyncBase
1111

12-
def asyncTransform[T](body: Tree, execContext: Tree)
12+
def asyncTransform[T](execContext: Tree)
1313
(resultType: WeakTypeTag[T]): Tree = {
1414

1515
// We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce
@@ -22,7 +22,7 @@ trait AsyncTransform {
2222
// Transform to A-normal form:
2323
// - no await calls in qualifiers or arguments,
2424
// - if/match only used in statement position.
25-
val anfTree0: Block = anfTransform(body)
25+
val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner)
2626

2727
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
2828

@@ -35,15 +35,15 @@ trait AsyncTransform {
3535
val stateMachine: ClassDef = {
3636
val body: List[Tree] = {
3737
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
38-
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
38+
val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
3939
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
4040

4141
val apply0DefDef: DefDef = {
4242
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
4343
// See SI-1247 for the the optimization that avoids creation.
4444
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
4545
}
46-
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
46+
List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
4747
}
4848

4949
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
@@ -98,10 +98,11 @@ trait AsyncTransform {
9898
}
9999

100100
val isSimple = asyncBlock.asyncStates.size == 1
101-
if (isSimple)
101+
val result = if (isSimple)
102102
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
103103
else
104104
startStateMachine
105+
cleanupContainsAwaitAttachments(result)
105106
}
106107

107108
def logDiagnostics(anfTree: Tree, states: Seq[String]) {

0 commit comments

Comments
 (0)