Skip to content

Commit 8baaeae

Browse files
committed
Refactor classtag synthesis
1 parent c480bd8 commit 8baaeae

File tree

4 files changed

+40
-46
lines changed

4 files changed

+40
-46
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

+7-5
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,13 @@ object Types {
419419

420420
/** Is this a match type or a higher-kinded abstraction of one?
421421
*/
422-
def isMatch(using Context): Boolean = stripped match {
423-
case _: MatchType => true
424-
case tp: HKTypeLambda => tp.resType.isMatch
425-
case tp: AppliedType => tp.isMatchAlias
426-
case _ => false
422+
def isMatch(using Context): Boolean = underlyingMatchType.exists
423+
424+
def underlyingMatchType(using Context): Type = stripped match {
425+
case tp: MatchType => tp
426+
case tp: HKTypeLambda => tp.resType.underlyingMatchType
427+
case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType
428+
case _ => NoType
427429
}
428430

429431
/** Is this a higher-kinded type lambda with given parameter variances? */

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

+29-37
Original file line numberDiff line numberDiff line change
@@ -28,43 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2828
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]
2929

3030
val synthesizedClassTag: SpecialHandler = (formal, span) =>
31-
formal.argInfos match
32-
case arg :: Nil =>
33-
if isFullyDefined(arg, ForceDegree.all) then
34-
arg match
35-
case defn.ArrayOf(elemTp) =>
36-
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
37-
if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap))
38-
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
39-
val sym = tp.typeSymbol
40-
val classTag = ref(defn.ClassTagModule)
41-
val tag =
42-
if defn.SpecialClassTagClasses.contains(sym) then
43-
classTag.select(sym.name.toTermName).withSpan(span)
44-
else
45-
def clsOfType(tp: Type): Type =
46-
val tp1 = tp.dealias
47-
if tp1.isMatch then
48-
val matchTp = tp1.underlyingIterator.collect {
49-
case mt: MatchType => mt
50-
}.next
51-
matchTp.alternatives.map(clsOfType) match
52-
case ct1 :: cts if cts.forall(ct1 == _) => ct1
53-
case _ => NoType
54-
else
55-
escapeJavaArray(erasure(tp))
56-
val ctype = clsOfType(tp)
57-
if ctype.exists then
58-
classTag.select(nme.apply)
59-
.appliedToType(tp)
60-
.appliedTo(clsOf(ctype))
61-
.withSpan(span)
62-
else
63-
EmptyTree
64-
withNoErrors(tag)
65-
case tp => EmptyTreeNoError
66-
else EmptyTreeNoError
67-
case _ => EmptyTreeNoError
31+
val tag = formal.argInfos match
32+
case arg :: Nil if isFullyDefined(arg, ForceDegree.all) =>
33+
arg match
34+
case defn.ArrayOf(elemTp) =>
35+
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
36+
if etag.tpe.isError then EmptyTree else etag.select(nme.wrap)
37+
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
38+
val sym = tp.typeSymbol
39+
val classTagModul = ref(defn.ClassTagModule)
40+
if defn.SpecialClassTagClasses.contains(sym) then
41+
classTagModul.select(sym.name.toTermName).withSpan(span)
42+
else
43+
def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match
44+
case matchTp: MatchType =>
45+
matchTp.alternatives.map(clsOfType) match
46+
case ct1 :: cts if cts.forall(ct1 == _) => ct1
47+
case _ => NoType
48+
case _ =>
49+
escapeJavaArray(erasure(tp))
50+
val ctype = clsOfType(tp)
51+
if ctype.exists then
52+
classTagModul.select(nme.apply)
53+
.appliedToType(tp)
54+
.appliedTo(clsOf(ctype))
55+
.withSpan(span)
56+
else EmptyTree
57+
case _ => EmptyTree
58+
case _ => EmptyTree
59+
(tag, Nil)
6860
end synthesizedClassTag
6961

7062
val synthesizedTypeTest: SpecialHandler =

tests/run/i15618.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Some(1.0)
1+
Some(1)

tests/run/i15618.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ type ScalaType[U <: DType] <: Int | Float = U match
1515
abstract class Tensor[T <: DType]:
1616
def toArray: Array[ScalaType[T]]
1717

18-
object FloatTensor extends Tensor[Float16]:
19-
def toArray: Array[Float] = Array(1, 2, 3)
18+
object IntTensor extends Tensor[Int32]:
19+
def toArray: Array[Int] = Array(1, 2, 3)
2020

2121
@main
2222
def Test =
23-
val t = FloatTensor: Tensor[Float16] // Tensor[Float32]
23+
val t = IntTensor: Tensor[Int32]
2424
println(t.toArray.headOption) // was ClassCastException

0 commit comments

Comments
 (0)