diff --git a/modules/build/src/main/scala/scala/build/ScopedSources.scala b/modules/build/src/main/scala/scala/build/ScopedSources.scala
index 75bb29e710..4684b836c5 100644
--- a/modules/build/src/main/scala/scala/build/ScopedSources.scala
+++ b/modules/build/src/main/scala/scala/build/ScopedSources.scala
@@ -62,14 +62,14 @@ final case class ScopedSources(
): Either[BuildException, Sources] = either {
val combinedOptions = combinedBuildOptions(scope, baseOptions)
- val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions)
+ val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger)
val wrappedScripts = unwrappedScripts
.flatMap(_.valueFor(scope).toSeq)
.map(_.wrap(codeWrapper))
codeWrapper match {
- case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
+ case _: AppCodeWrapper if wrappedScripts.size > 1 =>
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
case _ => ()
diff --git a/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala
index c20fd26983..f96cfb865d 100644
--- a/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala
+++ b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala
@@ -1,6 +1,6 @@
package scala.build.internal
-case object AppCodeWrapper extends CodeWrapper {
+case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
override def mainClassObject(className: Name) = className
def apply(
@@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper {
) = {
val wrapperObjectName = indexedWrapperName.backticked
+ val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
+ val invokeMain = mainObject match
+ case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)"
+ case otherwise =>
+ otherwise.warningMessage.foreach(log)
+ ""
val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|object $wrapperObjectName extends App {
- |val scriptPath = \"\"\"$scriptPath\"\"\"
+ |val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
|""".stripMargin
)
val bottom = AmmUtil.normalizeNewlines(
diff --git a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
index 8adcd8ad8b..97479587a5 100644
--- a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
+++ b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
@@ -5,7 +5,7 @@ package scala.build.internal
* running interconnected scripts using Scala CLI
Incompatible with Scala 2 - it uses
* Scala 3 feature 'export'
Incompatible with native JS members - the wrapper is a class
*/
-case object ClassCodeWrapper extends CodeWrapper {
+case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
@@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
+
+ val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
+ val mainInvocation = mainObject match
+ case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)"
+ case otherwise =>
+ otherwise.warningMessage.foreach(log)
+ s"val _ = script.hashCode()"
+
val name = mainClassObject(indexedWrapperName).backticked
- val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
+ val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
| private var args$$opt0 = Option.empty[Array[String]]
@@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper {
|
| def main(args: Array[String]): Unit = {
| args$$set(args)
- | val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
+ | $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|
diff --git a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
index bbd1f9e9b9..a594f2fcfa 100644
--- a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
+++ b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
@@ -4,7 +4,7 @@ package scala.build.internal
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
* threads from script
*/
-case object ObjectCodeWrapper extends CodeWrapper {
+case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
@@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
+ val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val name = mainClassObject(indexedWrapperName).backticked
val aliasedWrapperName = name + "$$alias"
- val funHashCodeMethod =
+ val realScript =
if (name == "main_sc")
- s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
- else s"${indexedWrapperName.backticked}.hashCode()"
+ s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
+ else s"${indexedWrapperName.backticked}"
+
+ val funHashCodeMethod = mainObject match
+ case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)"
+ case otherwise =>
+ otherwise.warningMessage.foreach(log)
+ s"val _ = $realScript.hashCode()"
// We need to call hashCode (or any other method so compiler does not report a warning)
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
@@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
| }
| def main(args: Array[String]): Unit = {
| args$$set(args)
- | val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
+ | $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|""".stripMargin)
diff --git a/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala b/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala
new file mode 100644
index 0000000000..7197c0d48a
--- /dev/null
+++ b/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala
@@ -0,0 +1,88 @@
+package scala.build.internal
+
+import scala.build.internal.util.WarningMessages
+
+object WrapperUtils {
+
+ enum ScriptMainMethod:
+ case Exists(name: String)
+ case Multiple(names: Seq[String])
+ case ToplevelStatsPresent
+ case ToplevelStatsWithMultiple(names: Seq[String])
+ case NoMain
+
+ def warningMessage: List[String] =
+ this match
+ case ScriptMainMethod.Multiple(names) =>
+ List(WarningMessages.multipleMainObjectsInScript(names))
+ case ScriptMainMethod.ToplevelStatsPresent => List(
+ WarningMessages.mixedToplvelAndObjectInScript
+ )
+ case ToplevelStatsWithMultiple(names) =>
+ List(
+ WarningMessages.multipleMainObjectsInScript(names),
+ WarningMessages.mixedToplvelAndObjectInScript
+ )
+ case _ => Nil
+
+ def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod =
+ import scala.meta.*
+
+ val scriptDialect =
+ if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3
+
+ given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
+ val parsedCode = code.parse[Source] match
+ case Parsed.Success(Source(stats)) => stats
+ case _ => Nil
+
+ // Check if there is a main function defined inside an object
+ def checkSignature(defn: Defn.Def) =
+ defn.paramClauseGroups match
+ case List(Member.ParamClauseGroup(
+ Type.ParamClause(Nil),
+ List(Term.ParamClause(
+ List(Term.Param(
+ Nil,
+ _: Term.Name,
+ Some(Type.Apply.After_4_6_0(
+ Type.Name("Array"),
+ Type.ArgClause(List(Type.Name("String")))
+ )),
+ None
+ )),
+ None
+ ))
+ )) => true
+ case _ => false
+
+ def noToplevelStatements = parsedCode.forall {
+ case _: Term => false
+ case _ => true
+ }
+
+ def hasMainSignature(templ: Template) = templ.body.stats.exists {
+ case defn: Defn.Def =>
+ defn.name.value == "main" && checkSignature(defn)
+ case _ => false
+ }
+ def extendsApp(templ: Template) = templ.inits match
+ case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true
+ case _ => false
+ val potentialMains = parsedCode.collect {
+ case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) =>
+ Seq(objName.value)
+ }.flatten
+
+ potentialMains match
+ case head :: Nil if noToplevelStatements =>
+ ScriptMainMethod.Exists(head)
+ case head :: Nil =>
+ ScriptMainMethod.ToplevelStatsPresent
+ case Nil => ScriptMainMethod.NoMain
+ case seq if noToplevelStatements =>
+ ScriptMainMethod.Multiple(seq)
+ case seq =>
+ ScriptMainMethod.ToplevelStatsWithMultiple(seq)
+
+}
diff --git a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
index 271c31bf15..9f74e9c212 100644
--- a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
+++ b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
@@ -105,6 +105,12 @@ object WarningMessages {
val offlineModeBloopJvmNotFound =
"Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback"
+ def multipleMainObjectsInScript(names: Seq[String]) =
+ s"Only a single main is allowed within scripts. Multiple main classes were found in the script: ${names.mkString(", ")}"
+
+ def mixedToplvelAndObjectInScript =
+ "Script contains objects with main methods and top-level statements, only the latter will be run."
+
def directivesInMultipleFilesWarning(
projectFilePath: String,
pathsToReport: Iterable[String] = Nil
diff --git a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
index 3d3633f303..c0c4d24a16 100644
--- a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
+++ b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
@@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
(codeWrapper: CodeWrapper) =>
if (containsMainAnnot) logger.diagnostic(
codeWrapper match {
- case _: AppCodeWrapper.type =>
+ case _: AppCodeWrapper =>
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
}
@@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
* @return
* code wrapper compatible with provided BuildOptions
*/
- def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
+ def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = {
val effectiveScalaVersion =
buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt)
.orElse(buildOptions.scalaOptions.defaultScalaVersion)
.getOrElse(Constants.defaultScalaVersion)
+ def logWarning(msg: String) = logger.diagnostic(msg)
def objectCodeWrapperForScalaVersion =
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
- if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
- else ObjectCodeWrapper
+ if effectiveScalaVersion.startsWith("2") then
+ AppCodeWrapper(effectiveScalaVersion, logWarning)
+ else ObjectCodeWrapper(effectiveScalaVersion, logWarning)
buildOptions.scriptOptions.forceObjectWrapper match {
case Some(true) => objectCodeWrapperForScalaVersion
case _ =>
buildOptions.scalaOptions.platform.map(_.value) match {
- case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
- case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
- case _ => ClassCodeWrapper
+ case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
+ case _ if effectiveScalaVersion.startsWith("2") =>
+ AppCodeWrapper(effectiveScalaVersion, logWarning)
+ case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning)
}
}
}
diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
index e75c63060b..4bbc867003 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
@@ -69,6 +69,163 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
}
}
+ test("main.sc has an object with a main method") {
+ val message = "Hello"
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|
+ |object Main {
+ | def main(args: Array[String]): Unit = println("$message")
+ |}
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd =
+ root
+ ).out.trim()
+ expect(output == message)
+ }
+ }
+ test("main.sc has an object that extends App") {
+ val message = "Hello"
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|
+ |object Main extends App{
+ | println("$message")
+ |}
+ |
+ |object Other {}
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd =
+ root
+ ).out.trim()
+ expect(output == message)
+ }
+ }
+
+ test("main.sc has an object with a main method and an object wrapper") {
+ val message = "Hello"
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|//> using objectWrapper
+ |object Main {
+ | def main(args: Array[String]): Unit = println("$message")
+ |}
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(cwd =
+ root
+ ).out.trim()
+ expect(output == message)
+ }
+ }
+
+ test("main.sc has multiple main methods") {
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|//> using objectWrapper
+ |object Main {
+ | def main(args: Array[String]): Unit = println("1")
+ |}
+ |object AnotherMain {
+ | def main(args: Array[String]): Unit = println("2")
+ |}
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val result = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(
+ cwd = root,
+ stderr = os.Pipe
+ )
+ val output = result.out.trim()
+ val err = result.err.trim()
+ expect(output == "")
+ expect(err.contains(
+ "Only a single main is allowed within scripts. Multiple main classes were found in the script: Main, AnotherMain"
+ ))
+ }
+ }
+ test("main.sc has multiple main methods and top-level definitions") {
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|//> using objectWrapper
+ |object Main {
+ | def main(args: Array[String]): Unit = println("1")
+ |}
+ |object AnotherMain {
+ | def main(args: Array[String]): Unit = println("2")
+ |}
+ |
+ |println("3")
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val result = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(
+ cwd = root,
+ stderr = os.Pipe
+ )
+ val output = result.out.trim()
+ val err = result.err.trim()
+ expect(output == "3")
+ expect(err.contains(
+ "Only a single main is allowed within scripts. Multiple main classes were found in the script: Main, AnotherMain"
+ ))
+ expect(err.contains(
+ "Script contains objects with main methods and top-level statements, only the latter will be run."
+ ))
+ }
+ }
+
+ test("main.sc has both an object with a main method as well as top-level definitions") {
+ val message1 = "Hello"
+ val message2 = "Another hello"
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|object Main {
+ | def main(args: Array[String]): Unit = println("$message1")
+ |}
+ |println("$message2")
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val result = os.proc(TestUtil.cli, extraOptions, "main.sc").call(
+ cwd = root,
+ stderr = os.Pipe
+ )
+ val output = result.out.trim()
+ val err = result.err.trim()
+ expect(output == message2)
+ expect(err.contains(
+ "Script contains objects with main methods and top-level statements, only the latter will be run."
+ ))
+ expect(output == message2)
+ }
+ }
+
+ test(
+ "main.sc has both an object with a main method and an object wrapper as well as top-level calls"
+ ) {
+ val message1 = "Hello"
+ val message2 = "Another hello"
+ val inputs = TestInputs(
+ os.rel / "main.sc" ->
+ s"""|//> using objectWrapper
+ |object Main {
+ | def main(args: Array[String]): Unit = println("$message1")
+ |}
+ |println("$message2")
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc")
+ .call(cwd = root).out.trim()
+ expect(output == message2)
+ }
+ }
if (actualScalaVersion.startsWith("3"))
test("use method from main.sc file") {
val message = "Hello"
diff --git a/project/deps.sc b/project/deps.sc
index a48bd4fa32..35552d0a94 100644
--- a/project/deps.sc
+++ b/project/deps.sc
@@ -117,7 +117,7 @@ object Deps {
def jsoniterScala = "2.23.2"
def jsoniterScalaJava8 = "2.13.5.2"
def jsoup = "1.18.3"
- def scalaMeta = "4.9.9"
+ def scalaMeta = "4.12.7"
def scalaNative04 = "0.4.17"
def scalaNative05 = "0.5.6"
def scalaNative = scalaNative05
@@ -227,7 +227,7 @@ object Deps {
def semanticDbJavac = ivy"com.sourcegraph:semanticdb-javac:${Versions.javaSemanticdb}"
def semanticDbScalac = ivy"org.scalameta:::semanticdb-scalac:${Versions.scalaMeta}"
def scalametaSemanticDbShared =
- ivy"org.scalameta:semanticdb-shared_${Scala.scala213}:${Versions.scalaMeta}"
+ ivy"org.scalameta:semanticdb-shared_2.13:${Versions.scalaMeta}"
.exclude("org.jline" -> "jline") // to prevent incompatibilities with GraalVM <23
.exclude("com.lihaoyi" -> "sourcecode_2.13")
.exclude("org.scala-lang.modules" -> "scala-collection-compat_2.13")