Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 01375b9

Browse files
committedAug 17, 2022
Map regular function types to impure function types when unpickling
Map regular function types to impure function types when unpickling a class under -Ycc that was not itself compiled with -Ycc.
1 parent 0b87afb commit 01375b9

File tree

6 files changed

+73
-4
lines changed

6 files changed

+73
-4
lines changed
 

‎compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+15-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*, Annotations.*
77
import ast.{tpd, untpd}
8-
import Decorators.*
8+
import Decorators.*, NameOps.*
99
import config.Printers.capt
1010
import util.Property.Key
1111
import tpd.*
@@ -71,3 +71,17 @@ extension (tp: Type)
7171
atd.derivedAnnotatedType(parent.stripCapturing, annot)
7272
case _ =>
7373
tp
74+
75+
/** Under -Ycc, map regular function type to impure function type
76+
*/
77+
def adaptFunctionType(using Context): Type = tp match
78+
case AppliedType(fn, args)
79+
if ctx.settings.Ycc.value && defn.isFunctionClass(fn.typeSymbol) =>
80+
val fname = fn.typeSymbol.name
81+
defn.FunctionType(
82+
fname.functionArity,
83+
isContextual = fname.isContextFunction,
84+
isErased = fname.isErasedFunction,
85+
isImpure = true).appliedTo(args)
86+
case _ =>
87+
tp

‎compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+15-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import util.{SourceFile, Property}
3131
import ast.{Trees, tpd, untpd}
3232
import Trees._
3333
import Decorators._
34+
import transform.SymUtils._
35+
import cc.adaptFunctionType
3436

3537
import dotty.tools.tasty.{TastyBuffer, TastyReader}
3638
import TastyBuffer._
@@ -85,6 +87,9 @@ class TreeUnpickler(reader: TastyReader,
8587
/** The root owner tree. See `OwnerTree` class definition. Set by `enterTopLevel`. */
8688
private var ownerTree: OwnerTree = _
8789

90+
/** Was unpickled class compiled with -Ycc? */
91+
private var wasCaptureChecked: Boolean = false
92+
8893
private def registerSym(addr: Addr, sym: Symbol) =
8994
symAtAddr(addr) = sym
9095

@@ -371,7 +376,7 @@ class TreeUnpickler(reader: TastyReader,
371376
// Note that the lambda "rt => ..." is not equivalent to a wildcard closure!
372377
// Eta expansion of the latter puts readType() out of the expression.
373378
case APPLIEDtype =>
374-
readType().appliedTo(until(end)(readType()))
379+
postProcessFunction(readType().appliedTo(until(end)(readType())))
375380
case TYPEBOUNDS =>
376381
val lo = readType()
377382
if nothingButMods(end) then
@@ -484,6 +489,12 @@ class TreeUnpickler(reader: TastyReader,
484489
def readTermRef()(using Context): TermRef =
485490
readType().asInstanceOf[TermRef]
486491

492+
/** Under -Ycc, map all function types to impure function types,
493+
* unless the unpickled class was also compiled with -Ycc.
494+
*/
495+
private def postProcessFunction(tp: Type)(using Context): Type =
496+
if wasCaptureChecked then tp else tp.adaptFunctionType
497+
487498
// ------ Reading definitions -----------------------------------------------------
488499

489500
private def nothingButMods(end: Addr): Boolean =
@@ -631,6 +642,8 @@ class TreeUnpickler(reader: TastyReader,
631642
}
632643
registerSym(start, sym)
633644
if (isClass) {
645+
if sym.owner.is(Package) && annots.exists(_.symbol == defn.CaptureCheckedAnnot) then
646+
wasCaptureChecked = true
634647
sym.completer.withDecls(newScope)
635648
forkAt(templateStart).indexTemplateParams()(using localContext(sym))
636649
}
@@ -1339,7 +1352,7 @@ class TreeUnpickler(reader: TastyReader,
13391352
val args = until(end)(readTpt())
13401353
val tree = untpd.AppliedTypeTree(tycon, args)
13411354
val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes))
1342-
tree.withType(ownType)
1355+
tree.withType(postProcessFunction(ownType))
13431356
case ANNOTATEDtpt =>
13441357
Annotated(readTpt(), readTerm())
13451358
case LAMBDAtpt =>

‎compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import scala.collection.mutable
3232
import scala.collection.mutable.ListBuffer
3333
import scala.annotation.switch
3434
import reporting._
35+
import cc.adaptFunctionType
3536

3637
object Scala2Unpickler {
3738

@@ -822,7 +823,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
822823
// special-case in erasure, see TypeErasure#eraseInfo.
823824
OrType(args(0), args(1), soft = false)
824825
}
825-
else if (args.nonEmpty) tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
826+
else if args.nonEmpty then
827+
tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly)))
828+
.adaptFunctionType
826829
else if (sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams)
827830
else tycon
828831
case TYPEBOUNDStpe =>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object Lib:
2+
extension [A](xs: Seq[A])
3+
def mapp[B](f: A => B): Seq[B] =
4+
xs.map(f.asInstanceOf[A -> B])
5+
6+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import language.experimental.saferExceptions
2+
import Lib.*
3+
4+
class LimitExceeded extends Exception
5+
6+
val limit = 10e9
7+
8+
def f(x: Double): Double throws LimitExceeded =
9+
if x < limit then x * x else throw LimitExceeded()
10+
11+
@main def test(xs: Double*) =
12+
try println(xs.mapp(f).sum)
13+
catch case ex: LimitExceeded => println("too large")
14+
15+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import language.experimental.saferExceptions
2+
3+
class LimitExceeded extends Exception
4+
5+
val limit = 10e9
6+
7+
extension [A](xs: Seq[A])
8+
def mapp[B](f: A => B): Seq[B] =
9+
xs.map(f.asInstanceOf[A -> B])
10+
11+
def f(x: Double): Double throws LimitExceeded =
12+
if x < limit then x * x else throw LimitExceeded()
13+
14+
@main def test(xs: Double*) =
15+
try println(xs.mapp(f).sum + xs.map(f).sum)
16+
catch case ex: LimitExceeded => println("too large")
17+
18+

0 commit comments

Comments
 (0)
Please sign in to comment.