Skip to content

Commit 1c80c61

Browse files
Backport "Two fixes to NamedTuple pattern matching" to 3.7.0 (#22995)
Backports #22953 to the 3.7.0-RC2. PR submitted by the release tooling.
2 parents 5517193 + 13204d0 commit 1c80c61

File tree

9 files changed

+108
-5
lines changed

9 files changed

+108
-5
lines changed

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ import NameKinds.OuterSelectName
1919
import StdNames.*
2020
import config.Feature
2121
import inlines.Inlines.inInlineMethod
22+
import util.Property
2223

2324
object FirstTransform {
2425
val name: String = "firstTransform"
2526
val description: String = "some transformations to put trees into a canonical form"
27+
28+
/** Attachment key for named argument patterns */
29+
val WasNamedArg: Property.StickyKey[Unit] = Property.StickyKey()
2630
}
2731

2832
/** The first tree transform
@@ -38,6 +42,7 @@ object FirstTransform {
3842
*/
3943
class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
4044
import ast.tpd.*
45+
import FirstTransform.*
4146

4247
override def phaseName: String = FirstTransform.name
4348

@@ -156,7 +161,13 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
156161

157162
override def transformOther(tree: Tree)(using Context): Tree = tree match {
158163
case tree: Export => EmptyTree
159-
case tree: NamedArg => transformAllDeep(tree.arg)
164+
case tree: NamedArg =>
165+
val res = transformAllDeep(tree.arg)
166+
if ctx.mode.is(Mode.Pattern) then
167+
// Need to keep NamedArg status for pattern matcher to work correctly when faced
168+
// with single-element named tuples.
169+
res.pushAttachment(WasNamedArg, ())
170+
res
160171
case tree => if (tree.isType) toTypeTree(tree) else tree
161172
}
162173

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

+14-3
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,20 @@ object PatternMatcher {
386386
}
387387
else
388388
letAbstract(get) { getResult =>
389-
val selectors =
390-
if (args.tail.isEmpty) ref(getResult) :: Nil
391-
else productSelectors(getResult.info).map(ref(getResult).select(_))
389+
def isUnaryNamedTupleSelectArg(arg: Tree) =
390+
get.tpe.widenDealias.isNamedTupleType
391+
&& arg.removeAttachment(FirstTransform.WasNamedArg).isDefined
392+
// Special case: Normally, we pull out the argument wholesale if
393+
// there is only one. But if the argument is a named argument for
394+
// a single-element named tuple, we have to select the field instead.
395+
// NamedArg trees are eliminated in FirstTransform but for named arguments
396+
// of patterns we add a WasNamedArg attachment, which is used to guide the
397+
// logic here. See i22900.scala for test cases.
398+
val selectors = args match
399+
case arg :: Nil if !isUnaryNamedTupleSelectArg(arg) =>
400+
ref(getResult) :: Nil
401+
case _ =>
402+
productSelectors(getResult.info).map(ref(getResult).select(_))
392403
matchArgsPlan(selectors, args, onSuccess)
393404
}
394405
}

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ object SpaceEngine {
279279
|| unappResult <:< ConstantType(Constant(true)) // only for unapply
280280
|| (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) // scala2 compatibility
281281
|| unapplySeqTypeElemTp(unappResult).exists // only for unapplySeq
282-
|| isProductMatch(unappResult, argLen)
282+
|| isProductMatch(unappResult.stripNamedTuple, argLen)
283283
|| extractorMemberType(unappResult, nme.isEmpty, NoSourcePosition) <:< ConstantType(Constant(false))
284284
|| unappResult.derivesFrom(defn.NonEmptyTupleClass)
285285
|| unapp.symbol == defn.TupleXXL_unapplySeq // Fixes TupleXXL.unapplySeq which returns Some but declares Option

compiler/src/dotty/tools/dotc/typer/Checking.scala

+2
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,8 @@ trait Checking {
10361036
pats.forall(recur(_, pt))
10371037
case Typed(arg, tpt) =>
10381038
check(pat, pt) && recur(arg, pt)
1039+
case NamedArg(name, pat) =>
1040+
recur(pat, pt)
10391041
case Ident(nme.WILDCARD) =>
10401042
true
10411043
case pat: QuotePattern =>

tests/run/i22900.check

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
6
2+
6
3+
6
4+
6
5+
7
6+
6
7+
7
8+
(6)

tests/run/i22900.scala

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object NameBaseExtractor {
2+
def unapply(x: Int): Some[(someName: Int)] = Some((someName = x + 3))
3+
}
4+
object NameBaseExtractor2 {
5+
def unapply(x: Int): Some[(someName: Int, age: Int)] = Some((someName = x + 3, age = x + 4))
6+
}
7+
@main
8+
def Test =
9+
val x1 = 3 match
10+
case NameBaseExtractor(someName = x) => x
11+
println(x1)
12+
val NameBaseExtractor(someName = x2) = 3
13+
println(x2)
14+
val NameBaseExtractor((someName = x3)) = 3
15+
println(x3)
16+
17+
val NameBaseExtractor2(someName = x4, age = x5) = 3
18+
println(x4)
19+
println(x5)
20+
21+
val NameBaseExtractor2((someName = x6, age = x7)) = 3
22+
println(x6)
23+
println(x7)
24+
25+
val NameBaseExtractor(y1) = 3
26+
println(y1)

tests/run/i22900a.check

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3
2+
6
3+
3

tests/run/i22900a.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
case class C(someName: Int)
2+
3+
object NameBaseExtractor3 {
4+
def unapply(x: Int): Some[C] = Some(C(someName = x + 3))
5+
}
6+
7+
@main
8+
def Test = {
9+
val C(someName = xx) = C(3)
10+
println(xx)
11+
val NameBaseExtractor3(C(someName = x)) = 3
12+
println(x)
13+
C(3) match
14+
case C(someName = xx) => println(xx)
15+
}

tests/warn/i22899.scala

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
case class CaseClass(a: Int)
2+
3+
object ProductMatch_CaseClass {
4+
def unapply(int: Int): CaseClass = CaseClass(int)
5+
}
6+
7+
object ProductMatch_NamedTuple {
8+
def unapply(int: Int): (a: Int) = (a = int)
9+
}
10+
11+
object NameBasedMatch_CaseClass {
12+
def unapply(int: Int): Some[CaseClass] = Some(CaseClass(int))
13+
}
14+
15+
object NameBasedMatch_NamedTuple {
16+
def unapply(int: Int): Some[(a: Int)] = Some((a = int))
17+
}
18+
19+
object Test {
20+
val ProductMatch_CaseClass(a = x1) = 1 // ok, was pattern's type (x1 : Int) is more specialized than the right hand side expression's type Int
21+
val ProductMatch_NamedTuple(a = x2) = 2 // ok, was pattern binding uses refutable extractor `org.test.ProductMatch_NamedTuple`
22+
val NameBasedMatch_CaseClass(a = x3) = 3 // ok, was pattern's type (x3 : Int) is more specialized than the right hand side expression's type Int
23+
val NameBasedMatch_NamedTuple(a = x4) = 4 // ok, was pattern's type (x4 : Int) is more specialized than the right hand side expression's type Int
24+
25+
val CaseClass(a = x5) = CaseClass(5) // ok, was pattern's type (x5 : Int) is more specialized than the right hand side expression's type Int
26+
val (a = x6) = (a = 6) // ok
27+
}

0 commit comments

Comments
 (0)