Skip to content

Commit ffd42da

Browse files
committed
Use peaks-based checking for applications
1 parent 4f98208 commit ffd42da

15 files changed

+324
-261
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import util.common.alwaysTrue
1818
import scala.collection.{mutable, immutable}
1919
import CCState.*
2020
import TypeOps.AvoidMap
21+
import compiletime.uninitialized
2122

2223
/** A class for capture sets. Capture sets can be constants or variables.
2324
* Capture sets support inclusion constraints <:< where <:< is subcapturing.
@@ -942,8 +943,9 @@ object CaptureSet:
942943
* which are already subject through snapshotting and rollbacks in VarState.
943944
* It's advantageous if we don't need to deal with other pieces of state there.
944945
*/
945-
class HiddenSet(initialHidden: Refs = emptyRefs)(using @constructorOnly ictx: Context)
946-
extends Var(initialElems = initialHidden):
946+
class HiddenSet(owner: Symbol, initialHidden: Refs = emptyRefs)(using @constructorOnly ictx: Context)
947+
extends Var(owner, initialHidden):
948+
var owningCap: AnnotatedType = uninitialized
947949

948950
private def aliasRef: AnnotatedType | Null =
949951
if myElems.size == 1 then
@@ -959,6 +961,9 @@ object CaptureSet:
959961
case _ => this
960962
else this
961963

964+
def superCaps: List[AnnotatedType] =
965+
deps.toList.map(_.asInstanceOf[HiddenSet].owningCap)
966+
962967
override def elems: Refs =
963968
val al = aliasSet
964969
if al eq this then super.elems else al.elems

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ class CheckCaptures extends Recheck, SymTransformer:
803803
var refined: Type = core
804804
var allCaptures: CaptureSet =
805805
if core.derivesFromMutable then initCs ++ CaptureSet.fresh()
806-
else if core.derivesFromCapability then initCs ++ Fresh.Cap().readOnly.singletonCaptureSet
806+
else if core.derivesFromCapability then initCs ++ Fresh.Cap(core.classSymbol).readOnly.singletonCaptureSet
807807
else initCs
808808
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
809809
val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol

compiler/src/dotty/tools/dotc/cc/Existential.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ object Existential:
245245
/** Map top-level existentials to `Fresh.Cap`. */
246246
def toCap(tp: Type)(using Context): Type = tp.dealiasKeepAnnots match
247247
case Existential(boundVar, unpacked) =>
248-
unpacked.substParam(boundVar, Fresh.Cap())
248+
unpacked.substParam(boundVar, Fresh.Cap(NoSymbol))
249249
case tp1 @ CapturingType(parent, refs) =>
250250
tp1.derivedCapturingType(toCap(parent), refs)
251251
case tp1 @ AnnotatedType(parent, ann) =>
@@ -256,7 +256,7 @@ object Existential:
256256
*/
257257
def toCapDeeply(tp: Type)(using Context): Type = tp.dealiasKeepAnnots match
258258
case Existential(boundVar, unpacked) =>
259-
toCapDeeply(unpacked.substParam(boundVar, Fresh.Cap()))
259+
toCapDeeply(unpacked.substParam(boundVar, Fresh.Cap(NoSymbol)))
260260
case tp1 @ FunctionOrMethod(args, res) =>
261261
val tp2 = tp1.derivedFunctionOrMethod(args, toCapDeeply(res))
262262
if tp2 ne tp1 then tp2 else tp
@@ -317,7 +317,7 @@ object Existential:
317317
//.showing(i"mapcap $t = $result")
318318

319319
lazy val inverse = new BiTypeMap:
320-
lazy val freshCap = Fresh.Cap()
320+
lazy val freshCap = Fresh.Cap(NoSymbol)
321321
def apply(t: Type) = t match
322322
case t: TermParamRef if t eq boundVar => freshCap
323323
case _ => mapOver(t)

compiler/src/dotty/tools/dotc/cc/Fresh.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,21 @@ object Fresh:
5050
/** An extractor for "fresh" capabilities */
5151
object Cap:
5252

53-
def apply(initialHidden: Refs = emptyRefs)(using Context): CaptureRef =
53+
def apply(owner: Symbol, initialHidden: Refs = emptyRefs)(using Context): CaptureRef =
5454
if ccConfig.useSepChecks then
55-
AnnotatedType(defn.captureRoot.termRef, Annot(CaptureSet.HiddenSet(initialHidden)))
55+
val hiddenSet = CaptureSet.HiddenSet(owner, initialHidden)
56+
val res = AnnotatedType(defn.captureRoot.termRef, Annot(hiddenSet))
57+
hiddenSet.owningCap = res
58+
//assert(hiddenSet.id != 3)
59+
res
5660
else
5761
defn.captureRoot.termRef
5862

5963
def apply(owner: Symbol, reach: Boolean)(using Context): CaptureRef =
60-
apply(ownerToHidden(owner, reach))
64+
apply(owner, ownerToHidden(owner, reach))
6165

6266
def apply(owner: Symbol)(using Context): CaptureRef =
63-
apply(ownerToHidden(owner, reach = false))
67+
apply(owner, ownerToHidden(owner, reach = false))
6468

6569
def unapply(tp: AnnotatedType): Option[CaptureSet.HiddenSet] = tp.annot match
6670
case Annot(hidden) => Some(hidden)
@@ -77,7 +81,7 @@ object Fresh:
7781
if variance <= 0 then t
7882
else t match
7983
case t: CaptureRef if t.isCap =>
80-
Cap(ownerToHidden(owner, reach))
84+
Cap(owner, ownerToHidden(owner, reach))
8185
case t @ CapturingType(_, refs) =>
8286
val savedReach = reach
8387
if t.isBoxed then reach = true

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

+140-72
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ object SepCheck:
142142

143143
val EmptyConsumedSet = ConstConsumedSet(Array(), Array())
144144

145+
case class PeaksPair(actual: Refs, formal: Refs)
146+
145147
class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
146148
import checker.*
147149
import SepCheck.*
@@ -194,6 +196,52 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
194196
val elems: Refs = refs.filter(!_.isMaxCapability)
195197
recur(elems, elems.toList)
196198

199+
/** The members of type Fresh.Cap(...) or Fresh.Cap(...).rd in the transitive closure
200+
* of this set
201+
*/
202+
private def freshElems(using Context): Refs =
203+
def recur(seen: Refs, acc: Refs, newElems: List[CaptureRef]): Refs = newElems match
204+
case newElem :: newElems1 =>
205+
if seen.contains(newElem) then
206+
recur(seen, acc, newElems1)
207+
else newElem.stripReadOnly match
208+
case Fresh.Cap(_) =>
209+
recur(seen, acc + newElem, newElems1)
210+
//case _: TypeRef | _: TypeParamRef =>
211+
// recur(seen + newElem, acc, newElems1)
212+
case _ =>
213+
recur(seen + newElem, acc, newElem.captureSetOfInfo.elems.toList ++ newElems1)
214+
case Nil => acc
215+
recur(emptyRefs, emptyRefs, refs.toList)
216+
217+
private def peaks(using Context): Refs =
218+
def recur(seen: Refs, acc: Refs, newElems: List[CaptureRef]): Refs = newElems match
219+
case newElem :: newElems1 =>
220+
if seen.contains(newElem) then
221+
recur(seen, acc, newElems1)
222+
else newElem.stripReadOnly match
223+
case Fresh.Cap(hidden) =>
224+
if hidden.deps.isEmpty then recur(seen + newElem, acc + newElem, newElems1)
225+
else
226+
val superCaps =
227+
if newElem.isReadOnly then hidden.superCaps.map(_.readOnly)
228+
else hidden.superCaps
229+
recur(seen + newElem, acc, superCaps ++ newElems)
230+
case _ =>
231+
if newElem.isMaxCapability
232+
//|| newElem.isInstanceOf[TypeRef | TypeParamRef]
233+
then recur(seen + newElem, acc, newElems1)
234+
else recur(seen + newElem, acc, newElem.captureSetOfInfo.elems.toList ++ newElems1)
235+
case Nil => acc
236+
recur(emptyRefs, emptyRefs, refs.toList)
237+
238+
/** The shared peaks between `refs` and `other` */
239+
private def sharedWith(other: Refs)(using Context): Refs =
240+
def common(refs1: Refs, refs2: Refs) =
241+
refs1.filter: ref =>
242+
!ref.isReadOnly && refs2.exists(_.stripReadOnly eq ref)
243+
common(refs, other) ++ common(other, refs)
244+
197245
/** The overlap of two footprint sets F1 and F2. This contains all exclusive references `r`
198246
* such that one of the following is true:
199247
* 1.
@@ -266,6 +314,11 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
266314
recur(refs)
267315
end containsHidden
268316

317+
def hiddenSet(using Context): Refs =
318+
freshElems.flatMap:
319+
case Fresh.Cap(hidden) => hidden.elems
320+
case ReadOnlyCapability(Fresh.Cap(hidden)) => hidden.elems.map(_.readOnly)
321+
269322
/** Subtract all elements that are covered by some element in `others` from this set. */
270323
private def deduct(others: Refs)(using Context): Refs =
271324
refs.filter: ref =>
@@ -297,29 +350,21 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
297350

298351
/** Report a separation failure in an application `fn(args)`
299352
* @param fn the function
300-
* @param args the flattened argument lists
301-
* @param argIdx the index of the failing argument in `args`, starting at 0
302-
* @param overlap the overlap causing the failure
303-
* @param hiddenInArg the hidxden set of the type of the failing argument
304-
* @param footprints a sequence of partial footprints, and the index of the
305-
* last argument they cover.
306-
* @param deps cross argument dependencies: maps argument trees to
307-
* those other arguments that where mentioned by coorresponding
308-
* formal parameters.
353+
* @param parts the function prefix followed by the flattened argument list
354+
* @param polyArg the clashing argument to a polymorphic formal
355+
* @param clashing the argument with which it clashes
309356
*/
310-
private def sepApplyError(fn: Tree, args: List[Tree], argIdx: Int,
311-
overlap: Refs, hiddenInArg: Refs, footprints: List[(Refs, Int)],
312-
deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
313-
val arg = args(argIdx)
357+
def sepApplyError(fn: Tree, parts: List[Tree], polyArg: Tree, clashing: Tree)(using Context): Unit =
358+
val polyArgIdx = parts.indexOf(polyArg).ensuring(_ >= 0) - 1
359+
val clashIdx = parts.indexOf(clashing).ensuring(_ >= 0)
314360
def paramName(mt: Type, idx: Int): Option[Name] = mt match
315361
case mt @ MethodType(pnames) =>
316362
if idx < pnames.length then Some(pnames(idx)) else paramName(mt.resType, idx - pnames.length)
317363
case mt: PolyType => paramName(mt.resType, idx)
318364
case _ => None
319-
def formalName = paramName(fn.nuType.widen, argIdx) match
365+
def formalName = paramName(fn.nuType.widen, polyArgIdx) match
320366
case Some(pname) => i"$pname "
321367
case _ => ""
322-
def whatStr = if overlap.size == 1 then "this capability is" else "these capabilities are"
323368
def qualifier = methPart(fn) match
324369
case Select(qual, _) => qual
325370
case _ => EmptyTree
@@ -329,43 +374,45 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
329374
def funStr =
330375
if isShowableMethod then i"${fn.symbol}: ${fn.symbol.info}"
331376
else i"a function of type ${funType.widen}"
332-
val clashIdx = footprints
333-
.collect:
334-
case (fp, idx) if !hiddenInArg.overlapWith(fp).isEmpty => idx
335-
.head
336-
def whereStr = clashIdx match
377+
def clashArgStr = clashIdx match
337378
case 0 => "function prefix"
338379
case 1 => "first argument "
339380
case 2 => "second argument"
340381
case 3 => "third argument "
341382
case n => s"${n}th argument "
342-
def clashTree =
343-
if clashIdx == 0 then qualifier
344-
else args(clashIdx - 1)
345383
def clashTypeStr =
346384
if clashIdx == 0 && !isShowableMethod then "" // we already mentioned the type in `funStr`
347-
else i" with type ${clashTree.nuType}"
348-
def clashCaptures = captures(clashTree)
349-
def hiddenCaptures = formalCaptures(arg).hidden
350-
def clashFootprint = clashCaptures.footprint
351-
def hiddenFootprint = hiddenCaptures.footprint
352-
def declaredFootprint = deps(arg).map(captures(_)).foldLeft(emptyRefs)(_ ++ _).footprint
353-
def footprintOverlap = hiddenFootprint.overlapWith(clashFootprint).deduct(declaredFootprint)
385+
else i" with type ${clashing.nuType}"
386+
val hiddenSet = formalCaptures(polyArg).hiddenSet
387+
val clashSet = captures(clashing)
388+
val hiddenFootprint = hiddenSet.footprint
389+
val clashFootprint = clashSet.footprint
390+
val overlapStr =
391+
// The overlap of footprints, or, of this empty the set of shared peaks.
392+
// We prefer footprint overlap since it tends to be more informative.
393+
val overlap = hiddenFootprint.overlapWith(clashFootprint)
394+
if !overlap.isEmpty then i"${CaptureSet(overlap)}"
395+
else
396+
val sharedPeaks = hiddenSet.peaks.sharedWith(clashSet.peaks)
397+
assert(!sharedPeaks.isEmpty,
398+
i"no overlap for $polyArg: $hiddenSet} vs $clashing: $clashSet")
399+
sharedPeaks.nth(0) match
400+
case fresh @ Fresh.Cap(hidden) =>
401+
if hidden.owner.exists then i"cap of ${hidden.owner}" else i"$fresh"
402+
354403
report.error(
355-
em"""Separation failure: argument of type ${arg.nuType}
404+
em"""Separation failure: argument of type ${polyArg.nuType}
356405
|to $funStr
357-
|corresponds to capture-polymorphic formal parameter ${formalName}of type ${arg.formalType}
358-
|and captures ${CaptureSet(overlap)}, but $whatStr also passed separately
359-
|in the ${whereStr.trim}$clashTypeStr.
406+
|corresponds to capture-polymorphic formal parameter ${formalName}of type ${polyArg.formalType}
407+
|and hides capabilities ${CaptureSet(hiddenSet)}.
408+
|Some of these overlap with the captures of the ${clashArgStr.trim}$clashTypeStr.
360409
|
361-
| Capture set of $whereStr : ${CaptureSet(clashCaptures)}
362-
| Hidden set of current argument : ${CaptureSet(hiddenCaptures)}
363-
| Footprint of $whereStr : ${CaptureSet(clashFootprint)}
364-
| Hidden footprint of current argument : ${CaptureSet(hiddenFootprint)}
365-
| Declared footprint of current argument: ${CaptureSet(declaredFootprint)}
366-
| Undeclared overlap of footprints : ${CaptureSet(footprintOverlap)}""",
367-
arg.srcPos)
368-
end sepApplyError
410+
| Hidden set of current argument : ${CaptureSet(hiddenSet)}
411+
| Hidden footprint of current argument : ${CaptureSet(hiddenSet.footprint)}
412+
| Capture set of $clashArgStr : ${CaptureSet(clashSet)}
413+
| Footprint set of $clashArgStr : ${CaptureSet(clashSet.footprint)}
414+
| The two sets overlap at : $overlapStr""",
415+
polyArg.srcPos)
369416

370417
/** Report a use/definition failure, where a previously hidden capability is
371418
* used again.
@@ -445,37 +492,58 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
445492
* formal parameters.
446493
*/
447494
private def checkApply(fn: Tree, args: List[Tree], deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
448-
val fnCaptures = methPart(fn) match
449-
case Select(qual, _) => qual.nuType.captureSet
450-
case _ => CaptureSet.empty
451-
capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = ${args.map(arg => CaptureSet(formalCaptures(arg)))}, deps = ${deps.toList}")
452-
var footprint = fnCaptures.elems.footprint
453-
val footprints = mutable.ListBuffer[(Refs, Int)]((footprint, 0))
454-
val indexedArgs = args.zipWithIndex
455-
456-
// First, compute all footprints of arguments to monomorphic pararameters,
457-
// separately in `footprints`, and their union in `footprint`.
458-
for (arg, idx) <- indexedArgs do
459-
if !arg.needsSepCheck then
460-
footprint = footprint ++ captures(arg).footprint.deductCapturesOf(deps(arg))
461-
footprints += ((footprint, idx + 1))
462-
463-
// Then, for each argument to a polymorphic parameter:
464-
// - check formal type via checkType
465-
// - check that hidden set of argument does not overlap with current footprint
466-
// - add footprint of the deep capture set of actual type of argument
467-
// to global footprint(s)
468-
for (arg, idx) <- indexedArgs do
495+
val (qual, fnCaptures) = methPart(fn) match
496+
case Select(qual, _) => (qual, qual.nuType.captureSet)
497+
case _ => (fn, CaptureSet.empty)
498+
var currentPeaks = PeaksPair(fnCaptures.elems.peaks, emptyRefs)
499+
val peaksOfTree: Map[Tree, PeaksPair] =
500+
((qual -> currentPeaks) :: args.map: arg =>
501+
arg -> PeaksPair(
502+
captures(arg).peaks,
503+
if arg.needsSepCheck then formalCaptures(arg).hiddenSet.peaks else emptyRefs)
504+
).toMap
505+
capt.println(
506+
i"""check separate $fn($args), fnCaptures = $fnCaptures,
507+
| formalCaptures = ${args.map(arg => CaptureSet(formalCaptures(arg)))},
508+
| actualCaptures = ${args.map(arg => CaptureSet(captures(arg)))},
509+
| formalPeaks = ${peaksOfTree.values.map(_.formal).toList}
510+
| actualPeaks = ${peaksOfTree.values.map(_.actual).toList}
511+
| deps = ${deps.toList}""")
512+
val parts = qual :: args
513+
514+
for arg <- args do
515+
val argPeaks = peaksOfTree(arg)
516+
val argDeps = deps(arg)
517+
518+
def clashingPart(argPeaks: Refs, selector: PeaksPair => Refs): Tree =
519+
parts.iterator.takeWhile(_ ne arg).find: prev =>
520+
!argDeps.contains(prev)
521+
&& !selector(peaksOfTree(prev)).sharedWith(argPeaks).isEmpty
522+
.getOrElse(EmptyTree)
523+
524+
// 1. test argPeaks.actual against previously captured formals
525+
if !argPeaks.actual.sharedWith(currentPeaks.formal).isEmpty then
526+
val clashing = clashingPart(argPeaks.actual, _.formal)
527+
if !clashing.isEmpty then sepApplyError(fn, parts, clashing, arg)
528+
else assert(!argDeps.isEmpty)
529+
469530
if arg.needsSepCheck then
470-
val ac = formalCaptures(arg)
531+
//println(i"testing $arg, ${argPeaks.actual}/${argPeaks.formal} against ${currentPeaks.actual}")
471532
checkType(arg.formalType, arg.srcPos, TypeRole.Argument(arg))
472-
val hiddenInArg = ac.hidden.footprint
473-
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
474-
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
475-
if !overlap.isEmpty then
476-
sepApplyError(fn, args, idx, overlap, hiddenInArg, footprints.toList, deps)
477-
footprint ++= captures(arg).footprint
478-
footprints += ((footprint, idx + 1))
533+
// 2. test argPeaks.formal against previously hidden actuals
534+
if !argPeaks.formal.sharedWith(currentPeaks.actual).isEmpty then
535+
val clashing = clashingPart(argPeaks.formal, _.actual)
536+
if !clashing.isEmpty then
537+
if !clashing.needsSepCheck then
538+
// if clashing needs a separation check then we already got an erro
539+
// in (1) at position of clashing. No need to report it twice.
540+
//println(i"CLASH $arg / ${argPeaks.formal} vs $clashing / ${peaksOfTree(clashing).actual} / ${captures(clashing).peaks}")
541+
sepApplyError(fn, parts, arg, clashing)
542+
else assert(!argDeps.isEmpty)
543+
544+
currentPeaks = PeaksPair(
545+
currentPeaks.actual ++ argPeaks.actual,
546+
currentPeaks.formal ++ argPeaks.formal)
479547
end checkApply
480548

481549
/** The def/use overlap between the references `hiddenByDef` hidden by
@@ -757,7 +825,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
757825
case dep: TermParamRef =>
758826
argMap(dep.binder)(dep.paramNum) :: Nil
759827
case dep: ThisType if dep.cls == fn.symbol.owner =>
760-
val Select(qual, _) = fn: @unchecked
828+
val Select(qual, _) = fn: @unchecked // TODO can we use fn instead?
761829
qual :: Nil
762830
case _ =>
763831
Nil

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
442442
case ReachCapability(tp1) => toTextCaptureRef(tp1) ~ "*"
443443
case MaybeCapability(tp1) => toTextCaptureRef(tp1) ~ "?"
444444
case Fresh.Cap(hidden) =>
445-
if printFreshDetailed then s"<cap${hashStr(tp)} hiding " ~ toTextCaptureSet(hidden) ~ ">"
445+
val idStr = if showUniqueIds then s"#${hidden.id}" else ""
446+
if printFreshDetailed then s"<cap$idStr hiding " ~ toTextCaptureSet(hidden) ~ ">"
446447
else if printFresh then "fresh"
447448
else "cap"
448449
case tp => toText(tp)

0 commit comments

Comments
 (0)