Skip to content

Commit ed42111

Browse files
authored
Merge pull request #7466 from dotty-staging/fix-#7459
Fix #7459: Fix two crash conditions in Inliner
2 parents f73eec7 + 39ade4e commit ed42111

File tree

4 files changed

+91
-9
lines changed

4 files changed

+91
-9
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ object Inferencing {
4343
if (isFullyDefined(tp, ForceDegree.all)) tp
4444
else throw new Error(i"internal error: type of $what $tp is not fully defined, pos = $span") // !!! DEBUG
4545

46-
4746
/** Instantiate selected type variables `tvars` in type `tp` */
4847
def instantiateSelected(tp: Type, tvars: List[Type])(implicit ctx: Context): Unit =
4948
if (tvars.nonEmpty)

compiler/src/dotty/tools/dotc/typer/Inliner.scala

+11-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import Names.{Name, TermName}
1818
import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
1919
import ProtoTypes.selectionProto
2020
import SymDenotations.SymDenotation
21-
import Inferencing.fullyDefinedType
21+
import Inferencing.isFullyDefined
2222
import config.Printers.inlining
2323
import ErrorReporting.errorTree
2424
import dotty.tools.dotc.tastyreflect.ReflectionImpl
@@ -239,8 +239,10 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
239239

240240
inlining.println(i"-----------------------\nInlining $call\nWith RHS $rhsToInline")
241241

242-
// Make sure all type arguments to the call are fully determined
243-
for (targ <- callTypeArgs) fullyDefinedType(targ.tpe, "inlined type argument", targ.span)
242+
// Make sure all type arguments to the call are fully determined,
243+
// but continue if that's not achievable (or else i7459.scala would crash).
244+
for arg <- callTypeArgs do
245+
isFullyDefined(arg.tpe, ForceDegree.all)
244246

245247
/** A map from parameter names of the inlineable method to references of the actual arguments.
246248
* For a type argument this is the full argument type.
@@ -313,9 +315,9 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
313315

314316
/** Populate `paramBinding` and `bindingsBuf` by matching parameters with
315317
* corresponding arguments. `bindingbuf` will be further extended later by
316-
* proxies to this-references.
318+
* proxies to this-references. Issue an error if some arguments are missing.
317319
*/
318-
private def computeParamBindings(tp: Type, targs: List[Tree], argss: List[List[Tree]]): Unit = tp match {
320+
private def computeParamBindings(tp: Type, targs: List[Tree], argss: List[List[Tree]]): Boolean = tp match
319321
case tp: PolyType =>
320322
tp.paramNames.lazyZip(targs).foreach { (name, arg) =>
321323
paramSpan(name) = arg.span
@@ -324,8 +326,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
324326
computeParamBindings(tp.resultType, Nil, argss)
325327
case tp: MethodType =>
326328
if argss.isEmpty then
327-
// can happen if arguments have errors, see i7438.scala
328329
ctx.error(i"mising arguments for inline method $inlinedMethod", call.sourcePos)
330+
false
329331
else
330332
tp.paramNames.lazyZip(tp.paramInfos).lazyZip(argss.head).foreach { (name, paramtp, arg) =>
331333
paramSpan(name) = arg.span
@@ -338,7 +340,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
338340
case _ =>
339341
assert(targs.isEmpty)
340342
assert(argss.isEmpty)
341-
}
343+
true
342344

343345
// Compute val-definitions for all this-proxies and append them to `bindingsBuf`
344346
private def computeThisBindings() = {
@@ -447,7 +449,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
447449
}
448450

449451
// Compute bindings for all parameters, appending them to bindingsBuf
450-
computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss)
452+
if !computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss) then
453+
return call
451454

452455
// make sure prefix is executed if it is impure
453456
if (!isIdempotentExpr(inlineCallPrefix)) registerType(inlinedMethod.owner.thisType)

tests/neg/i7459.scala

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
object Foo {
2+
inline def summon[T](x: T): T = x match {
3+
case t: T => t
4+
}
5+
println(summon) // error
6+
}
7+
8+
import scala.deriving._
9+
import scala.compiletime.erasedValue
10+
11+
inline def summon[T](given t:T): T = t match {
12+
case t: T => t
13+
}
14+
15+
inline def summonAll[T <: Tuple]: List[Eq[_]] = inline erasedValue[T] match {
16+
case _: Unit => Nil
17+
case _: (t *: ts) => summon[Eq[t]] :: summonAll[ts] // error
18+
}
19+
20+
trait Eq[T] {
21+
def eqv(x: T, y: T): Boolean
22+
}
23+
24+
object Eq {
25+
given Eq[Int] {
26+
def eqv(x: Int, y: Int) = x == y
27+
}
28+
29+
def check(elem: Eq[_])(x: Any, y: Any): Boolean =
30+
elem.asInstanceOf[Eq[Any]].eqv(x, y)
31+
32+
def iterator[T](p: T) = p.asInstanceOf[Product].productIterator
33+
34+
def eqSum[T](s: Mirror.SumOf[T], elems: List[Eq[_]]): Eq[T] =
35+
new Eq[T] {
36+
def eqv(x: T, y: T): Boolean = {
37+
val ordx = s.ordinal(x)
38+
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y)
39+
}
40+
}
41+
42+
def eqProduct[T](p: Mirror.ProductOf[T], elems: List[Eq[_]]): Eq[T] =
43+
new Eq[T] {
44+
def eqv(x: T, y: T): Boolean =
45+
iterator(x).zip(iterator(y)).zip(elems.iterator).forall {
46+
case ((x, y), elem) => check(elem)(x, y)
47+
}
48+
}
49+
50+
inline given derived[T](given m: Mirror.Of[T]): Eq[T] = {
51+
val elemInstances = summonAll[m.MirroredElemTypes]
52+
inline m match {
53+
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
54+
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
55+
}
56+
}
57+
}
58+
59+
60+
enum Opt[+T] derives Eq {
61+
case Sm(t: T)
62+
case Nn
63+
}

tests/pos/matrixOps.scala

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
object Test with
2+
3+
type Matrix = Array[Array[Double]]
4+
type Vector = Array[Double]
5+
6+
given (m: Matrix)
7+
def nRows = m.length
8+
def nCols = m(0).length
9+
def row(i: Int): Vector = m(i)
10+
def col(j: Int): Vector = Array.tabulate(m.length)(i => m(i)(j))
11+
12+
def pairwise(m: Matrix) =
13+
for
14+
i <- 0 until m.nRows
15+
j <- 0 until m.nCols
16+
yield
17+
m.row(i).zip(m.row(j)).map(_ - _).sum

0 commit comments

Comments
 (0)