Skip to content

Commit 0382f6b

Browse files
Merge pull request #7766 from dotty-staging/fix-#7757
Fix #7757: Do auto-parameter-untupling also for overloaded methods
2 parents 8d2ee65 + 4f8893f commit 0382f6b

File tree

8 files changed

+106
-35
lines changed

8 files changed

+106
-35
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+15-9
Original file line numberDiff line numberDiff line change
@@ -1321,33 +1321,39 @@ object desugar {
13211321
Function(params, Match(makeSelector(selector, checkMode), cases))
13221322
}
13231323

1324-
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
1324+
/** Map n-ary function `(x1: T1, ..., xn: Tn) => body` where n != 1 to unary function as follows:
13251325
*
1326-
* x$1 => {
1327-
* def p1 = x$1._1
1326+
* (x$1: (T1, ..., Tn)) => {
1327+
* def x1: T1 = x$1._1
13281328
* ...
1329-
* def pn = x$1._n
1329+
* def xn: Tn = x$1._n
13301330
* body
13311331
* }
13321332
*
13331333
* or if `isGenericTuple`
13341334
*
1335-
* x$1 => {
1336-
* def p1 = x$1.apply(0)
1335+
* (x$1: (T1, ... Tn) => {
1336+
* def x1: T1 = x$1.apply(0)
13371337
* ...
1338-
* def pn = x$1.apply(n-1)
1338+
* def xn: Tn = x$1.apply(n-1)
13391339
* body
13401340
* }
1341+
*
1342+
* If some of the Ti's are absent, omit the : (T1, ..., Tn) type ascription
1343+
* in the selector.
13411344
*/
13421345
def makeTupledFunction(params: List[ValDef], body: Tree, isGenericTuple: Boolean)(implicit ctx: Context): Tree = {
1343-
val param = makeSyntheticParameter()
1346+
val param = makeSyntheticParameter(
1347+
tpt =
1348+
if params.exists(_.tpt.isEmpty) then TypeTree()
1349+
else Tuple(params.map(_.tpt)))
13441350
def selector(n: Int) =
13451351
if (isGenericTuple) Apply(Select(refOfDef(param), nme.apply), Literal(Constant(n)))
13461352
else Select(refOfDef(param), nme.selectorName(n))
13471353
val vdefs =
13481354
params.zipWithIndex.map {
13491355
case (param, idx) =>
1350-
DefDef(param.name, Nil, Nil, TypeTree(), selector(idx)).withSpan(param.span)
1356+
DefDef(param.name, Nil, Nil, param.tpt, selector(idx)).withSpan(param.span)
13511357
}
13521358
Function(param :: Nil, Block(vdefs, body))
13531359
}

compiler/src/dotty/tools/dotc/core/Decorators.scala

+16
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ object Decorators {
114114
else x1 :: xs1
115115
}
116116

117+
/** Like `xs.lazyZip(xs.indices).map(f)`, but returns list `xs` itself
118+
* - instead of a copy - if function `f` maps all elements of
119+
* `xs` to themselves.
120+
*/
121+
def mapWithIndexConserve[U <: T](f: (T, Int) => U): List[U] =
122+
def recur(xs: List[T], idx: Int): List[U] =
123+
if xs.isEmpty then Nil
124+
else
125+
val x1 = f(xs.head, idx)
126+
val xs1 = recur(xs.tail, idx + 1)
127+
if (x1.asInstanceOf[AnyRef] eq xs.head.asInstanceOf[AnyRef])
128+
&& (xs1 eq xs.tail)
129+
then xs.asInstanceOf[List[U]]
130+
else x1 :: xs1
131+
recur(xs, 0)
132+
117133
final def hasSameLengthAs[U](ys: List[U]): Boolean = {
118134
@tailrec def loop(xs: List[T], ys: List[U]): Boolean =
119135
if (xs.isEmpty) ys.isEmpty

compiler/src/dotty/tools/dotc/typer/Applications.scala

+50-19
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33
package typer
44

55
import core._
6-
import ast.{Trees, tpd, untpd}
6+
import ast.{Trees, tpd, untpd, desugar}
77
import util.Spans._
88
import util.Stats.record
99
import util.{SourcePosition, NoSourcePosition, SourceFile}
@@ -864,7 +864,7 @@ trait Applications extends Compatibility {
864864
case funRef: TermRef =>
865865
val app =
866866
if (proto.allArgTypesAreCurrent())
867-
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
867+
new ApplyToTyped(tree, fun1, funRef, proto.typedArgs(), pt)
868868
else
869869
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(
870870
given fun1.nullableInArgContext(given argCtx(tree)))
@@ -891,7 +891,7 @@ trait Applications extends Compatibility {
891891
}
892892

893893
fun1.tpe match {
894-
case err: ErrorType => cpy.Apply(tree)(fun1, proto.unforcedTypedArgs).withType(err)
894+
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs()).withType(err)
895895
case TryDynamicCallType => typedDynamicApply(tree, pt)
896896
case _ =>
897897
if (originalProto.isDropped) fun1
@@ -1635,14 +1635,46 @@ trait Applications extends Compatibility {
16351635
def narrowByTypes(alts: List[TermRef], argTypes: List[Type], resultType: Type): List[TermRef] =
16361636
alts filter (isApplicableMethodRef(_, argTypes, resultType))
16371637

1638+
/** Normalization steps before checking arguments:
1639+
*
1640+
* { expr } --> expr
1641+
* (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr
1642+
* if n != 1, no alternative has a corresponding formal parameter that
1643+
* is an n-ary function, and at least one alternative has a corresponding
1644+
* formal parameter that is a unary function.
1645+
*/
1646+
def normArg(alts: List[TermRef], arg: untpd.Tree, idx: Int): untpd.Tree = arg match
1647+
case Block(Nil, expr) => normArg(alts, expr, idx)
1648+
case untpd.Function(args: List[untpd.ValDef] @unchecked, body) =>
1649+
1650+
// If ref refers to a method whose parameter at index `idx` is a function type,
1651+
// the arity of that function, otherise -1.
1652+
def paramCount(ref: TermRef) =
1653+
val formals = ref.widen.firstParamTypes
1654+
if formals.length > idx then
1655+
formals(idx) match
1656+
case defn.FunctionOf(args, _, _, _) => args.length
1657+
case _ => -1
1658+
else -1
1659+
1660+
val numArgs = args.length
1661+
if numArgs != 1
1662+
&& !alts.exists(paramCount(_) == numArgs)
1663+
&& alts.exists(paramCount(_) == 1)
1664+
then
1665+
desugar.makeTupledFunction(args, body, isGenericTuple = true)
1666+
// `isGenericTuple = true` is the safe choice here. It means the i'th tuple
1667+
// element is selected with `(i)` instead of `_i`, which gives the same code
1668+
// in the end, but the compilation time and the ascribed type are more involved.
1669+
// It also means that -Ytest-pickler -Xprint-types fails for sources exercising
1670+
// the idiom since after pickling the target is known, so _i is used directly.
1671+
else arg
1672+
case _ => arg
1673+
end normArg
1674+
16381675
val candidates = pt match {
16391676
case pt @ FunProto(args, resultType) =>
16401677
val numArgs = args.length
1641-
val normArgs = args.mapConserve {
1642-
case Block(Nil, expr) => expr
1643-
case x => x
1644-
}
1645-
16461678
def sizeFits(alt: TermRef): Boolean = alt.widen.stripPoly match {
16471679
case tp: MethodType =>
16481680
val ptypes = tp.paramInfos
@@ -1661,9 +1693,10 @@ trait Applications extends Compatibility {
16611693
alts.filter(sizeFits(_))
16621694

16631695
def narrowByShapes(alts: List[TermRef]): List[TermRef] =
1664-
if (normArgs exists untpd.isFunctionWithUnknownParamType)
1665-
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
1666-
else narrowByTypes(alts, normArgs map typeShape, resultType)
1696+
val normArgs = args.mapWithIndexConserve(normArg(alts, _, _))
1697+
if normArgs.exists(untpd.isFunctionWithUnknownParamType) then
1698+
if hasNamedArg(args) then narrowByTrees(alts, normArgs.map(treeShape), resultType)
1699+
else narrowByTypes(alts, normArgs.map(typeShape), resultType)
16671700
else
16681701
alts
16691702

@@ -1681,16 +1714,14 @@ trait Applications extends Compatibility {
16811714

16821715
val alts1 = narrowBySize(alts)
16831716
//ctx.log(i"narrowed by size: ${alts1.map(_.symbol.showDcl)}%, %")
1684-
if (isDetermined(alts1)) alts1
1685-
else {
1717+
if isDetermined(alts1) then alts1
1718+
else
16861719
val alts2 = narrowByShapes(alts1)
16871720
//ctx.log(i"narrowed by shape: ${alts2.map(_.symbol.showDcl)}%, %")
1688-
if (isDetermined(alts2)) alts2
1689-
else {
1721+
if isDetermined(alts2) then alts2
1722+
else
16901723
pretypeArgs(alts2, pt)
1691-
narrowByTrees(alts2, pt.unforcedTypedArgs, resultType)
1692-
}
1693-
}
1724+
narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType)
16941725

16951726
case pt @ PolyProto(targs1, pt1) if targs.isEmpty =>
16961727
val alts1 = alts.filter(pt.isMatchedBy(_))
@@ -1749,7 +1780,7 @@ trait Applications extends Compatibility {
17491780
else pt match {
17501781
case pt @ FunProto(_, resType: FunProto) =>
17511782
// try to narrow further with snd argument list
1752-
val advanced = advanceCandidates(pt.unforcedTypedArgs.tpes)
1783+
val advanced = advanceCandidates(pt.typedArgs().tpes)
17531784
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
17541785
.map(advanced.toMap) // map surviving result(s) back to original candidates
17551786
case _ =>

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object ErrorReporting {
5555
case _: WildcardType | _: IgnoredProto => ""
5656
case tp => em" and expected result type $tp"
5757
}
58-
em"arguments (${tp.unforcedTypedArgs.tpes}%, %)$result"
58+
em"arguments (${tp.typedArgs().tpes}%, %)$result"
5959
case _ =>
6060
em"expected type $tp"
6161
}

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ object ProtoTypes {
254254
override def resultType(implicit ctx: Context): Type = resType
255255

256256
def isMatchedBy(tp: Type, keepConstraint: Boolean)(implicit ctx: Context): Boolean = {
257-
val args = unforcedTypedArgs
257+
val args = typedArgs()
258258
def isPoly(tree: Tree) = tree.tpe.widenSingleton.isInstanceOf[PolyType]
259259
// See remark in normalizedCompatible for why we can't keep the constraint
260260
// if one of the arguments has a PolyType.
@@ -305,15 +305,18 @@ object ProtoTypes {
305305
* However, any constraint changes are also propagated to the currently passed
306306
* context.
307307
*
308+
* @param norm a normalization function that is applied to an untyped argument tree
309+
* before it is typed. The second Int parameter is the parameter index.
308310
*/
309-
def unforcedTypedArgs(implicit ctx: Context): List[Tree] =
311+
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(implicit ctx: Context): List[Tree] =
310312
if (state.typedArgs.size == args.length) state.typedArgs
311313
else {
312314
val prevConstraint = this.ctx.typerState.constraint
313315

314316
try {
315317
implicit val ctx = this.ctx
316-
val args1 = args.mapconserve(cacheTypedArg(_, typer.typed(_), force = false))
318+
val args1 = args.mapWithIndexConserve((arg, idx) =>
319+
cacheTypedArg(arg, arg => typer.typed(norm(arg, idx)), force = false))
317320
if (!args1.exists(arg => isUndefined(arg.tpe))) state.typedArgs = args1
318321
args1
319322
}
@@ -379,7 +382,7 @@ object ProtoTypes {
379382
derivedFunProto(args, tm(resultType), typer)
380383

381384
def fold[T](x: T, ta: TypeAccumulator[T])(implicit ctx: Context): T =
382-
ta(ta.foldOver(x, unforcedTypedArgs.tpes), resultType)
385+
ta(ta.foldOver(x, typedArgs().tpes), resultType)
383386

384387
override def deepenProto(implicit ctx: Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer)
385388

@@ -393,7 +396,7 @@ object ProtoTypes {
393396
* [](args): resultType, where args are known to be typed
394397
*/
395398
class FunProtoTyped(args: List[tpd.Tree], resultType: Type)(typer: Typer, isGivenApply: Boolean)(implicit ctx: Context) extends FunProto(args, resultType)(typer, isGivenApply)(ctx) {
396-
override def unforcedTypedArgs(implicit ctx: Context): List[tpd.Tree] = args
399+
override def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree)(implicit ctx: Context): List[tpd.Tree] = args
397400
override def withContext(ctx: Context): FunProtoTyped = this
398401
}
399402

@@ -686,4 +689,6 @@ object ProtoTypes {
686689
case _ => None
687690
}
688691
}
692+
693+
private val sameTree = (t: untpd.Tree, n: Int) => t
689694
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ class Typer extends Namer
947947
}
948948

949949
def typedFunctionValue(tree: untpd.Function, pt: Type)(implicit ctx: Context): Tree = {
950-
val untpd.Function(params: List[untpd.ValDef] @unchecked, body) = tree
950+
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree
951951

952952
val isContextual = tree match {
953953
case tree: untpd.FunctionWithMods => tree.mods.is(Given)

compiler/test/dotc/pos-test-pickling.blacklist

+3
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ i7580.scala
3232

3333
# Nullability
3434
nullable.scala
35+
36+
# parameter untupling with overloaded functions (see comment in Applications.normArg)
37+
i7757.scala

tests/pos/i7757.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
val m: Map[Int, String] = ???
2+
val _ = m.map((a, b) => a + b.length)
3+
4+
trait Foo
5+
def g(f: ((Int, Int)) => Int): Int = 1
6+
def g(f: ((Int, Int)) => (Int, Int)): String = "2"
7+
8+
@main def Test =
9+
val m: Foo = ???
10+
m.g((x: Int, b: Int) => (x, x))

0 commit comments

Comments
 (0)