Skip to content

Commit 7df862a

Browse files
committed
Treat exceptions as capabilities
1. Make CanThrow a @capability class 2. Fix pure arrow handling in parser 3. Avoid misleading type mismatch message 4. Make map and filter conserve Const capturesets if there's no change 5. Expand $throws clauses to context function types 6. Exempt compiletime.erasedValue for "no '*'" checks 7. Capability escape checking for try
1 parent aaa3cff commit 7df862a

File tree

14 files changed

+188
-34
lines changed

14 files changed

+188
-34
lines changed

Diff for: compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+10-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ sealed abstract class CaptureSet extends Showable:
173173
this -- ref.singletonCaptureSet
174174

175175
def filter(p: CaptureRef => Boolean)(using Context): CaptureSet =
176-
if this.isConst then Const(elems.filter(p))
176+
if this.isConst then
177+
val elems1 = elems.filter(p)
178+
if elems1 == elems then this
179+
else Const(elems.filter(p))
177180
else Filtered(asVar, p)
178181

179182
/** capture set obtained by applying `f` to all elements of the current capture set
@@ -183,11 +186,15 @@ sealed abstract class CaptureSet extends Showable:
183186
def map(tm: TypeMap)(using Context): CaptureSet = tm match
184187
case tm: BiTypeMap =>
185188
val mappedElems = elems.map(tm.forward)
186-
if isConst then Const(mappedElems)
189+
if isConst then
190+
if mappedElems == elems then this
191+
else Const(mappedElems)
187192
else BiMapped(asVar, tm, mappedElems)
188193
case _ =>
189194
val mapped = mapRefs(elems, tm, tm.variance)
190-
if isConst then mapped
195+
if isConst then
196+
if mapped.isConst && mapped.elems == elems then this
197+
else mapped
191198
else Mapped(asVar, tm, tm.variance, mapped)
192199

193200
def substParams(tl: BindingType, to: List[Type])(using Context) =

Diff for: compiler/src/dotty/tools/dotc/cc/CapturingType.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end CapturingType
2929

3030
/** An extractor for types that will be capturing types at phase CheckCaptures. Also
3131
* included are types that indicate captures on enclosing call-by-name parameters
32-
* before phase ElimByName
32+
* before phase ElimByName.
3333
*/
3434
object EventuallyCapturingType:
3535

Diff for: compiler/src/dotty/tools/dotc/cc/Setup.scala

+31-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,29 @@ extends tpd.TreeTraverser:
4444
case _ =>
4545
traverseChildren(t)
4646

47+
/** Expand some aliases of function types to the underlying functions.
48+
* Right now, these are only $throws aliases, but this could be generalized.
49+
*/
50+
def expandInlineAlias(tp: Type)(using Context) = tp match
51+
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
52+
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
53+
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
54+
case _ => tp
55+
56+
private def expandInlineAliases(using Context) = new TypeMap:
57+
def apply(t: Type) = t match
58+
case _: AppliedType =>
59+
val t1 = expandInlineAlias(t)
60+
if t1 ne t then apply(t1) else mapOver(t)
61+
case _: LazyRef =>
62+
t
63+
case t @ AnnotatedType(t1, ann) =>
64+
// Don't map capture sets, since that would implicitly normalize sets that
65+
// are not well-formed.
66+
t.derivedAnnotatedType(apply(t1), ann)
67+
case _ =>
68+
mapOver(t)
69+
4770
/** Perform the following transformation steps everywhere in a type:
4871
* 1. Drop retains annotations
4972
* 2. Turn plain function types into dependent function types, so that
@@ -143,7 +166,8 @@ extends tpd.TreeTraverser:
143166
try ts.mapConserve(this) finally isTopLevel = saved
144167

145168
def apply(t: Type) =
146-
val t1 = t match
169+
val tp = expandInlineAlias(t)
170+
val tp1 = tp match
147171
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
148172
apply(parent)
149173
case tp @ AppliedType(tycon, args) =>
@@ -172,8 +196,8 @@ extends tpd.TreeTraverser:
172196
paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds),
173197
resType = this(tp.resType))
174198
case _ =>
175-
mapOver(t)
176-
addVar(addCaptureRefinements(t1))
199+
mapOver(tp)
200+
addVar(addCaptureRefinements(tp1))
177201
end mapInferred
178202

179203
private def expandAbbreviations(using Context) = new TypeMap:
@@ -232,8 +256,10 @@ extends tpd.TreeTraverser:
232256
private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type =
233257
addBoxes.traverse(tp)
234258
if boxed then setBoxed(tp)
235-
if ctx.settings.YccNoAbbrev.value then tp
236-
else expandAbbreviations(tp)
259+
val tp1 = expandInlineAliases(tp)
260+
if tp1 ne tp then capt.println(i"expanded: $tp --> $tp1")
261+
if ctx.settings.YccNoAbbrev.value then tp1
262+
else expandAbbreviations(tp1)
237263

238264
// Substitute parameter symbols in `from` to paramRefs in corresponding
239265
// method or poly types `to`. We use a single BiTypeMap to do everything.

Diff for: compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ object Parsers {
426426
/** Convert tree to formal parameter list
427427
*/
428428
def convertToParams(tree: Tree): List[ValDef] =
429-
val mods = if in.token == CTXARROW then Modifiers(Given) else EmptyModifiers
429+
val mods =
430+
if in.token == CTXARROW || in.isIdent(nme.PURECTXARROW)
431+
then Modifiers(Given)
432+
else EmptyModifiers
430433
tree match
431434
case Parens(t) =>
432435
convertToParam(t, mods) :: Nil
@@ -1511,7 +1514,7 @@ object Parsers {
15111514
commaSeparatedRest(t, funArgType)
15121515
}
15131516
accept(RPAREN)
1514-
if isValParamList || in.isArrow then
1517+
if isValParamList || in.isArrow || in.isPureArrow then
15151518
functionRest(ts)
15161519
else {
15171520
val ts1 = ts.mapConserve { t =>

Diff for: compiler/src/dotty/tools/dotc/parsing/Scanners.scala

+3
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ object Scanners {
8888

8989
def isArrow =
9090
token == ARROW || token == CTXARROW
91+
92+
def isPureArrow =
93+
isIdent(nme.PUREARROW) || isIdent(nme.PURECTXARROW)
9194
}
9295

9396
abstract class ScannerCommon(source: SourceFile)(using Context) extends CharArrayReader with TokenData {

Diff for: compiler/src/dotty/tools/dotc/transform/Recheck.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
337337

338338
def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type =
339339
checkConforms(tpe, pt, tree)
340-
if keepTypes then tree.rememberType(tpe)
340+
if keepTypes
341+
|| tree.isInstanceOf[Try] // type needs tp be checked for * escapes
342+
then tree.rememberType(tpe)
341343
tpe
342344

343345
def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
@@ -363,6 +365,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
363365
|| expected.isRepeatedParam
364366
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
365367
if !isCompatible then
368+
recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected")
366369
err.typeMismatch(tree.withType(tpe), expected)
367370
else if debugSuccesses then
368371
tree match

Diff for: compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

+28-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import transform.Recheck
1616
import Recheck.*
1717
import scala.collection.mutable
1818
import CaptureSet.withCaptureSetsExplained
19+
import reporting.trace
1920

2021
object CheckCaptures:
2122
import ast.tpd.*
@@ -75,7 +76,15 @@ object CheckCaptures:
7576
if remaining.accountsFor(firstRef) then
7677
report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos)
7778

78-
private inline val disallowGlobal = true
79+
/** Does this function allow type arguments carrying the universal capability?
80+
* Currently this is true only for `erasedValue` since this function is magic in
81+
* that is allows to conjure global capabilies from nothing (aside: can we find a
82+
* more controlled way to achieve this?).
83+
* But it could be generalized to other functions that so that they can take capability
84+
* classes as arguments.
85+
*/
86+
private def allowUniversalArguments(fn: Tree)(using Context): Boolean =
87+
fn.symbol == defn.Compiletime_erasedValue
7988

8089
class CheckCaptures extends Recheck:
8190
thisPhase =>
@@ -305,13 +314,13 @@ class CheckCaptures extends Recheck:
305314
.traverse(ctx.compilationUnit.tpdTree)
306315
withCaptureSetsExplained {
307316
super.checkUnit(unit)
308-
PostRefinerCheck.traverse(unit.tpdTree)
317+
PostCheck.traverse(unit.tpdTree)
309318
if ctx.settings.YccDebug.value then
310319
show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing
311320
}
312321

313-
def checkNotGlobal(tree: Tree, tp: Type, allArgs: Tree*)(using Context): Unit =
314-
for ref <-tp.captureSet.elems do
322+
def checkNotGlobal(tree: Tree, tp: Type, isVar: Boolean, allArgs: Tree*)(using Context): Unit =
323+
for ref <- tp.captureSet.elems do
315324
val isGlobal = ref match
316325
case ref: TermRef => ref.isRootCapability
317326
case _ => false
@@ -320,7 +329,7 @@ class CheckCaptures extends Recheck:
320329
val notAllowed = i" is not allowed to capture the $what capability $ref"
321330
def msg =
322331
if allArgs.isEmpty then
323-
i"type of mutable variable ${tree.knownType}$notAllowed"
332+
i"${if isVar then "type of mutable variable" else "result type"} ${tree.knownType}$notAllowed"
324333
else tree match
325334
case tree: InferredTypeTree =>
326335
i"""inferred type argument ${tree.knownType}$notAllowed
@@ -330,12 +339,11 @@ class CheckCaptures extends Recheck:
330339
report.error(msg, tree.srcPos)
331340

332341
def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit =
333-
if disallowGlobal then
334-
tree match
335-
case LambdaTypeTree(_, restpt) =>
336-
checkNotGlobal(restpt, allArgs*)
337-
case _ =>
338-
checkNotGlobal(tree, tree.knownType, allArgs*)
342+
tree match
343+
case LambdaTypeTree(_, restpt) =>
344+
checkNotGlobal(restpt, allArgs*)
345+
case _ =>
346+
checkNotGlobal(tree, tree.knownType, isVar = false, allArgs*)
339347

340348
def checkNotGlobalDeep(tree: Tree)(using Context): Unit =
341349
val checker = new TypeTraverser:
@@ -346,12 +354,12 @@ class CheckCaptures extends Recheck:
346354
case _ =>
347355
case tp: TermRef =>
348356
case _ =>
349-
checkNotGlobal(tree, tp)
357+
checkNotGlobal(tree, tp, isVar = true)
350358
traverseChildren(tp)
351359
checker.traverse(tree.knownType)
352360

353-
object PostRefinerCheck extends TreeTraverser:
354-
def traverse(tree: Tree)(using Context) =
361+
object PostCheck extends TreeTraverser:
362+
def traverse(tree: Tree)(using Context) = trace{i"post check $tree"} {
355363
tree match
356364
case _: InferredTypeTree =>
357365
case tree: TypeTree if !tree.span.isZeroExtent =>
@@ -362,7 +370,7 @@ class CheckCaptures extends Recheck:
362370
checkWellformedPost(annot.tree)
363371
case _ =>
364372
}
365-
case tree1 @ TypeApply(fn, args) if disallowGlobal =>
373+
case tree1 @ TypeApply(fn, args) if !allowUniversalArguments(fn) =>
366374
for arg <- args do
367375
//println(i"checking $arg in $tree: ${tree.knownType.captureSet}")
368376
checkNotGlobal(arg, args*)
@@ -390,11 +398,14 @@ class CheckCaptures extends Recheck:
390398
inferred.foreachPart(checkPure, StopAt.Static)
391399
case t: ValDef if t.symbol.is(Mutable) =>
392400
checkNotGlobalDeep(t.tpt)
401+
case t: Try =>
402+
checkNotGlobal(t)
393403
case _ =>
394404
traverseChildren(tree)
405+
}
395406

396-
def postRefinerCheck(tree: tpd.Tree)(using Context): Unit =
397-
PostRefinerCheck.traverse(tree)
407+
def postCheck(tree: tpd.Tree)(using Context): Unit =
408+
PostCheck.traverse(tree)
398409

399410
end CaptureChecker
400411
end CheckCaptures

Diff for: compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

+14-2
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,25 @@ object ErrorReporting {
125125
def typeMismatch(tree: Tree, pt: Type, implicitFailure: SearchFailureType = NoMatchingImplicits): Tree = {
126126
val normTp = normalize(tree.tpe, pt)
127127
val normPt = normalize(pt, pt)
128+
129+
def contextFunctionCount(tp: Type): Int = tp.stripped match
130+
case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp)
131+
case _ => 0
132+
def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp)
133+
def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt)
134+
128135
val (treeTp, expectedTp) =
129-
if (normTp <:< normPt) (tree.tpe, pt) else (normTp, normPt)
130-
// use normalized types if that also shows an error, original types otherwise
136+
if normTp <:< normPt || strippedTpCount != strippedPtCount
137+
then (tree.tpe, pt)
138+
else (normTp, normPt)
139+
// use normalized types if that also shows an error, and both sides stripped
140+
// the same number of context functions. Use original types otherwise.
141+
131142
def missingElse = tree match
132143
case If(_, _, elsep @ Literal(Constant(()))) if elsep.span.isSynthetic =>
133144
"\nMaybe you are missing an else part for the conditional?"
134145
case _ => ""
146+
135147
errorTree(tree, TypeMismatch(treeTp, expectedTp, Some(tree), implicitFailure.whyNoConversion, missingElse))
136148
}
137149

Diff for: compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2412,7 +2412,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24122412
//todo: make sure dependent method types do not depend on implicits or by-name params
24132413
}
24142414

2415-
/** (1) Check that the signature of the class mamber does not return a repeated parameter type
2415+
/** (1) Check that the signature of the class member does not return a repeated parameter type
24162416
* (2) If info is an erased class, set erased flag of member
24172417
*/
24182418
private def postProcessInfo(sym: Symbol)(using Context): Unit =

Diff for: compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] {
3838
((SimpleIdentitySet.empty: SimpleIdentitySet[E]) /: this) { (s, x) =>
3939
if (that.contains(x)) s else s + x
4040
}
41+
42+
def == [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): Boolean =
43+
this.size == that.size && forall(that.contains)
44+
4145
override def toString: String = toList.mkString("{", ", ", "}")
4246
}
4347

Diff for: library/src/scala/CanThrow.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package scala
22
import language.experimental.erasedDefinitions
3-
import annotation.{implicitNotFound, experimental}
3+
import annotation.{implicitNotFound, experimental, capability}
44

55
/** A capability class that allows to throw exception `E`. When used with the
66
* experimental.saferExceptions feature, a `throw Ex()` expression will require
77
* a given of class `CanThrow[Ex]` to be available.
88
*/
9-
@experimental
9+
@experimental @capability
1010
@implicitNotFound("The capability to throw exception ${E} is missing.\nThe capability can be provided by one of the following:\n - Adding a using clause `(using CanThrow[${E}])` to the definition of the enclosing method\n - Adding `throws ${E}` clause after the result type of the enclosing method\n - Wrapping this piece of code with a `try` block that catches ${E}")
1111
erased class CanThrow[-E <: Exception]
1212

Diff for: tests/neg-custom-args/captures/real-try.check

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Error: tests/neg-custom-args/captures/real-try.scala:10:2 -----------------------------------------------------------
2+
10 | try // error
3+
| ^
4+
| result type {*} () -> Unit is not allowed to capture the universal capability *.type
5+
11 | () => foo(1)
6+
12 | catch
7+
13 | case _: Ex1 => ???
8+
14 | case _: Ex2 => ???

Diff for: tests/neg-custom-args/captures/real-try.scala

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import language.experimental.saferExceptions
2+
3+
class Ex1 extends Exception("Ex1")
4+
class Ex2 extends Exception("Ex2")
5+
6+
def foo(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit =
7+
if i > 0 then throw new Ex1 else throw new Ex2
8+
9+
def test() =
10+
try // error
11+
() => foo(1)
12+
catch
13+
case _: Ex1 => ???
14+
case _: Ex2 => ???

0 commit comments

Comments
 (0)