Skip to content

Commit 4f8893f

Browse files
committed
Keep type info when auto-tupling parameters
1 parent 17c2f37 commit 4f8893f

File tree

7 files changed

+77
-58
lines changed

7 files changed

+77
-58
lines changed

compiler/src/dotty/tools/dotc/core/Decorators.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ object Decorators {
118118
* - instead of a copy - if function `f` maps all elements of
119119
* `xs` to themselves.
120120
*/
121-
def mapWithIndexConserve[U](f: (T, Int) => T): List[T] =
122-
def recur(xs: List[T], idx: Int): List[T] =
121+
def mapWithIndexConserve[U <: T](f: (T, Int) => U): List[U] =
122+
def recur(xs: List[T], idx: Int): List[U] =
123123
if xs.isEmpty then Nil
124124
else
125125
val x1 = f(xs.head, idx)
126126
val xs1 = recur(xs.tail, idx + 1)
127127
if (x1.asInstanceOf[AnyRef] eq xs.head.asInstanceOf[AnyRef])
128128
&& (xs1 eq xs.tail)
129-
then xs
129+
then xs.asInstanceOf[List[U]]
130130
else x1 :: xs1
131131
recur(xs, 0)
132132

compiler/src/dotty/tools/dotc/typer/Applications.scala

+50-47
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33
package typer
44

55
import core._
6-
import ast.{Trees, tpd, untpd}
6+
import ast.{Trees, tpd, untpd, desugar}
77
import util.Spans._
88
import util.Stats.record
99
import util.{SourcePosition, NoSourcePosition, SourceFile}
@@ -864,7 +864,7 @@ trait Applications extends Compatibility {
864864
case funRef: TermRef =>
865865
val app =
866866
if (proto.allArgTypesAreCurrent())
867-
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
867+
new ApplyToTyped(tree, fun1, funRef, proto.typedArgs(), pt)
868868
else
869869
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(
870870
given fun1.nullableInArgContext(given argCtx(tree)))
@@ -891,7 +891,7 @@ trait Applications extends Compatibility {
891891
}
892892

893893
fun1.tpe match {
894-
case err: ErrorType => cpy.Apply(tree)(fun1, proto.unforcedTypedArgs).withType(err)
894+
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
895895
case TryDynamicCallType => typedDynamicApply(tree, pt)
896896
case _ =>
897897
if (originalProto.isDropped) fun1
@@ -1635,6 +1635,43 @@ trait Applications extends Compatibility {
16351635
def narrowByTypes(alts: List[TermRef], argTypes: List[Type], resultType: Type): List[TermRef] =
16361636
alts filter (isApplicableMethodRef(_, argTypes, resultType))
16371637

1638+
/** Normalization steps before checking arguments:
1639+
*
1640+
* { expr } --> expr
1641+
* (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr
1642+
* if n != 1, no alternative has a corresponding formal parameter that
1643+
* is an n-ary function, and at least one alternative has a corresponding
1644+
* formal parameter that is a unary function.
1645+
*/
1646+
def normArg(alts: List[TermRef], arg: untpd.Tree, idx: Int): untpd.Tree = arg match
1647+
case Block(Nil, expr) => normArg(alts, expr, idx)
1648+
case untpd.Function(args: List[untpd.ValDef] @unchecked, body) =>
1649+
1650+
// If ref refers to a method whose parameter at index `idx` is a function type,
1651+
// the arity of that function, otherise -1.
1652+
def paramCount(ref: TermRef) =
1653+
val formals = ref.widen.firstParamTypes
1654+
if formals.length > idx then
1655+
formals(idx) match
1656+
case defn.FunctionOf(args, _, _, _) => args.length
1657+
case _ => -1
1658+
else -1
1659+
1660+
val numArgs = args.length
1661+
if numArgs != 1
1662+
&& !alts.exists(paramCount(_) == numArgs)
1663+
&& alts.exists(paramCount(_) == 1)
1664+
then
1665+
desugar.makeTupledFunction(args, body, isGenericTuple = true)
1666+
// `isGenericTuple = true` is the safe choice here. It means the i'th tuple
1667+
// element is selected with `(i)` instead of `_i`, which gives the same code
1668+
// in the end, but the compilation time and the ascribed type are more involved.
1669+
// It also means that -Ytest-pickler -Xprint-types fails for sources exercising
1670+
// the idiom since after pickling the target is known, so _i is used directly.
1671+
else arg
1672+
case _ => arg
1673+
end normArg
1674+
16381675
val candidates = pt match {
16391676
case pt @ FunProto(args, resultType) =>
16401677
val numArgs = args.length
@@ -1656,42 +1693,10 @@ trait Applications extends Compatibility {
16561693
alts.filter(sizeFits(_))
16571694

16581695
def narrowByShapes(alts: List[TermRef]): List[TermRef] =
1659-
1660-
/** Normalization steps before shape-checking arguments:
1661-
*
1662-
* { expr } --> expr
1663-
* (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr
1664-
* if n > 1, no alternative has a corresponding formal parameter that
1665-
* is an n-ary function, and at least one alternative has a corresponding
1666-
* formal parameter that is a unary function.
1667-
*/
1668-
def normArg(arg: untpd.Tree, idx: Int): untpd.Tree = arg match
1669-
case Block(Nil, expr) => normArg(expr, idx)
1670-
case untpd.Function(args, body) =>
1671-
1672-
// If ref refers to a method whose parameter at index `idx` is a function type,
1673-
// the arity of that function, otherise 0.
1674-
def paramCount(ref: TermRef) =
1675-
val formals = ref.widen.firstParamTypes
1676-
if formals.length > idx then
1677-
formals(idx) match
1678-
case defn.FunctionOf(args, _, _, _) => args.length
1679-
case _ => 0
1680-
else 0
1681-
1682-
val numArgs = args.length
1683-
if numArgs > 1
1684-
&& !alts.exists(paramCount(_) == numArgs)
1685-
&& alts.exists(paramCount(_) == 1)
1686-
then untpd.Function(untpd.Tuple(args) :: Nil, body)
1687-
else arg
1688-
case _ => arg
1689-
end normArg
1690-
1691-
val normArgs = args.mapWithIndexConserve(normArg)
1692-
if (normArgs exists untpd.isFunctionWithUnknownParamType)
1693-
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
1694-
else narrowByTypes(alts, normArgs map typeShape, resultType)
1696+
val normArgs = args.mapWithIndexConserve(normArg(alts, _, _))
1697+
if normArgs.exists(untpd.isFunctionWithUnknownParamType) then
1698+
if hasNamedArg(args) then narrowByTrees(alts, normArgs.map(treeShape), resultType)
1699+
else narrowByTypes(alts, normArgs.map(typeShape), resultType)
16951700
else
16961701
alts
16971702

@@ -1709,16 +1714,14 @@ trait Applications extends Compatibility {
17091714

17101715
val alts1 = narrowBySize(alts)
17111716
//ctx.log(i"narrowed by size: ${alts1.map(_.symbol.showDcl)}%, %")
1712-
if (isDetermined(alts1)) alts1
1713-
else {
1717+
if isDetermined(alts1) then alts1
1718+
else
17141719
val alts2 = narrowByShapes(alts1)
17151720
//ctx.log(i"narrowed by shape: ${alts2.map(_.symbol.showDcl)}%, %")
1716-
if (isDetermined(alts2)) alts2
1717-
else {
1721+
if isDetermined(alts2) then alts2
1722+
else
17181723
pretypeArgs(alts2, pt)
1719-
narrowByTrees(alts2, pt.unforcedTypedArgs, resultType)
1720-
}
1721-
}
1724+
narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType)
17221725

17231726
case pt @ PolyProto(targs1, pt1) if targs.isEmpty =>
17241727
val alts1 = alts.filter(pt.isMatchedBy(_))
@@ -1777,7 +1780,7 @@ trait Applications extends Compatibility {
17771780
else pt match {
17781781
case pt @ FunProto(_, resType: FunProto) =>
17791782
// try to narrow further with snd argument list
1780-
val advanced = advanceCandidates(pt.unforcedTypedArgs.tpes)
1783+
val advanced = advanceCandidates(pt.typedArgs().tpes)
17811784
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
17821785
.map(advanced.toMap) // map surviving result(s) back to original candidates
17831786
case _ =>

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object ErrorReporting {
5555
case _: WildcardType | _: IgnoredProto => ""
5656
case tp => em" and expected result type $tp"
5757
}
58-
em"arguments (${tp.unforcedTypedArgs.tpes}%, %)$result"
58+
em"arguments (${tp.typedArgs().tpes}%, %)$result"
5959
case _ =>
6060
em"expected type $tp"
6161
}

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ object ProtoTypes {
250250
override def resultType(implicit ctx: Context): Type = resType
251251

252252
def isMatchedBy(tp: Type, keepConstraint: Boolean)(implicit ctx: Context): Boolean = {
253-
val args = unforcedTypedArgs
253+
val args = typedArgs()
254254
def isPoly(tree: Tree) = tree.tpe.widenSingleton.isInstanceOf[PolyType]
255255
// See remark in normalizedCompatible for why we can't keep the constraint
256256
// if one of the arguments has a PolyType.
@@ -301,15 +301,18 @@ object ProtoTypes {
301301
* However, any constraint changes are also propagated to the currently passed
302302
* context.
303303
*
304+
* @param norm a normalization function that is applied to an untyped argument tree
305+
* before it is typed. The second Int parameter is the parameter index.
304306
*/
305-
def unforcedTypedArgs(implicit ctx: Context): List[Tree] =
307+
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(implicit ctx: Context): List[Tree] =
306308
if (state.typedArgs.size == args.length) state.typedArgs
307309
else {
308310
val prevConstraint = this.ctx.typerState.constraint
309311

310312
try {
311313
implicit val ctx = this.ctx
312-
val args1 = args.mapconserve(cacheTypedArg(_, typer.typed(_), force = false))
314+
val args1 = args.mapWithIndexConserve((arg, idx) =>
315+
cacheTypedArg(arg, arg => typer.typed(norm(arg, idx)), force = false))
313316
if (!args1.exists(arg => isUndefined(arg.tpe))) state.typedArgs = args1
314317
args1
315318
}
@@ -375,7 +378,7 @@ object ProtoTypes {
375378
derivedFunProto(args, tm(resultType), typer)
376379

377380
def fold[T](x: T, ta: TypeAccumulator[T])(implicit ctx: Context): T =
378-
ta(ta.foldOver(x, unforcedTypedArgs.tpes), resultType)
381+
ta(ta.foldOver(x, typedArgs().tpes), resultType)
379382

380383
override def deepenProto(implicit ctx: Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer)
381384

@@ -389,7 +392,7 @@ object ProtoTypes {
389392
* [](args): resultType, where args are known to be typed
390393
*/
391394
class FunProtoTyped(args: List[tpd.Tree], resultType: Type)(typer: Typer, isGivenApply: Boolean)(implicit ctx: Context) extends FunProto(args, resultType)(typer, isGivenApply)(ctx) {
392-
override def unforcedTypedArgs(implicit ctx: Context): List[tpd.Tree] = args
395+
override def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree)(implicit ctx: Context): List[tpd.Tree] = args
393396
override def withContext(ctx: Context): FunProtoTyped = this
394397
}
395398

@@ -682,4 +685,6 @@ object ProtoTypes {
682685
case _ => None
683686
}
684687
}
688+
689+
private val sameTree = (t: untpd.Tree, n: Int) => t
685690
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ class Typer extends Namer
947947
}
948948

949949
def typedFunctionValue(tree: untpd.Function, pt: Type)(implicit ctx: Context): Tree = {
950-
val untpd.Function(params: List[untpd.ValDef] @unchecked, body) = tree
950+
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree
951951

952952
val isContextual = tree match {
953953
case tree: untpd.FunctionWithMods => tree.mods.is(Given)

compiler/test/dotc/pos-test-pickling.blacklist

+3
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ i7580.scala
3232

3333
# Nullability
3434
nullable.scala
35+
36+
# parameter untupling with overloaded functions (see comment in Applications.normArg)
37+
i7757.scala

tests/pos/i7757.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
11
val m: Map[Int, String] = ???
2-
val _ = m.map((a, b) => a + b.length)
2+
val _ = m.map((a, b) => a + b.length)
3+
4+
trait Foo
5+
def g(f: ((Int, Int)) => Int): Int = 1
6+
def g(f: ((Int, Int)) => (Int, Int)): String = "2"
7+
8+
@main def Test =
9+
val m: Foo = ???
10+
m.g((x: Int, b: Int) => (x, x))

0 commit comments

Comments
 (0)