diff --git a/build.sbt b/build.sbt index c4224b6..b96d560 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ Global / idePackagePrefix := Some("dummy") def DottyProject(name: String): Project = Project.apply(name, file(name)).settings( - scalaVersion := "3.1.0", + scalaVersion := "3.1.3", Compile / scalaSource := baseDirectory.value / "src", ) diff --git a/build.sc b/build.sc index 048701a..398429b 100644 --- a/build.sc +++ b/build.sc @@ -1,7 +1,7 @@ import mill._, scalalib._ trait DottyModule extends ScalaModule { - def scalaVersion = "3.1.0" + def scalaVersion = "3.1.3" def scalacOptions = Seq("-Xcheck-macros") } diff --git a/defaultParamsInference/src/Test.scala b/defaultParamsInference/src/Test.scala index 4c5b933..3ded467 100644 --- a/defaultParamsInference/src/Test.scala +++ b/defaultParamsInference/src/Test.scala @@ -1,6 +1,6 @@ package dummy -case class Person(name: String, address: String = "Zuricch", foo: Int, age: Int = 26) +case class Person[T](name: String, address: String = "Zuricch", foo: Int, age: Int = 26, bar: List[T] = Nil) object Person: val x = 10 @@ -8,5 +8,5 @@ object Person: @main def test(): Unit = val p1 = Person("John", foo = 10) println(p1) - println(defaultParams[Person]) - assert(defaultParams[Person] == Map("address" -> "Zuricch", "age" -> 26)) + println(defaultParams[Person[Double]]) + assert(defaultParams[Person[Double]] == Map("address" -> "Zuricch", "age" -> 26, "bar" -> Nil)) diff --git a/defaultParamsInference/src/macro.scala b/defaultParamsInference/src/macro.scala index 7fc350a..68e9f7d 100644 --- a/defaultParamsInference/src/macro.scala +++ b/defaultParamsInference/src/macro.scala @@ -1,12 +1,16 @@ package dummy +import scala.annotation.experimental import scala.quoted.* inline def defaultParams[T]: Map[String, Any] = ${ defaultParmasImpl[T] } +@experimental // because .typeArgs is @experimental def defaultParmasImpl[T](using quotes: Quotes, tpe: Type[T]): Expr[Map[String, Any]] = import quotes.reflect.* - val sym = TypeTree.of[T].symbol + val typ = TypeRepr.of[T] + val sym = typ.typeSymbol + val typeArgs = typ.typeArgs val comp = sym.companionClass val mod = Ref(sym.companionModule) val names = @@ -16,10 +20,10 @@ def defaultParmasImpl[T](using quotes: Quotes, tpe: Type[T]): Expr[Map[String, A Expr.ofList(names.map(Expr(_))) val body = comp.tree.asInstanceOf[ClassDef].body - val idents: List[Ref] = + val idents: List[Term] = for case deff @ DefDef(name, _, _, _) <- body if name.startsWith("$lessinit$greater$default") - yield mod.select(deff.symbol) + yield mod.select(deff.symbol).appliedToTypes(typeArgs) val identsExpr: Expr[List[Any]] = Expr.ofList(idents.map(_.asExpr))