Skip to content

Commit aaa3cff

Browse files
committed
Handle captures in by-name parameters
1. Infrastructure to deal with capturesets in byname parameters 2. Handle retainsByName annotations in ElimByName Convert them to regular annotations on the generated function types. This enables capture checking on by-name parameters. 3. Add a style warning for misleading by-name parameter type formatting. By-name types should be formatted `{...}-> T`. `{...} -> T` looks too much like a function type.
1 parent f8f7680 commit aaa3cff

24 files changed

+233
-84
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ object desugar {
468468

469469
if mods.is(Trait) then
470470
for vparams <- originalVparamss; vparam <- vparams do
471-
if vparam.tpt.isInstanceOf[ByNameTypeTree] then
471+
if isByNameType(vparam.tpt) then
472472
report.error(em"implementation restriction: traits cannot have by name parameters", vparam.srcPos)
473473

474474
// Annotations on class _type_ parameters are set on the derived parameters
@@ -576,9 +576,8 @@ object desugar {
576576
appliedTypeTree(tycon, targs)
577577
}
578578

579-
def isRepeated(tree: Tree): Boolean = tree match {
579+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
580580
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
581-
case ByNameTypeTree(tree1) => isRepeated(tree1)
582581
case _ => false
583582
}
584583

@@ -1811,8 +1810,13 @@ object desugar {
18111810
case ext: ExtMethods =>
18121811
Block(List(ext), Literal(Constant(())).withSpan(ext.span))
18131812
case CapturingTypeTree(refs, parent) =>
1814-
val annot = New(scalaDot(tpnme.retains), List(refs))
1815-
Annotated(parent, annot)
1813+
def annotate(annotName: TypeName, tp: Tree) =
1814+
Annotated(tp, New(scalaDot(annotName), List(refs)))
1815+
parent match
1816+
case ByNameTypeTree(restpt) =>
1817+
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
1818+
case _ =>
1819+
annotate(tpnme.retains, parent)
18161820
}
18171821
desugared.withSpan(tree.span)
18181822
}

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+21-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
178178
}
179179

180180
/** Is tpt a vararg type of the form T* or => T*? */
181-
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = tpt match {
182-
case ByNameTypeTree(tpt1) => isRepeatedParamType(tpt1)
181+
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = stripByNameType(tpt) match {
183182
case tpt: TypeTree => tpt.typeOpt.isRepeatedParam
184183
case AppliedTypeTree(Select(_, tpnme.REPEATED_PARAM_CLASS), _) => true
185184
case _ => false
@@ -196,6 +195,16 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
196195
case arg => arg.typeOpt.widen.isRepeatedParam
197196
}
198197

198+
def isByNameType(tree: Tree)(using Context): Boolean =
199+
stripByNameType(tree) ne tree
200+
201+
def stripByNameType(tree: Tree)(using Context): Tree = unsplice(tree) match
202+
case ByNameTypeTree(t1) => t1
203+
case untpd.CapturingTypeTree(_, parent) =>
204+
val parent1 = stripByNameType(parent)
205+
if parent1 eq parent then tree else parent1
206+
case _ => tree
207+
199208
/** All type and value parameter symbols of this DefDef */
200209
def allParamSyms(ddef: DefDef)(using Context): List[Symbol] =
201210
ddef.paramss.flatten.map(_.symbol)
@@ -388,6 +397,16 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
388397
case _ => None
389398
}
390399
}
400+
401+
object ImpureByNameTypeTree:
402+
def apply(tp: ByNameTypeTree)(using Context): untpd.CapturingTypeTree =
403+
untpd.CapturingTypeTree(
404+
Ident(nme.CAPTURE_ROOT).withSpan(tp.span.startPos) :: Nil, tp)
405+
def unapply(tp: Tree)(using Context): Option[ByNameTypeTree] = tp match
406+
case untpd.CapturingTypeTree(id @ Ident(nme.CAPTURE_ROOT) :: Nil, bntp: ByNameTypeTree)
407+
if id.span == bntp.span.startPos => Some(bntp)
408+
case _ => None
409+
end ImpureByNameTypeTree
391410
}
392411

393412
trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

+11-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import printing.Printer
1212
import printing.Texts.Text
1313

1414

15-
case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation:
15+
case class CaptureAnnotation(refs: CaptureSet, kind: CapturingKind) extends Annotation:
1616
import CaptureAnnotation.*
1717
import tpd.*
1818

@@ -25,25 +25,26 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
2525
val arg = repeated(elems, TypeTree(defn.AnyType))
2626
New(symbol.typeRef, arg :: Nil)
2727

28-
override def symbol(using Context) = defn.RetainsAnnot
28+
override def symbol(using Context) =
29+
if kind == CapturingKind.ByName then defn.RetainsByNameAnnot else defn.RetainsAnnot
2930

3031
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
unsupported("derivedAnnotation(Tree)")
3233

33-
def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation =
34-
if (this.refs eq refs) && (this.boxed == boxed) then this
35-
else CaptureAnnotation(refs, boxed)
34+
def derivedAnnotation(refs: CaptureSet, kind: CapturingKind)(using Context): Annotation =
35+
if (this.refs eq refs) && (this.kind == kind) then this
36+
else CaptureAnnotation(refs, kind)
3637

3738
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
38-
case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2
39+
case CaptureAnnotation(refs2, kind2) => refs == refs2 && kind == kind2
3940
case _ => false
4041

4142
override def mapWith(tp: TypeMap)(using Context) =
4243
val elems = refs.elems.toList
4344
val elems1 = elems.mapConserve(tp)
4445
if elems1 eq elems then this
4546
else if elems1.forall(_.isInstanceOf[CaptureRef])
46-
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
47+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), kind)
4748
else EmptyAnnotation
4849

4950
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -54,10 +55,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
5455

5556
override def toText(printer: Printer): Text = refs.toText(printer)
5657

57-
override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0)
58+
override def hash: Int =
59+
(refs.hashCode << 1) | (if kind == CapturingKind.Regular then 0 else 1)
5860

5961
override def eql(that: Annotation) = that match
60-
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed)
62+
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.kind == kind)
6163
case _ => false
6264

6365
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ extension (tree: Tree)
4343
extension (tp: Type)
4444

4545
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
46-
case CapturingType(p, r, b) =>
46+
case CapturingType(p, r, k) =>
4747
if (parent eq p) && (refs eq r) then tp
48-
else CapturingType(parent, refs, b)
48+
else CapturingType(parent, refs, k)
4949

5050
/** If this is type variable instantiated or upper bounded with a capturing type,
5151
* the capture set associated with that type. Extended to and-or types and
@@ -54,7 +54,8 @@ extension (tp: Type)
5454
*/
5555
def boxedCaptured(using Context): CaptureSet =
5656
def getBoxed(tp: Type): CaptureSet = tp match
57-
case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty
57+
case CapturingType(_, refs, CapturingKind.Boxed) => refs
58+
case CapturingType(_, _, _) => CaptureSet.empty
5859
case tp: TypeProxy => getBoxed(tp.superType)
5960
case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2)
6061
case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2)

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ sealed abstract class CaptureSet extends Showable:
209209
((NoType: Type) /: elems) ((tp, ref) =>
210210
if tp.exists then OrType(tp, ref, soft = false) else ref)
211211

212-
def toRegularAnnotation(using Context): Annotation =
213-
Annotation(CaptureAnnotation(this, boxed = false).tree)
212+
def toRegularAnnotation(byName: Boolean)(using Context): Annotation =
213+
val kind = if byName then CapturingKind.ByName else CapturingKind.Regular
214+
Annotation(CaptureAnnotation(this, kind).tree)
214215

215216
override def toText(printer: Printer): Text =
216217
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
/** Possible kinds of captures */
6+
enum CapturingKind:
7+
case Regular // normal capture
8+
case Boxed // capture under box
9+
case ByName // capture applies to enclosing by-name type (only possible before ElimByName)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

+29-9
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,46 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*
77

8+
/** A capturing type. This is internally represented as an annotated type with a `retains`
9+
* annotation, but the extractor will succeed only at phase CheckCaptures.
10+
* Annotated types with `@retainsByName` annotation can also be created that way, by
11+
* giving a `CapturingKind.ByName` as `kind` argument, but they are never extracted,
12+
* since they have already been converted to regular capturing types before CheckCaptures.
13+
*/
814
object CapturingType:
915

10-
def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type =
16+
def apply(parent: Type, refs: CaptureSet, kind: CapturingKind)(using Context): Type =
1117
if refs.isAlwaysEmpty then parent
12-
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
13-
14-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
15-
if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp)
18+
else AnnotatedType(parent, CaptureAnnotation(refs, kind))
19+
20+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
21+
if ctx.phase == Phases.checkCapturesPhase then
22+
val r = EventuallyCapturingType.unapply(tp)
23+
r match
24+
case Some((_, _, CapturingKind.ByName)) => None
25+
case _ => r
1626
else None
1727

1828
end CapturingType
1929

30+
/** An extractor for types that will be capturing types at phase CheckCaptures. Also
31+
* included are types that indicate captures on enclosing call-by-name parameters
32+
* before phase ElimByName
33+
*/
2034
object EventuallyCapturingType:
2135

22-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
23-
if tp.annot.symbol == defn.RetainsAnnot then
36+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
37+
val sym = tp.annot.symbol
38+
if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then
2439
tp.annot match
25-
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
40+
case ann: CaptureAnnotation =>
41+
Some((tp.parent, ann.refs, ann.kind))
2642
case ann =>
27-
try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
43+
val kind =
44+
if ann.tree.isBoxedCapturing then CapturingKind.Boxed
45+
else if sym == defn.RetainsByNameAnnot then CapturingKind.ByName
46+
else CapturingKind.Regular
47+
try Some((tp.parent, ann.tree.toCaptureSet, kind))
2848
catch case ex: IllegalCaptureRef => None
2949
else None
3050

compiler/src/dotty/tools/dotc/cc/Setup.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ extends tpd.TreeTraverser:
2525
.toFunctionType(isJava = false, alwaysDependent = true)
2626

2727
private def box(tp: Type)(using Context): Type = tp match
28-
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
28+
case CapturingType(parent, refs, CapturingKind.Regular) =>
29+
CapturingType(parent, refs, CapturingKind.Boxed)
2930
case _ => tp
3031

3132
private def setBoxed(tp: Type)(using Context) = tp match
@@ -77,7 +78,7 @@ extends tpd.TreeTraverser:
7778
cls.paramGetters.foldLeft(tp) { (core, getter) =>
7879
if getter.termRef.isTracked then
7980
val getterType = tp.memberInfo(getter).strippedDealias
80-
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
81+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), CapturingKind.Regular))
8182
.showing(i"add capture refinement $tp --> $result", capt)
8283
else
8384
core
@@ -130,7 +131,7 @@ extends tpd.TreeTraverser:
130131
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
131132
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
132133
case _ if canHaveInferredCapture(tp) =>
133-
CapturingType(tp, CaptureSet.Var(), boxed = false)
134+
CapturingType(tp, CaptureSet.Var(), CapturingKind.Regular)
134135
case _ =>
135136
tp
136137

compiler/src/dotty/tools/dotc/core/Definitions.scala

+14-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Comments.CommentsContext
1515
import Comments.Comment
1616
import util.Spans.NoSpan
1717
import Symbols.requiredModuleRef
18-
import cc.{CapturingType, CaptureSet}
18+
import cc.{CapturingType, CaptureSet, CapturingKind, EventuallyCapturingType}
1919

2020
import scala.annotation.tailrec
2121

@@ -118,9 +118,9 @@ class Definitions {
118118
*
119119
* ErasedFunctionN and ErasedContextFunctionN erase to Function0.
120120
*
121-
* EffXYZFunctionN afollow this template:
121+
* ImpureXYZFunctionN follow this template:
122122
*
123-
* type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
123+
* type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
124124
*/
125125
private def newFunctionNType(name: TypeName): Symbol = {
126126
val impure = name.startsWith("Impure")
@@ -136,7 +136,7 @@ class Definitions {
136136
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
137137
tl => List.fill(arity + 1)(TypeBounds.empty),
138138
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
139-
CaptureSet.universal, boxed = false)
139+
CaptureSet.universal, CapturingKind.Regular)
140140
))
141141
else
142142
val cls = denot.asClass.classSymbol
@@ -1016,6 +1016,7 @@ class Definitions {
10161016
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
10171017
@tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since")
10181018
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains")
1019+
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName")
10191020

10201021
@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")
10211022

@@ -1149,9 +1150,16 @@ class Definitions {
11491150
}
11501151
}
11511152

1153+
/** Extractor for function types representing by-name parameters, of the form
1154+
* `() ?=> T`.
1155+
* Under -Ycc, this becomes `() ?-> T` or `{r1, ..., rN} () ?-> T`.
1156+
*/
11521157
object ByNameFunction:
1153-
def apply(tp: Type)(using Context): Type =
1154-
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
1158+
def apply(tp: Type)(using Context): Type = tp match
1159+
case EventuallyCapturingType(tp1, refs, CapturingKind.ByName) =>
1160+
CapturingType(apply(tp1), refs, CapturingKind.Regular)
1161+
case _ =>
1162+
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
11551163
def unapply(tp: Type)(using Context): Option[Type] = tp match
11561164
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
11571165
Some(arg)

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ object StdNames {
566566
val reify : N = "reify"
567567
val releaseFence : N = "releaseFence"
568568
val retains: N = "retains"
569+
val retainsByName: N = "retainsByName"
569570
val rootMirror : N = "rootMirror"
570571
val run: N = "run"
571572
val runOrElse: N = "runOrElse"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import typer.ProtoTypes.constrained
2323
import typer.Applications.productSelectorTypes
2424
import reporting.trace
2525
import annotation.constructorOnly
26-
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing}
26+
import cc.{CapturingType, derivedCapturingType, CaptureSet, CapturingKind, stripCapturing}
2727

2828
/** Provides methods to compare types.
2929
*/
@@ -876,7 +876,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
876876
tp1 match
877877
case tp1: CaptureRef if tp1.isTracked =>
878878
val stripped = tp1w.stripCapturing
879-
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false)
879+
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, CapturingKind.Regular)
880880
case _ =>
881881
isSubType(tp1w, tp2, approx.addLow)
882882
}

compiler/src/dotty/tools/dotc/core/Types.scala

+9-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import config.Printers.{core, typr, matchTypes}
3636
import reporting.{trace, Message}
3737
import java.lang.ref.WeakReference
3838
import compiletime.uninitialized
39-
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing}
39+
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing, CapturingKind}
4040
import CaptureSet.CompareResult
4141

4242
import scala.annotation.internal.sharable
@@ -1880,13 +1880,15 @@ object Types {
18801880

18811881
def capturing(ref: CaptureRef)(using Context): Type =
18821882
if captureSet.accountsFor(ref) then this
1883-
else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing)
1883+
else CapturingType(this, ref.singletonCaptureSet,
1884+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18841885

18851886
def capturing(cs: CaptureSet)(using Context): Type =
18861887
if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this
18871888
else this match
18881889
case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs)
1889-
case _ => CapturingType(this, cs, this.isBoxedCapturing)
1890+
case _ => CapturingType(this, cs,
1891+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18901892

18911893
/** The set of distinct symbols referred to by this type, after all aliases are expanded */
18921894
def coveringSet(using Context): Set[Symbol] =
@@ -3840,10 +3842,11 @@ object Types {
38403842
CapturingType(parent1, CaptureSet.universal, boxed))
38413843
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
38423844
val parent1 = mapOver(parent)
3843-
if ann.symbol == defn.RetainsAnnot then
3845+
if ann.symbol == defn.RetainsAnnot || ann.symbol == defn.RetainsByNameAnnot then
3846+
val byName = ann.symbol == defn.RetainsByNameAnnot
38443847
range(
3845-
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation),
3846-
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation))
3848+
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation(byName)),
3849+
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation(byName)))
38473850
else
38483851
parent1
38493852
case _ => mapOver(tp)

0 commit comments

Comments
 (0)