Skip to content

Commit 2cdc7ce

Browse files
committed
Fix soundness problem with curried functions
Flags the following as an error: ```scala val foo: (x: Ref[Int]^) -> (y: Ref[Int]^{a}) ->{x} Unit = x => y => swap(x, y) val f: (y: Ref[Int]^{a}) ->{a} Unit = foo(a) // error f(a) ``` Here, the result type of `foo(a)` takes an argument with `a` capture but also refers to a hidden `a` in its `x` dependency. We now recognize and reject this case. ```
1 parent c5bad2c commit 2cdc7ce

File tree

4 files changed

+142
-41
lines changed

4 files changed

+142
-41
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureRef.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ trait CaptureRef extends TypeProxy, ValueType:
259259
vs.ifNotSeen(this)(hidden.elems.exists(_.subsumes(y)))
260260
|| !y.stripReadOnly.isCap && canAddHidden && vs.addHidden(hidden, y)
261261
case _ =>
262-
this.isCap && canAddHidden
262+
this.isCap && canAddHidden && vs != VarState.HardSeparate
263263
|| y.match
264264
case ReadOnlyCapability(y1) => this.stripReadOnly.maxSubsumes(y1, canAddHidden)
265265
case _ => false

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+20-6
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ sealed abstract class CaptureSet extends Showable:
197197
// For instance x: C^{y, z}. Then neither y nor z subsumes x but {y, z} accounts for x.
198198
!x.isMaxCapability
199199
&& !x.derivesFrom(defn.Caps_CapSet)
200-
&& !(vs == VarState.Separate && x.captureSetOfInfo.containsRootCapability)
200+
&& !(vs.isSeparating && x.captureSetOfInfo.containsRootCapability)
201201
// in VarState.Separate, don't try to widen to cap since that might succeed with {cap} <: {cap}
202202
&& x.captureSetOfInfo.subCaptures(this, VarState.Separate).isOK
203203

@@ -257,9 +257,9 @@ sealed abstract class CaptureSet extends Showable:
257257
* `this` and `that`
258258
*/
259259
def ++ (that: CaptureSet)(using Context): CaptureSet =
260-
if this.subCaptures(that, VarState.Separate).isOK then
260+
if this.subCaptures(that, VarState.HardSeparate).isOK then
261261
if that.isAlwaysEmpty && this.keepAlways then this else that
262-
else if that.subCaptures(this, VarState.Separate).isOK then this
262+
else if that.subCaptures(this, VarState.HardSeparate).isOK then this
263263
else if this.isConst && that.isConst then Const(this.elems ++ that.elems)
264264
else Union(this, that)
265265

@@ -554,7 +554,7 @@ object CaptureSet:
554554
else
555555
// id == 108 then assert(false, i"trying to add $elem to $this")
556556
assert(elem.isTrackableRef, elem)
557-
assert(!this.isInstanceOf[HiddenSet] || summon[VarState] == VarState.Separate, summon[VarState])
557+
assert(!this.isInstanceOf[HiddenSet] || summon[VarState].isSeparating, summon[VarState])
558558
elems += elem
559559
if elem.isRootCapability then
560560
rootAddedHandler()
@@ -1157,6 +1157,7 @@ object CaptureSet:
11571157

11581158
/** Does this state allow additions of elements to capture set variables? */
11591159
def isOpen = true
1160+
def isSeparating = false
11601161

11611162
/** Add element to hidden set, recording it in elemsMap,
11621163
* return whether this was allowed. By default, recording is allowed
@@ -1204,10 +1205,23 @@ object CaptureSet:
12041205
* reference `r` only if `r` is already present in the hidden set of the instance.
12051206
* No new references can be added.
12061207
*/
1207-
@sharable
1208-
object Separate extends Closed:
1208+
class Separating extends Closed:
12091209
override def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean = false
12101210
override def toString = "separating varState"
1211+
override def isSeparating = true
1212+
1213+
/** A closed state that allows a Fresh.Cap instance to subsume a
1214+
* reference `r` only if `r` is already present in the hidden set of the instance.
1215+
* No new references can be added.
1216+
*/
1217+
@sharable
1218+
object Separate extends Separating
1219+
1220+
/** Like Separate but in addition we assume that `cap` never subsumes anything else.
1221+
* Used in `++` to not lose track of dependencies between function parameters.
1222+
*/
1223+
@sharable
1224+
object HardSeparate extends Separating
12111225

12121226
/** A special state that turns off recording of elements. Used only
12131227
* in `addSub` to prevent cycles in recordings.

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

+72-34
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,13 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
340340
* @param fn the function
341341
* @param parts the function prefix followed by the flattened argument list
342342
* @param polyArg the clashing argument to a polymorphic formal
343-
* @param clashing the argument with which it clashes
343+
* @param clashing the argument, function prefix, or entire function application result with
344+
* which it clashes,
345+
*
344346
*/
345347
def sepApplyError(fn: Tree, parts: List[Tree], polyArg: Tree, clashing: Tree)(using Context): Unit =
346348
val polyArgIdx = parts.indexOf(polyArg).ensuring(_ >= 0) - 1
347-
val clashIdx = parts.indexOf(clashing).ensuring(_ >= 0)
349+
val clashIdx = parts.indexOf(clashing) // -1 means entire function application
348350
def paramName(mt: Type, idx: Int): Option[Name] = mt match
349351
case mt @ MethodType(pnames) =>
350352
if idx < pnames.length then Some(pnames(idx)) else paramName(mt.resType, idx - pnames.length)
@@ -363,11 +365,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
363365
if isShowableMethod then i"${fn.symbol}: ${fn.symbol.info}"
364366
else i"a function of type ${funType.widen}"
365367
def clashArgStr = clashIdx match
366-
case 0 => "function prefix"
367-
case 1 => "first argument "
368-
case 2 => "second argument"
369-
case 3 => "third argument "
370-
case n => s"${n}th argument "
368+
case -1 => "function result"
369+
case 0 => "function prefix"
370+
case 1 => "first argument "
371+
case 2 => "second argument"
372+
case 3 => "third argument "
373+
case n => s"${n}th argument "
371374
def clashTypeStr =
372375
if clashIdx == 0 && !isShowableMethod then "" // we already mentioned the type in `funStr`
373376
else i" with type ${clashing.nuType}"
@@ -455,11 +458,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
455458
*
456459
* @param fn the applied function
457460
* @param args the flattened argument lists
461+
* @param app the entire application tree
458462
* @param deps cross argument dependencies: maps argument trees to
459463
* those other arguments that where mentioned by coorresponding
460464
* formal parameters.
461465
*/
462-
private def checkApply(fn: Tree, args: List[Tree], deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
466+
private def checkApply(fn: Tree, args: List[Tree], app: Tree, deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
463467
val (qual, fnCaptures) = methPart(fn) match
464468
case Select(qual, _) => (qual, qual.nuType.captureSet)
465469
case _ => (fn, CaptureSet.empty)
@@ -511,6 +515,29 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
511515
currentPeaks = PeaksPair(
512516
currentPeaks.actual ++ argPeaks.actual,
513517
currentPeaks.hidden ++ argPeaks.hidden)
518+
end for
519+
520+
def collectRefs(args: List[Type], res: Type) =
521+
args.foldLeft(argCaptures(res)): (refs, arg) =>
522+
refs ++ arg.deepCaptureSet.elems
523+
524+
/** The deep capture sets of all parameters of this type (if it is a function type) */
525+
def argCaptures(tpe: Type): Refs = tpe match
526+
case defn.FunctionOf(args, resultType, isContextual) =>
527+
collectRefs(args, resultType)
528+
case defn.RefinedFunctionOf(mt) =>
529+
collectRefs(mt.paramInfos, mt.resType)
530+
case CapturingType(parent, _) =>
531+
argCaptures(parent)
532+
case _ =>
533+
emptyRefs
534+
535+
if !deps(app).isEmpty then
536+
lazy val appPeaks = argCaptures(app.nuType).peaks
537+
lazy val partPeaks = partsWithPeaks.toMap
538+
for arg <- deps(app) do
539+
if arg.needsSepCheck && !partPeaks(arg).hidden.sharedWith(appPeaks).isEmpty then
540+
sepApplyError(fn, parts, arg, app)
514541
end checkApply
515542

516543
/** 1. Check that the capabilities used at `tree` don't overlap with
@@ -782,44 +809,55 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
782809
*
783810
* f(x: A, y: B^{cap, x}, z: C^{x, y}): D
784811
*
785-
* then the dependencies of an application `f(a, b)` is a map that takes
786-
* `b` to `List(a)` and `c` to `List(a, b)`.
812+
* then the dependencies of an application `f(a, b, c)` of type C^{y} is the map
813+
*
814+
* [ b -> [a]
815+
* , c -> [a, b]
816+
* , f(a, b, c) -> [b]]
787817
*/
788-
private def dependencies(fn: Tree, argss: List[List[Tree]])(using Context): collection.Map[Tree, List[Tree]] =
818+
private def dependencies(fn: Tree, argss: List[List[Tree]], app: Tree)(using Context): collection.Map[Tree, List[Tree]] =
819+
def isFunApply(sym: Symbol) =
820+
sym.name == nme.apply && defn.isFunctionClass(sym.owner)
789821
val mtpe =
790-
if fn.symbol.exists then fn.symbol.info
791-
else fn.tpe.widen // happens for PolyFunction applies
822+
if fn.symbol.exists && !isFunApply(fn.symbol) then fn.symbol.info
823+
else fn.nuType.widen
792824
val mtps = collectMethodTypes(mtpe)
793825
assert(mtps.hasSameLengthAs(argss), i"diff for $fn: ${fn.symbol} /// $mtps /// $argss")
794826
val mtpsWithArgs = mtps.zip(argss)
795827
val argMap = mtpsWithArgs.toMap
796828
val deps = mutable.HashMap[Tree, List[Tree]]().withDefaultValue(Nil)
797-
for
798-
(mt, args) <- mtpsWithArgs
799-
(formal, arg) <- mt.paramInfos.zip(args)
800-
dep <- formal.captureSet.elems.toList
801-
do
802-
val referred = dep.stripReach match
803-
case dep: TermParamRef =>
804-
argMap(dep.binder)(dep.paramNum) :: Nil
805-
case dep: ThisType if dep.cls == fn.symbol.owner =>
806-
val Select(qual, _) = fn: @unchecked // TODO can we use fn instead?
807-
qual :: Nil
808-
case _ =>
809-
Nil
810-
deps(arg) ++= referred
829+
830+
def recordDeps(formal: Type, actual: Tree) =
831+
for dep <- formal.captureSet.elems.toList do
832+
val referred = dep.stripReach match
833+
case dep: TermParamRef =>
834+
argMap(dep.binder)(dep.paramNum) :: Nil
835+
case dep: ThisType if dep.cls == fn.symbol.owner =>
836+
val Select(qual, _) = fn: @unchecked // TODO can we use fn instead?
837+
qual :: Nil
838+
case _ =>
839+
Nil
840+
deps(actual) ++= referred
841+
842+
for (mt, args) <- mtpsWithArgs; (formal, arg) <- mt.paramInfos.zip(args) do
843+
recordDeps(formal, arg)
844+
recordDeps(mtpe.finalResultType, app)
845+
capt.println(i"deps for $app = ${deps.toList}")
811846
deps
812847

848+
813849
/** Decompose an application into a function prefix and a list of argument lists.
814850
* If some of the arguments need a separation check because they are capture polymorphic,
815851
* perform a separation check with `checkApply`
816852
*/
817-
private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match
818-
case Apply(fn, args) => traverseApply(fn, args :: argss)
819-
case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments
820-
case _ =>
821-
if argss.nestedExists(_.needsSepCheck) then
822-
checkApply(tree, argss.flatten, dependencies(tree, argss))
853+
private def traverseApply(app: Tree)(using Context): Unit =
854+
def recur(tree: Tree, argss: List[List[Tree]]): Unit = tree match
855+
case Apply(fn, args) => recur(fn, args :: argss)
856+
case TypeApply(fn, args) => recur(fn, argss) // skip type arguments
857+
case _ =>
858+
if argss.nestedExists(_.needsSepCheck) then
859+
checkApply(tree, argss.flatten, app, dependencies(tree, argss, app))
860+
recur(app, Nil)
823861

824862
/** Is `tree` an application of `caps.unsafe.unsafeAssumeSeparate`? */
825863
def isUnsafeAssumeSeparate(tree: Tree)(using Context): Boolean = tree match
@@ -866,7 +904,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
866904
traverseChildren(tree)
867905
tree.tpe match
868906
case _: MethodOrPoly =>
869-
case _ => traverseApply(tree, Nil)
907+
case _ => traverseApply(tree)
870908
case _: Block | _: Template =>
871909
traverseSection(tree)
872910
case tree: ValDef =>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
4+
class Ref[T](init: T) extends Mutable:
5+
private var value: T = init
6+
def get: T = value
7+
mut def set(newValue: T): Unit = value = newValue
8+
9+
// a library function that assumes that a and b MUST BE separate
10+
def swap[T](a: Ref[Int]^, b: Ref[Int]^): Unit = ???
11+
12+
def test0(): Unit =
13+
val a: Ref[Int]^ = Ref(0)
14+
def foo(x: Ref[Int]^)(y: Ref[Int]^{a}): Unit =
15+
swap(x, y)
16+
foo(a)(a) // error
17+
18+
def test1(): Unit =
19+
val a: Ref[Int]^ = Ref(0)
20+
val foo: (x: Ref[Int]^) -> (y: Ref[Int]^{a}) ->{x} Unit =
21+
x => y => swap(x, y)
22+
val f: (y: Ref[Int]^{a}) ->{a} Unit = foo(a) // error
23+
f(a)
24+
25+
def test2(): Unit =
26+
val a: Ref[Int]^ = Ref(0)
27+
val foo: (x: Ref[Int]^) -> (y: Ref[Int]^{a}) ->{x} Unit =
28+
x => y => swap(x, y)
29+
foo(a)(a) // error
30+
31+
def test3(): Unit =
32+
val a: Ref[Int]^ = Ref(0)
33+
val foo: (x: Ref[Int]^) -> (y: Ref[Int]^) ->{x} Unit =
34+
x => y => swap(x, y)
35+
foo(a)(a) // error
36+
37+
def test4(): Unit =
38+
val a: Ref[Int]^ = Ref(0)
39+
val foo: (x: Ref[Int]^) -> (y: Ref[Int]^) ->{x} Unit =
40+
x => y => swap(x, y)
41+
val f = foo(a)
42+
f(a) // error
43+
44+
def test5(): Unit =
45+
val a: Ref[Int]^ = Ref(0)
46+
val foo: (x: Ref[Int]^) -> (y: Ref[Int]^) ->{x} Unit =
47+
x => y => swap(x, y)
48+
val f: (y: Ref[Int]^{a}) ->{a} Unit = foo(a) // should be error, but we don't check params
49+
f(a)

0 commit comments

Comments
 (0)