Skip to content

Commit d4421d0

Browse files
Under betterFors don't drop the trailing map if it would result in a different type (also drop _ => ()) (#22619)
closes #21804
1 parent 4d48bce commit d4421d0

File tree

7 files changed

+158
-15
lines changed

7 files changed

+158
-15
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import parsing.Parser
88
import Phases.Phase
99
import transform.*
1010
import backend.jvm.{CollectSuperCalls, GenBCode}
11-
import localopt.StringInterpolatorOpt
11+
import localopt.{StringInterpolatorOpt, DropForMap}
1212

1313
/** The central class of the dotc compiler. The job of a compiler is to create
1414
* runs, which process given `phases` in a given `rootContext`.
@@ -68,7 +68,8 @@ class Compiler {
6868
new InlineVals, // Check right hand-sides of an `inline val`s
6969
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
7070
new ElimRepeated, // Rewrite vararg parameters and arguments
71-
new RefChecks) :: // Various checks mostly related to abstract members and overriding
71+
new RefChecks, // Various checks mostly related to abstract members and overriding
72+
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
7273
List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file
7374
List(new init.Checker) :: // Check initialization of objects
7475
List(new ProtectedAccessors, // Add accessors for protected members

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+19-12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ object desugar {
6464
*/
6565
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
6666

67+
/** An attachment key to indicate that an Apply is created as a last `map`
68+
* scall in a for-comprehension.
69+
*/
70+
val TrailingForMap: Property.Key[Unit] = Property.StickyKey()
71+
6772
/** What static check should be applied to a Match? */
6873
enum MatchCheck {
6974
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -1967,14 +1972,8 @@ object desugar {
19671972
*
19681973
* 3.
19691974
*
1970-
* for (P <- G) yield P ==> G
1971-
*
1972-
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
1973-
*
19741975
* for (P <- G) yield E ==> G.map (P => E)
19751976
*
1976-
* Otherwise
1977-
*
19781977
* 4.
19791978
*
19801979
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
@@ -2147,14 +2146,20 @@ object desugar {
21472146
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
21482147
case _ => false
21492148

2149+
def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit =
2150+
if betterForsEnabled
2151+
&& selectName == mapName
2152+
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2153+
&& (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil)))
2154+
then
2155+
aply.putAttachment(TrailingForMap, ())
2156+
21502157
enums match {
21512158
case Nil if betterForsEnabled => body
21522159
case (gen: GenFrom) :: Nil =>
2153-
if betterForsEnabled
2154-
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2155-
&& deepEquals(gen.pat, body)
2156-
then gen.expr // avoid a redundant map with identity
2157-
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2160+
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2161+
markTrailingMap(aply, gen, mapName)
2162+
aply
21582163
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
21592164
val cont = makeFor(mapName, flatMapName, rest, body)
21602165
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
@@ -2165,7 +2170,9 @@ object desugar {
21652170
val selectName =
21662171
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
21672172
else mapName
2168-
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2173+
val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2174+
markTrailingMap(aply, gen, selectName)
2175+
aply
21692176
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
21702177
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
21712178
val pats = valeqs map { case GenAlias(pat, _) => pat }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import dotty.tools.dotc.ast.tpd.*
5+
import dotty.tools.dotc.core.Decorators.*
6+
import dotty.tools.dotc.core.Contexts.*
7+
import dotty.tools.dotc.core.StdNames.*
8+
import dotty.tools.dotc.core.Symbols.*
9+
import dotty.tools.dotc.core.Types.*
10+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
11+
import dotty.tools.dotc.ast.desugar
12+
13+
/** Drop unused trailing map calls in for comprehensions.
14+
* We can drop the map call if:
15+
* - it won't change the type of the expression, and
16+
* - the function is an identity function or a const function to unit.
17+
*
18+
* The latter condition is checked in [[Desugar.scala#makeFor]]
19+
*/
20+
class DropForMap extends MiniPhase:
21+
import DropForMap.*
22+
23+
override def phaseName: String = DropForMap.name
24+
25+
override def description: String = DropForMap.description
26+
27+
override def transformApply(tree: Apply)(using Context): Tree =
28+
if !tree.hasAttachment(desugar.TrailingForMap) then tree
29+
else tree match
30+
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
31+
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
32+
f // drop the map call
33+
case _ =>
34+
tree.removeAttachment(desugar.TrailingForMap)
35+
tree
36+
37+
private object Lambda:
38+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
39+
tree match
40+
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
41+
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
42+
Some((defdef.termParamss.flatten, defdef.rhs))
43+
case _ => None
44+
45+
private object MapCall:
46+
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
47+
case Select(f, nme.map) => Some(f)
48+
case Apply(fn, _) => unapply(fn)
49+
case TypeApply(fn, _) => unapply(fn)
50+
case _ => None
51+
52+
object DropForMap:
53+
val name: String = "dropForMap"
54+
val description: String = "Drop unused trailing map calls in for comprehensions"

docs/_docs/reference/experimental/better-fors.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Additionally this extension changes the way `for`-comprehensions are desugared.
6060
This change makes the desugaring more intuitive and avoids unnecessary `map` calls, when an alias is not followed by a guard.
6161

6262
2. **Avoiding Redundant `map` Calls**:
63-
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. but th eequality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables.
63+
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. But the equality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. There is also a special case for dropping the `map`, if its body is a constant function, that returns `()` (`Unit` constant).
6464
**Current Desugaring**:
6565
```scala
6666
for {

tests/pos/better-fors-i21804.scala

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.language.experimental.betterFors
2+
3+
case class Container[A](val value: A) {
4+
def map[B](f: A => B): Container[B] = Container(f(value))
5+
}
6+
7+
sealed trait Animal
8+
case class Dog() extends Animal
9+
10+
def opOnDog(dog: Container[Dog]): Container[Animal] =
11+
for
12+
v <- dog
13+
yield v

tests/run/better-fors-map-elim.check

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MySome(())
2+
MySome(2)
3+
MySome((2,3))
4+
MySome((2,(3,4)))

tests/run/better-fors-map-elim.scala

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import scala.language.experimental.betterFors
2+
3+
class myOptionModule(doOnMap: => Unit) {
4+
sealed trait MyOption[+A] {
5+
def map[B](f: A => B): MyOption[B] = this match {
6+
case MySome(x) => {
7+
doOnMap
8+
MySome(f(x))
9+
}
10+
case MyNone => MyNone
11+
}
12+
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
13+
case MySome(x) => f(x)
14+
case MyNone => MyNone
15+
}
16+
}
17+
case class MySome[A](x: A) extends MyOption[A]
18+
case object MyNone extends MyOption[Nothing]
19+
object MyOption {
20+
def apply[A](x: A): MyOption[A] = MySome(x)
21+
}
22+
}
23+
24+
object Test extends App {
25+
26+
val myOption = new myOptionModule(println("map called"))
27+
28+
import myOption.*
29+
30+
def portablePrintMyOption(opt: MyOption[Any]): Unit =
31+
if opt == MySome(()) then
32+
println("MySome(())")
33+
else
34+
println(opt)
35+
36+
val z = for {
37+
a <- MyOption(1)
38+
b <- MyOption(())
39+
} yield ()
40+
41+
portablePrintMyOption(z)
42+
43+
val z2 = for {
44+
a <- MyOption(1)
45+
b <- MyOption(2)
46+
} yield b
47+
48+
portablePrintMyOption(z2)
49+
50+
val z3 = for {
51+
a <- MyOption(1)
52+
(b, c) <- MyOption((2, 3))
53+
} yield (b, c)
54+
55+
portablePrintMyOption(z3)
56+
57+
val z4 = for {
58+
a <- MyOption(1)
59+
(b, (c, d)) <- MyOption((2, (3, 4)))
60+
} yield (b, (c, d))
61+
62+
portablePrintMyOption(z4)
63+
64+
}

0 commit comments

Comments
 (0)