Skip to content

Commit 11d65aa

Browse files
authored
Merge pull request #15625 from dotty-staging/fix-15618
Fix two problems related to match types as array elements
2 parents 794e7c9 + 8baaeae commit 11d65aa

File tree

7 files changed

+105
-25
lines changed

7 files changed

+105
-25
lines changed

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

+3
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ object TypeErasure {
328328
isGenericArrayElement(tp.alias, isScala2)
329329
case tp: TypeBounds =>
330330
!fitsInJVMArray(tp.hi)
331+
case tp: MatchType =>
332+
val alts = tp.alternatives
333+
alts.nonEmpty && !fitsInJVMArray(alts.reduce(OrType(_, _, soft = true)))
331334
case tp: TypeProxy =>
332335
isGenericArrayElement(tp.translucentSuperType, isScala2)
333336
case tp: AndType =>

compiler/src/dotty/tools/dotc/core/Types.scala

+7-5
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,13 @@ object Types {
419419

420420
/** Is this a match type or a higher-kinded abstraction of one?
421421
*/
422-
def isMatch(using Context): Boolean = stripped match {
423-
case _: MatchType => true
424-
case tp: HKTypeLambda => tp.resType.isMatch
425-
case tp: AppliedType => tp.isMatchAlias
426-
case _ => false
422+
def isMatch(using Context): Boolean = underlyingMatchType.exists
423+
424+
def underlyingMatchType(using Context): Type = stripped match {
425+
case tp: MatchType => tp
426+
case tp: HKTypeLambda => tp.resType.underlyingMatchType
427+
case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType
428+
case _ => NoType
427429
}
428430

429431
/** Is this a higher-kinded type lambda with given parameter variances? */

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+29-20
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2828
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]
2929

3030
val synthesizedClassTag: SpecialHandler = (formal, span) =>
31-
formal.argInfos match
32-
case arg :: Nil =>
33-
if isFullyDefined(arg, ForceDegree.all) then
34-
arg match
35-
case defn.ArrayOf(elemTp) =>
36-
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
37-
if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap))
38-
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
39-
val sym = tp.typeSymbol
40-
val classTag = ref(defn.ClassTagModule)
41-
val tag =
42-
if defn.SpecialClassTagClasses.contains(sym) then
43-
classTag.select(sym.name.toTermName)
44-
else
45-
val clsOfType = escapeJavaArray(erasure(tp))
46-
classTag.select(nme.apply).appliedToType(tp).appliedTo(clsOf(clsOfType))
47-
withNoErrors(tag.withSpan(span))
48-
case tp => EmptyTreeNoError
49-
else EmptyTreeNoError
50-
case _ => EmptyTreeNoError
31+
val tag = formal.argInfos match
32+
case arg :: Nil if isFullyDefined(arg, ForceDegree.all) =>
33+
arg match
34+
case defn.ArrayOf(elemTp) =>
35+
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
36+
if etag.tpe.isError then EmptyTree else etag.select(nme.wrap)
37+
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
38+
val sym = tp.typeSymbol
39+
val classTagModul = ref(defn.ClassTagModule)
40+
if defn.SpecialClassTagClasses.contains(sym) then
41+
classTagModul.select(sym.name.toTermName).withSpan(span)
42+
else
43+
def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match
44+
case matchTp: MatchType =>
45+
matchTp.alternatives.map(clsOfType) match
46+
case ct1 :: cts if cts.forall(ct1 == _) => ct1
47+
case _ => NoType
48+
case _ =>
49+
escapeJavaArray(erasure(tp))
50+
val ctype = clsOfType(tp)
51+
if ctype.exists then
52+
classTagModul.select(nme.apply)
53+
.appliedToType(tp)
54+
.appliedTo(clsOf(ctype))
55+
.withSpan(span)
56+
else EmptyTree
57+
case _ => EmptyTree
58+
case _ => EmptyTree
59+
(tag, Nil)
5160
end synthesizedClassTag
5261

5362
val synthesizedTypeTest: SpecialHandler =

tests/neg/i15618.check

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
-- Error: tests/neg/i15618.scala:17:44 ---------------------------------------------------------------------------------
2+
17 | def toArray: Array[ScalaType[T]] = Array() // error
3+
| ^
4+
| No ClassTag available for ScalaType[T]
5+
|
6+
| where: T is a type in class Tensor with bounds <: DType
7+
|
8+
|
9+
| Note: a match type could not be fully reduced:
10+
|
11+
| trying to reduce ScalaType[T]
12+
| failed since selector T
13+
| does not match case Float16 => Float
14+
| and cannot be shown to be disjoint from it either.
15+
| Therefore, reduction cannot advance to the remaining cases
16+
|
17+
| case Float32 => Float
18+
| case Int32 => Int

tests/neg/i15618.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
sealed abstract class DType
2+
sealed class Float16 extends DType
3+
sealed class Float32 extends DType
4+
sealed class Int32 extends DType
5+
6+
object Float16 extends Float16
7+
object Float32 extends Float32
8+
object Int32 extends Int32
9+
10+
type ScalaType[U <: DType] <: Int | Float = U match
11+
case Float16 => Float
12+
case Float32 => Float
13+
case Int32 => Int
14+
15+
class Tensor[T <: DType](dtype: T):
16+
def toSeq: Seq[ScalaType[T]] = Seq()
17+
def toArray: Array[ScalaType[T]] = Array() // error
18+
19+
@main
20+
def Test =
21+
val t = Tensor(Float32) // Tensor[Float32]
22+
println(t.toSeq.headOption) // works, Seq[Float]
23+
println(t.toArray.headOption) // ClassCastException

tests/run/i15618.check

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Some(1)

tests/run/i15618.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
sealed abstract class DType
2+
sealed class Float16 extends DType
3+
sealed class Float32 extends DType
4+
sealed class Int32 extends DType
5+
6+
object Float16 extends Float16
7+
object Float32 extends Float32
8+
object Int32 extends Int32
9+
10+
type ScalaType[U <: DType] <: Int | Float = U match
11+
case Float16 => Float
12+
case Float32 => Float
13+
case Int32 => Int
14+
15+
abstract class Tensor[T <: DType]:
16+
def toArray: Array[ScalaType[T]]
17+
18+
object IntTensor extends Tensor[Int32]:
19+
def toArray: Array[Int] = Array(1, 2, 3)
20+
21+
@main
22+
def Test =
23+
val t = IntTensor: Tensor[Int32]
24+
println(t.toArray.headOption) // was ClassCastException

0 commit comments

Comments
 (0)