Skip to content

Commit efd2d9a

Browse files
committed
Three changes to typing rules
The following two rules replace #13657: 1. Exploit capture monotonicity in the apply rule, as discussed in #14387. 2. A rule to make typing nested classes more flexible as discussed in #14390. There's also a bug fix where we now enforce a previously missing subcapturing relationship between the capture set of parent of a class and the capture set of the class itself. Clearly a class captures all variables captured by one of its parent classes.
1 parent 60d8195 commit efd2d9a

File tree

11 files changed

+198
-66
lines changed

11 files changed

+198
-66
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ sealed abstract class CaptureSet extends Showable:
104104
extension (x: CaptureRef) private def subsumes(y: CaptureRef) =
105105
(x eq y)
106106
|| y.match
107-
case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ?
107+
case y: TermRef => y.prefix eq x
108108
case _ => false
109109

110110
/** {x} <:< this where <:< is subcapturing, but treating all variables

compiler/src/dotty/tools/dotc/transform/Recheck.scala

+18-20
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
126126
bindType
127127

128128
def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
129-
if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym)
129+
if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info)
130130

131131
def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit =
132132
val rhsCtx = linkConstructorParams(sym).withOwner(sym)
133133
if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then
134-
inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) }
135-
136-
def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
137-
recheck(tree, pt)
134+
inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) }
138135

139136
def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type =
140137
recheck(tree.rhs)
@@ -358,21 +355,22 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
358355
// Don't report closure nodes, since their span is a point; wait instead
359356
// for enclosing block to preduce an error
360357
case _ =>
361-
val actual = tpe.widenExpr
362-
val expected = pt.widenExpr
363-
//println(i"check conforms $actual <:< $expected")
364-
val isCompatible =
365-
actual <:< expected
366-
|| expected.isRepeatedParam
367-
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
368-
if !isCompatible then
369-
recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected")
370-
err.typeMismatch(tree.withType(tpe), expected)
371-
else if debugSuccesses then
372-
tree match
373-
case _: Ident =>
374-
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
375-
case _ =>
358+
checkConformsExpr(tpe, tpe.widenExpr, pt.widenExpr, tree)
359+
360+
def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
361+
//println(i"check conforms $actual <:< $expected")
362+
val isCompatible =
363+
actual <:< expected
364+
|| expected.isRepeatedParam
365+
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
366+
if !isCompatible then
367+
recheckr.println(i"conforms failed for ${tree}: $original vs $expected")
368+
err.typeMismatch(tree.withType(original), expected)
369+
else if debugSuccesses then
370+
tree match
371+
case _: Ident =>
372+
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
373+
case _ =>
376374

377375
def checkUnit(unit: CompilationUnit)(using Context): Unit =
378376
recheck(unit.tpdTree)

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

+65-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import transform.Recheck
1616
import Recheck.*
1717
import scala.collection.mutable
1818
import CaptureSet.withCaptureSetsExplained
19+
import StdNames.nme
1920
import reporting.trace
2021

2122
object CheckCaptures:
@@ -213,22 +214,6 @@ class CheckCaptures extends Recheck:
213214
interpolateVarsIn(tree.tpt)
214215
curEnv = saved
215216

216-
override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
217-
val pt1 = pt match
218-
case CapturingType(core, refs, _)
219-
if sym.owner.isClass && !sym.owner.isExtensibleClass
220-
&& refs.elems.contains(sym.owner.thisType) =>
221-
val paramCaptures =
222-
sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) =>
223-
val pcs = p.info.captureSet
224-
(cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst
225-
}
226-
val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet
227-
pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures))
228-
case _ =>
229-
pt
230-
recheck(tree, pt1)
231-
232217
override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type =
233218
for param <- cls.paramGetters do
234219
if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then
@@ -237,6 +222,8 @@ class CheckCaptures extends Recheck:
237222
param.srcPos)
238223
val saved = curEnv
239224
val localSet = capturedVars(cls)
225+
for parent <- impl.parents do
226+
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
240227
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv)
241228
try super.recheckClassDef(tree, impl, cls)
242229
finally curEnv = saved
@@ -289,9 +276,34 @@ class CheckCaptures extends Recheck:
289276
finally curEnv = curEnv.outer
290277
recheckFinish(result, arg, pt)
291278

279+
/** A specialized implementation of the apply rule from https://github.com/lampepfl/dotty/discussions/14387:
280+
*
281+
* E |- f: Cf (Ra -> Cr Rr)
282+
* E |- a: Ra
283+
* ------------------------
284+
* E |- f a: Cr /\ {f} Rr
285+
*
286+
* Specialized for the case where `f` is a tracked and the arguments are pure.
287+
* This replaces the previous rule #13657 while still allowing the code in pos/lazylists1.scala.
288+
* We could consider generalizing to the case where the function arguments have non-empty
289+
* capture sets as suggested in #14387, but that would make capture set computations more complex,
290+
* so we should also evaluate the performance impact.
291+
*/
292292
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
293293
includeCallCaptures(tree.symbol, tree.srcPos)
294-
super.recheckApply(tree, pt)
294+
super.recheckApply(tree, pt) match
295+
case tp @ CapturingType(tp1, refs, kind) =>
296+
tree.fun match
297+
case Select(qual, nme.apply)
298+
if defn.isFunctionType(qual.tpe.widen) =>
299+
qual.tpe match
300+
case ref: CaptureRef
301+
if ref.isTracked && tree.args.forall(_.tpe.captureSet.isAlwaysEmpty) =>
302+
tp.derivedCapturingType(tp1, refs ** ref.singletonCaptureSet)
303+
.showing(i"narrow $tree: $tp --> $result", capt)
304+
case _ => tp
305+
case _ => tp
306+
case tp => tp
295307

296308
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
297309
val res = super.recheck(tree, pt)
@@ -319,6 +331,42 @@ class CheckCaptures extends Recheck:
319331
case _ =>
320332
super.recheckFinish(tpe, tree, pt)
321333

334+
/** This method implements the rule outlined in #14390:
335+
* When checking an expression `e: Ca Ta` against an expected type `Cx Tx`
336+
* where the capture set of `Cx` contains this and any method inside the class
337+
* `Cls` of `this` that contains `e` has only pure parameters, drop from `Ca`
338+
* all references to variables or this references outside `Cls`. These are all
339+
* accessed through this, so are already accounted for by `Cx`.
340+
*/
341+
override def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
342+
def isPure(info: Type): Boolean = info match
343+
case info: PolyType => isPure(info.resType)
344+
case info: MethodType => info.paramInfos.forall(_.captureSet.isAlwaysEmpty) && isPure(info.resType)
345+
case _ => true
346+
def isPureContext(owner: Symbol, limit: Symbol): Boolean =
347+
if owner == limit then true
348+
else if !owner.exists then false
349+
else isPure(owner.info) && isPureContext(owner.owner, limit)
350+
val actual1 = (expected, actual.widen) match
351+
case (CapturingType(ecore, erefs, _), actualw @ CapturingType(acore, arefs, _)) =>
352+
val arefs1 = (arefs /: erefs.elems) { (arefs1, eref) =>
353+
eref match
354+
case eref: ThisType if isPureContext(ctx.owner, eref.cls) =>
355+
arefs1.filter {
356+
case aref1: TermRef => !eref.cls.isContainedIn(aref1.symbol.owner)
357+
case aref1: ThisType => !eref.cls.isContainedIn(aref1.cls)
358+
case _ => true
359+
}
360+
case _ =>
361+
arefs1
362+
}
363+
if arefs1 eq arefs then actual
364+
else actualw.derivedCapturingType(acore, arefs1)
365+
.showing(i"healing $actual --> $result", capt)
366+
case _ =>
367+
actual
368+
super.checkConformsExpr(original, actual1, expected, tree)
369+
322370
override def checkUnit(unit: CompilationUnit)(using Context): Unit =
323371
Setup(preRecheckPhase, thisPhase, recheckDef)
324372
.traverse(ctx.compilationUnit.tpdTree)

tests/neg-custom-args/captures/lazylist.check

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ longer explanation available when compiling with `-explain`
3737
17 | def tail = xs() // error: cannot have an inferred type
3838
| ^^^^^^^^^^^^^^^
3939
| Non-local method tail cannot have an inferred result type
40-
| {*} lazylists.LazyList[T]
41-
| with non-empty capture set {*}.
40+
| {LazyCons.this.xs} lazylists.LazyList[T]
41+
| with non-empty capture set {LazyCons.this.xs}.
4242
| The type needs to be declared explicitly.
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 -----------------------------------
2-
25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
3-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
4-
| Found: {xs, f} LazyList[A]
5-
| Required: {Mapped.this, xs} LazyList[A]
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:66 -----------------------------------
2+
25 | def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
3+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
| Found: {xs, f} LazyList[A]
5+
| Required: {Mapped.this, f} LazyList[A]
66

77
longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/lazylists1.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ extension [A](xs: {*} LazyList[A])
2222
def head: B = f(xs.head)
2323
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
2424
def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK
25-
def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
25+
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
2626
new Mapped
2727

tests/neg-custom-args/captures/lazylists2.check

+22-10
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,29 @@ longer explanation available when compiling with `-explain`
2929
32 | new Mapped
3030

3131
longer explanation available when compiling with `-explain`
32-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 -----------------------------------
33-
41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
34-
| ^^^^^^^^^^^^^^
35-
| Found: {f} LazyList[B]
36-
| Required: {xs} LazyList[B]
32+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:36:4 ------------------------------------
33+
36 | final class Mapped extends LazyList[B]: // error
34+
| ^
35+
| Found: {f, xs} LazyList[B]
36+
| Required: {xs} LazyList[B]
37+
37 | this: ({xs} Mapped) =>
38+
38 | def isEmpty = false
39+
39 | def head: B = f(xs.head)
40+
40 | def tail: {this} LazyList[B] = xs.tail.map(f)
41+
41 | new Mapped
3742

3843
longer explanation available when compiling with `-explain`
39-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 -----------------------------------
40-
59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
41-
| ^^^^^^^^^^^^^^
42-
| Found: {f} LazyList[B]
43-
| Required: {Mapped.this} LazyList[B]
44+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:54:4 ------------------------------------
45+
54 | class Mapped extends LazyList[B]: // error
46+
| ^
47+
| Found: {f, xs} LazyList[B]
48+
| Required: LazyList[B]
49+
55 | this: ({xs, f} Mapped) =>
50+
56 | def isEmpty = false
51+
57 | def head: B = f(xs.head)
52+
58 | def tail: {this} LazyList[B] = xs.tail.map(f)
53+
59 | class Mapped2 extends Mapped:
54+
60 | this: Mapped =>
55+
61 | new Mapped2
4456

4557
longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/lazylists2.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ extension [A](xs: {*} LazyList[A])
3333
new Mapped
3434

3535
def map3[B](f: A => B): {xs} LazyList[B] =
36-
final class Mapped extends LazyList[B]:
36+
final class Mapped extends LazyList[B]: // error
3737
this: ({xs} Mapped) =>
3838

3939
def isEmpty = false
4040
def head: B = f(xs.head)
41-
def tail: {this} LazyList[B] = xs.tail.map(f) // error
41+
def tail: {this} LazyList[B] = xs.tail.map(f)
4242
new Mapped
4343

4444
def map4[B](f: A => B): {xs} LazyList[B] =
@@ -51,12 +51,12 @@ extension [A](xs: {*} LazyList[A])
5151
new Mapped
5252

5353
def map5[B](f: A => B): LazyList[B] =
54-
class Mapped extends LazyList[B]:
54+
class Mapped extends LazyList[B]: // error
5555
this: ({xs, f} Mapped) =>
5656

5757
def isEmpty = false
5858
def head: B = f(xs.head)
59-
def tail: {this} LazyList[B] = xs.tail.map(f) // error
59+
def tail: {this} LazyList[B] = xs.tail.map(f)
6060
class Mapped2 extends Mapped:
6161
this: Mapped =>
6262
new Mapped2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import language.experimental.saferExceptions
2+
import annotation.unchecked.uncheckedVariance
3+
4+
trait LazyList[+A]:
5+
this: {*} LazyList[A] =>
6+
7+
def isEmpty: Boolean
8+
def head: A
9+
def tail: {this} LazyList[A]
10+
11+
object LazyNil extends LazyList[Nothing]:
12+
def isEmpty: Boolean = true
13+
def head = ???
14+
def tail = ???
15+
16+
final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]:
17+
this: {*} LazyList[T] =>
18+
19+
var forced = false
20+
var cache: {this} LazyList[T @uncheckedVariance] = compiletime.uninitialized
21+
22+
private def force =
23+
if !forced then
24+
cache = xs()
25+
forced = true
26+
cache
27+
28+
def isEmpty = false
29+
def head = x
30+
def tail: {this} LazyList[T] = force
31+
end LazyCons
32+
33+
extension [A](xs: {*} LazyList[A])
34+
def map[B](f: A => B): {xs, f} LazyList[B] =
35+
if xs.isEmpty then LazyNil
36+
else LazyCons(f(xs.head), () => xs.tail.map(f))
37+
38+
def filter(p: A => Boolean): {xs, p} LazyList[A] =
39+
if xs.isEmpty then LazyNil
40+
else if p(xs.head) then LazyCons(xs.head, () => xs.tail.filter(p))
41+
else xs.tail.filter(p)
42+
43+
def concat(ys: {*} LazyList[A]): {xs, ys} LazyList[A] =
44+
if xs.isEmpty then ys
45+
else LazyCons(xs.head, () => xs.tail.concat(ys))
46+
end extension
47+
48+
class Ex1 extends Exception
49+
class Ex2 extends Exception
50+
51+
def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) =
52+
val xs = LazyCons(1, () => LazyNil)
53+
54+
def f(x: Int): Int throws Ex1 =
55+
if x < 0 then throw Ex1()
56+
x * x
57+
58+
def g(x: Int): Int throws Ex1 =
59+
if x < 0 then throw Ex1()
60+
x * x
61+
62+
def x1 = xs.map(f)
63+
def x1c: {cap1} LazyList[Int] = x1
64+
65+
def x2 = x1.concat(xs.map(g).filter(_ > 0))
66+
def x2c: {cap1, cap2} LazyList[Int] = x2
67+
68+

tests/pos-custom-args/captures/lazylists.scala

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ extension [A](xs: {*} LazyList[A])
2121
def isEmpty = false
2222
def head: B = f(xs.head)
2323
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
24-
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK
2524
if xs.isEmpty then LazyNil
2625
else new Mapped
2726

0 commit comments

Comments
 (0)