diff --git a/modules/build/src/main/scala/scala/build/ScopedSources.scala b/modules/build/src/main/scala/scala/build/ScopedSources.scala index 75bb29e710..4684b836c5 100644 --- a/modules/build/src/main/scala/scala/build/ScopedSources.scala +++ b/modules/build/src/main/scala/scala/build/ScopedSources.scala @@ -62,14 +62,14 @@ final case class ScopedSources( ): Either[BuildException, Sources] = either { val combinedOptions = combinedBuildOptions(scope, baseOptions) - val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions) + val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger) val wrappedScripts = unwrappedScripts .flatMap(_.valueFor(scope).toSeq) .map(_.wrap(codeWrapper)) codeWrapper match { - case _: AppCodeWrapper.type if wrappedScripts.size > 1 => + case _: AppCodeWrapper if wrappedScripts.size > 1 => wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc")) .foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper)) case _ => () diff --git a/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala index c20fd26983..f96cfb865d 100644 --- a/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala +++ b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala @@ -1,6 +1,6 @@ package scala.build.internal -case object AppCodeWrapper extends CodeWrapper { +case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper { override def mainClassObject(className: Name) = className def apply( @@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper { ) = { val wrapperObjectName = indexedWrapperName.backticked + val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code) + val invokeMain = mainObject match + case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)" + case otherwise => + otherwise.warningMessage.foreach(log) + "" val packageDirective = if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n" val top = AmmUtil.normalizeNewlines( s"""$packageDirective | |object $wrapperObjectName extends App { - |val scriptPath = \"\"\"$scriptPath\"\"\" + |val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain |""".stripMargin ) val bottom = AmmUtil.normalizeNewlines( diff --git a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala index 8adcd8ad8b..97479587a5 100644 --- a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala +++ b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala @@ -5,7 +5,7 @@ package scala.build.internal * running interconnected scripts using Scala CLI

Incompatible with Scala 2 - it uses * Scala 3 feature 'export'
Incompatible with native JS members - the wrapper is a class */ -case object ClassCodeWrapper extends CodeWrapper { +case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper { override def mainClassObject(className: Name): Name = Name(className.raw ++ "_sc") @@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper { extraCode: String, scriptPath: String ) = { + + val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code) + val mainInvocation = mainObject match + case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)" + case otherwise => + otherwise.warningMessage.foreach(log) + s"val _ = script.hashCode()" + val name = mainClassObject(indexedWrapperName).backticked - val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked + val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked val mainObjectCode = AmmUtil.normalizeNewlines(s"""|object $name { | private var args$$opt0 = Option.empty[Array[String]] @@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper { | | def main(args: Array[String]): Unit = { | args$$set(args) - | val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position + | $mainInvocation // hashCode to clear scalac warning about pure expression in statement position | } |} | diff --git a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala index bbd1f9e9b9..a594f2fcfa 100644 --- a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala +++ b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala @@ -4,7 +4,7 @@ package scala.build.internal * or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running * threads from script */ -case object ObjectCodeWrapper extends CodeWrapper { +case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper { override def mainClassObject(className: Name): Name = Name(className.raw ++ "_sc") @@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper { extraCode: String, scriptPath: String ) = { + val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code) val name = mainClassObject(indexedWrapperName).backticked val aliasedWrapperName = name + "$$alias" - val funHashCodeMethod = + val realScript = if (name == "main_sc") - s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314 - else s"${indexedWrapperName.backticked}.hashCode()" + s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314 + else s"${indexedWrapperName.backticked}" + + val funHashCodeMethod = mainObject match + case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)" + case otherwise => + otherwise.warningMessage.foreach(log) + s"val _ = $realScript.hashCode()" // We need to call hashCode (or any other method so compiler does not report a warning) val mainObjectCode = AmmUtil.normalizeNewlines(s"""|object $name { @@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper { | } | def main(args: Array[String]): Unit = { | args$$set(args) - | val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position + | $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position | } |} |""".stripMargin) diff --git a/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala b/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala new file mode 100644 index 0000000000..7197c0d48a --- /dev/null +++ b/modules/build/src/main/scala/scala/build/internal/WrapperUtils.scala @@ -0,0 +1,88 @@ +package scala.build.internal + +import scala.build.internal.util.WarningMessages + +object WrapperUtils { + + enum ScriptMainMethod: + case Exists(name: String) + case Multiple(names: Seq[String]) + case ToplevelStatsPresent + case ToplevelStatsWithMultiple(names: Seq[String]) + case NoMain + + def warningMessage: List[String] = + this match + case ScriptMainMethod.Multiple(names) => + List(WarningMessages.multipleMainObjectsInScript(names)) + case ScriptMainMethod.ToplevelStatsPresent => List( + WarningMessages.mixedToplvelAndObjectInScript + ) + case ToplevelStatsWithMultiple(names) => + List( + WarningMessages.multipleMainObjectsInScript(names), + WarningMessages.mixedToplvelAndObjectInScript + ) + case _ => Nil + + def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod = + import scala.meta.* + + val scriptDialect = + if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3 + + given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true) + val parsedCode = code.parse[Source] match + case Parsed.Success(Source(stats)) => stats + case _ => Nil + + // Check if there is a main function defined inside an object + def checkSignature(defn: Defn.Def) = + defn.paramClauseGroups match + case List(Member.ParamClauseGroup( + Type.ParamClause(Nil), + List(Term.ParamClause( + List(Term.Param( + Nil, + _: Term.Name, + Some(Type.Apply.After_4_6_0( + Type.Name("Array"), + Type.ArgClause(List(Type.Name("String"))) + )), + None + )), + None + )) + )) => true + case _ => false + + def noToplevelStatements = parsedCode.forall { + case _: Term => false + case _ => true + } + + def hasMainSignature(templ: Template) = templ.body.stats.exists { + case defn: Defn.Def => + defn.name.value == "main" && checkSignature(defn) + case _ => false + } + def extendsApp(templ: Template) = templ.inits match + case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true + case _ => false + val potentialMains = parsedCode.collect { + case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) => + Seq(objName.value) + }.flatten + + potentialMains match + case head :: Nil if noToplevelStatements => + ScriptMainMethod.Exists(head) + case head :: Nil => + ScriptMainMethod.ToplevelStatsPresent + case Nil => ScriptMainMethod.NoMain + case seq if noToplevelStatements => + ScriptMainMethod.Multiple(seq) + case seq => + ScriptMainMethod.ToplevelStatsWithMultiple(seq) + +} diff --git a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala index 271c31bf15..9f74e9c212 100644 --- a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala +++ b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala @@ -105,6 +105,12 @@ object WarningMessages { val offlineModeBloopJvmNotFound = "Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback" + def multipleMainObjectsInScript(names: Seq[String]) = + s"Only a single main is allowed within scripts. Multiple main classes were found in the script: ${names.mkString(", ")}" + + def mixedToplvelAndObjectInScript = + "Script contains objects with main methods and top-level statements, only the latter will be run." + def directivesInMultipleFilesWarning( projectFilePath: String, pathsToReport: Iterable[String] = Nil diff --git a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala index 3d3633f303..c0c4d24a16 100644 --- a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala +++ b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala @@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor { (codeWrapper: CodeWrapper) => if (containsMainAnnot) logger.diagnostic( codeWrapper match { - case _: AppCodeWrapper.type => + case _: AppCodeWrapper => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true) case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false) } @@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor { * @return * code wrapper compatible with provided BuildOptions */ - def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = { + def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = { val effectiveScalaVersion = buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt) .orElse(buildOptions.scalaOptions.defaultScalaVersion) .getOrElse(Constants.defaultScalaVersion) + def logWarning(msg: String) = logger.diagnostic(msg) def objectCodeWrapperForScalaVersion = // AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3 - if effectiveScalaVersion.startsWith("2") then AppCodeWrapper - else ObjectCodeWrapper + if effectiveScalaVersion.startsWith("2") then + AppCodeWrapper(effectiveScalaVersion, logWarning) + else ObjectCodeWrapper(effectiveScalaVersion, logWarning) buildOptions.scriptOptions.forceObjectWrapper match { case Some(true) => objectCodeWrapperForScalaVersion case _ => buildOptions.scalaOptions.platform.map(_.value) match { - case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion - case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper - case _ => ClassCodeWrapper + case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion + case _ if effectiveScalaVersion.startsWith("2") => + AppCodeWrapper(effectiveScalaVersion, logWarning) + case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning) } } } diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala index e75c63060b..4bbc867003 100644 --- a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala +++ b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala @@ -69,6 +69,163 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions => } } + test("main.sc has an object with a main method") { + val message = "Hello" + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""| + |object Main { + | def main(args: Array[String]): Unit = println("$message") + |} + |""".stripMargin + ) + inputs.fromRoot { root => + val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd = + root + ).out.trim() + expect(output == message) + } + } + test("main.sc has an object that extends App") { + val message = "Hello" + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""| + |object Main extends App{ + | println("$message") + |} + | + |object Other {} + |""".stripMargin + ) + inputs.fromRoot { root => + val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd = + root + ).out.trim() + expect(output == message) + } + } + + test("main.sc has an object with a main method and an object wrapper") { + val message = "Hello" + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""|//> using objectWrapper + |object Main { + | def main(args: Array[String]): Unit = println("$message") + |} + |""".stripMargin + ) + inputs.fromRoot { root => + val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(cwd = + root + ).out.trim() + expect(output == message) + } + } + + test("main.sc has multiple main methods") { + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""|//> using objectWrapper + |object Main { + | def main(args: Array[String]): Unit = println("1") + |} + |object AnotherMain { + | def main(args: Array[String]): Unit = println("2") + |} + |""".stripMargin + ) + inputs.fromRoot { root => + val result = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call( + cwd = root, + stderr = os.Pipe + ) + val output = result.out.trim() + val err = result.err.trim() + expect(output == "") + expect(err.contains( + "Only a single main is allowed within scripts. Multiple main classes were found in the script: Main, AnotherMain" + )) + } + } + test("main.sc has multiple main methods and top-level definitions") { + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""|//> using objectWrapper + |object Main { + | def main(args: Array[String]): Unit = println("1") + |} + |object AnotherMain { + | def main(args: Array[String]): Unit = println("2") + |} + | + |println("3") + |""".stripMargin + ) + inputs.fromRoot { root => + val result = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call( + cwd = root, + stderr = os.Pipe + ) + val output = result.out.trim() + val err = result.err.trim() + expect(output == "3") + expect(err.contains( + "Only a single main is allowed within scripts. Multiple main classes were found in the script: Main, AnotherMain" + )) + expect(err.contains( + "Script contains objects with main methods and top-level statements, only the latter will be run." + )) + } + } + + test("main.sc has both an object with a main method as well as top-level definitions") { + val message1 = "Hello" + val message2 = "Another hello" + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""|object Main { + | def main(args: Array[String]): Unit = println("$message1") + |} + |println("$message2") + |""".stripMargin + ) + inputs.fromRoot { root => + val result = os.proc(TestUtil.cli, extraOptions, "main.sc").call( + cwd = root, + stderr = os.Pipe + ) + val output = result.out.trim() + val err = result.err.trim() + expect(output == message2) + expect(err.contains( + "Script contains objects with main methods and top-level statements, only the latter will be run." + )) + expect(output == message2) + } + } + + test( + "main.sc has both an object with a main method and an object wrapper as well as top-level calls" + ) { + val message1 = "Hello" + val message2 = "Another hello" + val inputs = TestInputs( + os.rel / "main.sc" -> + s"""|//> using objectWrapper + |object Main { + | def main(args: Array[String]): Unit = println("$message1") + |} + |println("$message2") + |""".stripMargin + ) + inputs.fromRoot { root => + val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc") + .call(cwd = root).out.trim() + expect(output == message2) + } + } if (actualScalaVersion.startsWith("3")) test("use method from main.sc file") { val message = "Hello" diff --git a/project/deps.sc b/project/deps.sc index a48bd4fa32..35552d0a94 100644 --- a/project/deps.sc +++ b/project/deps.sc @@ -117,7 +117,7 @@ object Deps { def jsoniterScala = "2.23.2" def jsoniterScalaJava8 = "2.13.5.2" def jsoup = "1.18.3" - def scalaMeta = "4.9.9" + def scalaMeta = "4.12.7" def scalaNative04 = "0.4.17" def scalaNative05 = "0.5.6" def scalaNative = scalaNative05 @@ -227,7 +227,7 @@ object Deps { def semanticDbJavac = ivy"com.sourcegraph:semanticdb-javac:${Versions.javaSemanticdb}" def semanticDbScalac = ivy"org.scalameta:::semanticdb-scalac:${Versions.scalaMeta}" def scalametaSemanticDbShared = - ivy"org.scalameta:semanticdb-shared_${Scala.scala213}:${Versions.scalaMeta}" + ivy"org.scalameta:semanticdb-shared_2.13:${Versions.scalaMeta}" .exclude("org.jline" -> "jline") // to prevent incompatibilities with GraalVM <23 .exclude("com.lihaoyi" -> "sourcecode_2.13") .exclude("org.scala-lang.modules" -> "scala-collection-compat_2.13")