@@ -340,11 +340,13 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
340
340
* @param fn the function
341
341
* @param parts the function prefix followed by the flattened argument list
342
342
* @param polyArg the clashing argument to a polymorphic formal
343
- * @param clashing the argument with which it clashes
343
+ * @param clashing the argument, function prefix, or entire function application result with
344
+ * which it clashes,
345
+ *
344
346
*/
345
347
def sepApplyError (fn : Tree , parts : List [Tree ], polyArg : Tree , clashing : Tree )(using Context ): Unit =
346
348
val polyArgIdx = parts.indexOf(polyArg).ensuring(_ >= 0 ) - 1
347
- val clashIdx = parts.indexOf(clashing).ensuring(_ >= 0 )
349
+ val clashIdx = parts.indexOf(clashing) // -1 means entire function application
348
350
def paramName (mt : Type , idx : Int ): Option [Name ] = mt match
349
351
case mt @ MethodType (pnames) =>
350
352
if idx < pnames.length then Some (pnames(idx)) else paramName(mt.resType, idx - pnames.length)
@@ -363,11 +365,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
363
365
if isShowableMethod then i " ${fn.symbol}: ${fn.symbol.info}"
364
366
else i " a function of type ${funType.widen}"
365
367
def clashArgStr = clashIdx match
366
- case 0 => " function prefix"
367
- case 1 => " first argument "
368
- case 2 => " second argument"
369
- case 3 => " third argument "
370
- case n => s " ${n}th argument "
368
+ case - 1 => " function result"
369
+ case 0 => " function prefix"
370
+ case 1 => " first argument "
371
+ case 2 => " second argument"
372
+ case 3 => " third argument "
373
+ case n => s " ${n}th argument "
371
374
def clashTypeStr =
372
375
if clashIdx == 0 && ! isShowableMethod then " " // we already mentioned the type in `funStr`
373
376
else i " with type ${clashing.nuType}"
@@ -455,11 +458,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
455
458
*
456
459
* @param fn the applied function
457
460
* @param args the flattened argument lists
461
+ * @param app the entire application tree
458
462
* @param deps cross argument dependencies: maps argument trees to
459
463
* those other arguments that where mentioned by coorresponding
460
464
* formal parameters.
461
465
*/
462
- private def checkApply (fn : Tree , args : List [Tree ], deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
466
+ private def checkApply (fn : Tree , args : List [Tree ], app : Tree , deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
463
467
val (qual, fnCaptures) = methPart(fn) match
464
468
case Select (qual, _) => (qual, qual.nuType.captureSet)
465
469
case _ => (fn, CaptureSet .empty)
@@ -511,6 +515,29 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
511
515
currentPeaks = PeaksPair (
512
516
currentPeaks.actual ++ argPeaks.actual,
513
517
currentPeaks.hidden ++ argPeaks.hidden)
518
+ end for
519
+
520
+ def collectRefs (args : List [Type ], res : Type ) =
521
+ args.foldLeft(argCaptures(res)): (refs, arg) =>
522
+ refs ++ arg.deepCaptureSet.elems
523
+
524
+ /** The deep capture sets of all parameters of this type (if it is a function type) */
525
+ def argCaptures (tpe : Type ): Refs = tpe match
526
+ case defn.FunctionOf (args, resultType, isContextual) =>
527
+ collectRefs(args, resultType)
528
+ case defn.RefinedFunctionOf (mt) =>
529
+ collectRefs(mt.paramInfos, mt.resType)
530
+ case CapturingType (parent, _) =>
531
+ argCaptures(parent)
532
+ case _ =>
533
+ emptyRefs
534
+
535
+ if ! deps(app).isEmpty then
536
+ lazy val appPeaks = argCaptures(app.nuType).peaks
537
+ lazy val partPeaks = partsWithPeaks.toMap
538
+ for arg <- deps(app) do
539
+ if arg.needsSepCheck && ! partPeaks(arg).hidden.sharedWith(appPeaks).isEmpty then
540
+ sepApplyError(fn, parts, arg, app)
514
541
end checkApply
515
542
516
543
/** 1. Check that the capabilities used at `tree` don't overlap with
@@ -782,44 +809,55 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
782
809
*
783
810
* f(x: A, y: B^{cap, x}, z: C^{x, y}): D
784
811
*
785
- * then the dependencies of an application `f(a, b)` is a map that takes
786
- * `b` to `List(a)` and `c` to `List(a, b)`.
812
+ * then the dependencies of an application `f(a, b, c)` of type C^{y} is the map
813
+ *
814
+ * [ b -> [a]
815
+ * , c -> [a, b]
816
+ * , f(a, b, c) -> [b]]
787
817
*/
788
- private def dependencies (fn : Tree , argss : List [List [Tree ]])(using Context ): collection.Map [Tree , List [Tree ]] =
818
+ private def dependencies (fn : Tree , argss : List [List [Tree ]], app : Tree )(using Context ): collection.Map [Tree , List [Tree ]] =
819
+ def isFunApply (sym : Symbol ) =
820
+ sym.name == nme.apply && defn.isFunctionClass(sym.owner)
789
821
val mtpe =
790
- if fn.symbol.exists then fn.symbol.info
791
- else fn.tpe .widen // happens for PolyFunction applies
822
+ if fn.symbol.exists && ! isFunApply(fn.symbol) then fn.symbol.info
823
+ else fn.nuType .widen
792
824
val mtps = collectMethodTypes(mtpe)
793
825
assert(mtps.hasSameLengthAs(argss), i " diff for $fn: ${fn.symbol} /// $mtps /// $argss" )
794
826
val mtpsWithArgs = mtps.zip(argss)
795
827
val argMap = mtpsWithArgs.toMap
796
828
val deps = mutable.HashMap [Tree , List [Tree ]]().withDefaultValue(Nil )
797
- for
798
- (mt, args) <- mtpsWithArgs
799
- (formal, arg) <- mt.paramInfos.zip(args)
800
- dep <- formal.captureSet.elems.toList
801
- do
802
- val referred = dep.stripReach match
803
- case dep : TermParamRef =>
804
- argMap(dep.binder)(dep.paramNum) :: Nil
805
- case dep : ThisType if dep.cls == fn.symbol.owner =>
806
- val Select (qual, _) = fn : @ unchecked // TODO can we use fn instead?
807
- qual :: Nil
808
- case _ =>
809
- Nil
810
- deps(arg) ++= referred
829
+
830
+ def recordDeps (formal : Type , actual : Tree ) =
831
+ for dep <- formal.captureSet.elems.toList do
832
+ val referred = dep.stripReach match
833
+ case dep : TermParamRef =>
834
+ argMap(dep.binder)(dep.paramNum) :: Nil
835
+ case dep : ThisType if dep.cls == fn.symbol.owner =>
836
+ val Select (qual, _) = fn : @ unchecked // TODO can we use fn instead?
837
+ qual :: Nil
838
+ case _ =>
839
+ Nil
840
+ deps(actual) ++= referred
841
+
842
+ for (mt, args) <- mtpsWithArgs; (formal, arg) <- mt.paramInfos.zip(args) do
843
+ recordDeps(formal, arg)
844
+ recordDeps(mtpe.finalResultType, app)
845
+ capt.println(i " deps for $app = ${deps.toList}" )
811
846
deps
812
847
848
+
813
849
/** Decompose an application into a function prefix and a list of argument lists.
814
850
* If some of the arguments need a separation check because they are capture polymorphic,
815
851
* perform a separation check with `checkApply`
816
852
*/
817
- private def traverseApply (tree : Tree , argss : List [List [Tree ]])(using Context ): Unit = tree match
818
- case Apply (fn, args) => traverseApply(fn, args :: argss)
819
- case TypeApply (fn, args) => traverseApply(fn, argss) // skip type arguments
820
- case _ =>
821
- if argss.nestedExists(_.needsSepCheck) then
822
- checkApply(tree, argss.flatten, dependencies(tree, argss))
853
+ private def traverseApply (app : Tree )(using Context ): Unit =
854
+ def recur (tree : Tree , argss : List [List [Tree ]]): Unit = tree match
855
+ case Apply (fn, args) => recur(fn, args :: argss)
856
+ case TypeApply (fn, args) => recur(fn, argss) // skip type arguments
857
+ case _ =>
858
+ if argss.nestedExists(_.needsSepCheck) then
859
+ checkApply(tree, argss.flatten, app, dependencies(tree, argss, app))
860
+ recur(app, Nil )
823
861
824
862
/** Is `tree` an application of `caps.unsafe.unsafeAssumeSeparate`? */
825
863
def isUnsafeAssumeSeparate (tree : Tree )(using Context ): Boolean = tree match
@@ -866,7 +904,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
866
904
traverseChildren(tree)
867
905
tree.tpe match
868
906
case _ : MethodOrPoly =>
869
- case _ => traverseApply(tree, Nil )
907
+ case _ => traverseApply(tree)
870
908
case _ : Block | _ : Template =>
871
909
traverseSection(tree)
872
910
case tree : ValDef =>
0 commit comments