Skip to content

Commit 4a65e8f

Browse files
authored
Merge pull request #206 from retronym/bug/post-refchecks
Detect and deal with non-RefTree captures
2 parents 89f5921 + 9bf63b6 commit 4a65e8f

File tree

4 files changed

+126
-25
lines changed

4 files changed

+126
-25
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

+24-5
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,26 @@ trait AsyncTransform {
154154
sym.asModule.moduleClass.setOwner(stateMachineClass)
155155
}
156156
}
157+
158+
def adjustType(tree: Tree): Tree = {
159+
val resultType = if (tree.tpe eq null) null else tree.tpe.map {
160+
case TypeRef(pre, sym, args) if liftedSyms.contains(sym) =>
161+
val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args)
162+
tp1
163+
case SingleType(pre, sym) if liftedSyms.contains(sym) =>
164+
val tp1 = internal.singleType(thisType(sym.owner.asClass), sym)
165+
tp1
166+
case tp => tp
167+
}
168+
setType(tree, resultType)
169+
}
170+
157171
// Replace the ValDefs in the splicee with Assigns to the corresponding lifted
158172
// fields. Similarly, replace references to them with references to the field.
159173
//
160174
// This transform will only be run on the RHS of `def foo`.
161-
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match {
162-
case _ if api.currentOwner == stateMachineClass =>
175+
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match {
176+
case _ if api.currentOwner == stateMachineClass =>
163177
api.default(tree)
164178
case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
165179
api.atOwner(api.currentOwner) {
@@ -172,14 +186,19 @@ trait AsyncTransform {
172186
treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner)
173187
}
174188
}
175-
case _: DefTree if liftedSyms(tree.symbol) =>
189+
case _: DefTree if liftedSyms(tree.symbol) =>
176190
EmptyTree
177-
case Ident(name) if liftedSyms(tree.symbol) =>
191+
case Ident(name) if liftedSyms(tree.symbol) =>
178192
val fieldSym = tree.symbol
179193
atPos(tree.pos) {
180194
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe)
181195
}
182-
case _ =>
196+
case sel @ Select(n@New(tt: TypeTree), nme.CONSTRUCTOR) =>
197+
adjustType(sel)
198+
adjustType(n)
199+
adjustType(tt)
200+
sel
201+
case _ =>
183202
api.default(tree)
184203
}
185204

src/main/scala/scala/async/internal/Lifter.scala

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package scala.async.internal
22

33
import scala.collection.mutable
4+
import scala.collection.mutable.ListBuffer
45

56
trait Lifter {
67
self: AsyncMacro =>
@@ -77,13 +78,25 @@ trait Lifter {
7778
// The direct references of each block, excluding references of `DefTree`-s which
7879
// are already accounted for.
7980
val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = {
80-
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
81-
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
81+
val result = new mutable.LinkedHashMap[Int, ListBuffer[Symbol]]()
82+
asyncStates.foreach(
83+
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).foreach(_.foreach {
8284
case rt: RefTree
83-
if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
85+
if symToDefiningState.contains(rt.symbol) =>
86+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += rt.symbol
87+
case tt: TypeTree =>
88+
tt.tpe.foreach { tp =>
89+
val termSym = tp.termSymbol
90+
if (symToDefiningState.contains(termSym))
91+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += termSym
92+
val typeSym = tp.typeSymbol
93+
if (symToDefiningState.contains(typeSym))
94+
result.getOrElseUpdate(asyncState.state, new ListBuffer) += typeSym
95+
}
96+
case _ =>
8497
})
8598
)
86-
toMultiMap(refs)
99+
result.map { case (a, b) => (a, b.result())}
87100
}
88101

89102
def liftableSyms: mutable.LinkedHashSet[Symbol] = {

src/test/scala/scala/async/TreeInterrogation.scala

+21-9
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,29 @@ object TreeInterrogationApp extends App {
7070
val tree = tb.parse(
7171
"""
7272
| import scala.async.internal.AsyncId._
73-
| async {
74-
| var b = true
75-
| while(await(b)) {
76-
| b = false
77-
| }
78-
| (1, 1) match {
79-
| case (x, y) => await(2); println(x)
80-
| }
81-
| await(b)
73+
| trait QBound { type D; trait ResultType { case class Inner() }; def toResult: ResultType = ??? }
74+
| trait QD[Q <: QBound] {
75+
| val operation: Q
76+
| type D = operation.D
8277
| }
8378
|
79+
| async {
80+
| if (!"".isEmpty) {
81+
| val treeResult = null.asInstanceOf[QD[QBound]]
82+
| await(0)
83+
| val y = treeResult.operation
84+
| type RD = treeResult.operation.D
85+
| (null: Object) match {
86+
| case (_, _: RD) => ???
87+
| case _ => val x = y.toResult; x.Inner()
88+
| }
89+
| await(1)
90+
| (y, null.asInstanceOf[RD])
91+
| ""
92+
| }
93+
|
94+
| }
95+
|
8496
| """.stripMargin)
8597
println(tree)
8698
val tree1 = tb.typeCheck(tree.duplicate)

src/test/scala/scala/async/run/late/LateExpansion.scala

+64-7
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ package scala.async.run.late
33
import java.io.File
44

55
import junit.framework.Assert.assertEquals
6-
import org.junit.{Assert, Test}
6+
import org.junit.{Assert, Ignore, Test}
77

88
import scala.annotation.StaticAnnotation
99
import scala.annotation.meta.{field, getter}
10-
import scala.async.TreeInterrogation
1110
import scala.async.internal.AsyncId
1211
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
1312
import scala.tools.nsc._
@@ -19,6 +18,57 @@ import scala.tools.nsc.transform.TypingTransformers
1918
// calls it from a new phase that runs after patmat.
2019
class LateExpansion {
2120

21+
@Test def testRewrittenApply(): Unit = {
22+
val result = wrapAndRun(
23+
"""
24+
| object O {
25+
| case class Foo(a: Any)
26+
| }
27+
| @autoawait def id(a: String) = a
28+
| O.Foo
29+
| id("foo") + id("bar")
30+
| O.Foo(1)
31+
| """.stripMargin)
32+
assertEquals("Foo(1)", result.toString)
33+
}
34+
35+
@Ignore("Need to use adjustType more pervasively in AsyncTransform, but that exposes bugs in {Type, ... }Symbol's cache invalidation")
36+
@Test def testIsInstanceOfType(): Unit = {
37+
val result = wrapAndRun(
38+
"""
39+
| class Outer
40+
| @autoawait def id(a: String) = a
41+
| val o = new Outer
42+
| id("foo") + id("bar")
43+
| ("": Object).isInstanceOf[o.type]
44+
| """.stripMargin)
45+
assertEquals(false, result)
46+
}
47+
48+
@Test def testIsInstanceOfTerm(): Unit = {
49+
val result = wrapAndRun(
50+
"""
51+
| class Outer
52+
| @autoawait def id(a: String) = a
53+
| val o = new Outer
54+
| id("foo") + id("bar")
55+
| o.isInstanceOf[Outer]
56+
| """.stripMargin)
57+
assertEquals(true, result)
58+
}
59+
60+
@Test def testArrayLocalModule(): Unit = {
61+
val result = wrapAndRun(
62+
"""
63+
| class Outer
64+
| @autoawait def id(a: String) = a
65+
| val O = ""
66+
| id("foo") + id("bar")
67+
| new Array[O.type](0)
68+
| """.stripMargin)
69+
assertEquals(classOf[Array[String]], result.getClass)
70+
}
71+
2272
@Test def test0(): Unit = {
2373
val result = wrapAndRun(
2474
"""
@@ -27,6 +77,7 @@ class LateExpansion {
2777
| """.stripMargin)
2878
assertEquals("foobar", result)
2979
}
80+
3081
@Test def testGuard(): Unit = {
3182
val result = wrapAndRun(
3283
"""
@@ -143,6 +194,7 @@ class LateExpansion {
143194
|}
144195
| """.stripMargin)
145196
}
197+
146198
@Test def shadowing2(): Unit = {
147199
val result = run(
148200
"""
@@ -369,6 +421,7 @@ class LateExpansion {
369421
}
370422
""")
371423
}
424+
372425
@Test def testNegativeArraySizeExceptionFine1(): Unit = {
373426
val result = run(
374427
"""
@@ -389,18 +442,20 @@ class LateExpansion {
389442
}
390443
""")
391444
}
445+
392446
private def createTempDir(): File = {
393447
val f = File.createTempFile("output", "")
394448
f.delete()
395449
f.mkdirs()
396450
f
397451
}
452+
398453
def run(code: String): Any = {
399-
// settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn")
400454
val out = createTempDir()
401455
try {
402456
val reporter = new StoreReporter
403457
val settings = new Settings(println(_))
458+
//settings.processArgumentString("-Xprint:refchecks,patmat,postpatmat,jvm -nowarn")
404459
settings.outdir.value = out.getAbsolutePath
405460
settings.embeddedDefaults(getClass.getClassLoader)
406461
val isInSBT = !settings.classpath.isSetByUser
@@ -432,6 +487,7 @@ class LateExpansion {
432487
}
433488

434489
abstract class LatePlugin extends Plugin {
490+
435491
import global._
436492

437493
override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
@@ -448,16 +504,16 @@ abstract class LatePlugin extends Plugin {
448504
super.transform(tree) match {
449505
case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
450506
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
451-
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) =>
507+
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi]) =>
452508
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil))
453509
case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
454-
deriveDefDef(dd){ rhs: Tree =>
510+
deriveDefDef(dd) { rhs: Tree =>
455511
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
456512
localTyper.typed(atPos(dd.pos)(invoke))
457513
}
458514
}
459515
case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) {
460-
deriveValDef(vd){ rhs: Tree =>
516+
deriveValDef(vd) { rhs: Tree =>
461517
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
462518
localTyper.typed(atPos(vd.pos)(invoke))
463519
}
@@ -468,6 +524,7 @@ abstract class LatePlugin extends Plugin {
468524
}
469525
}
470526
}
527+
471528
override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
472529
override def apply(unit: CompilationUnit): Unit = {
473530
val translated = newTransformer(unit).transformUnit(unit)
@@ -476,7 +533,7 @@ abstract class LatePlugin extends Plugin {
476533
}
477534
}
478535

479-
override val runsAfter: List[String] = "patmat" :: Nil
536+
override val runsAfter: List[String] = "refchecks" :: Nil
480537
override val phaseName: String = "postpatmat"
481538

482539
})

0 commit comments

Comments
 (0)