Skip to content

improvement: Detect objects with main class in scripts #3479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modules/build/src/main/scala/scala/build/ScopedSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ final case class ScopedSources(
): Either[BuildException, Sources] = either {
val combinedOptions = combinedBuildOptions(scope, baseOptions)

val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions)
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger)

val wrappedScripts = unwrappedScripts
.flatMap(_.valueFor(scope).toSeq)
.map(_.wrap(codeWrapper))

codeWrapper match {
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
case _: AppCodeWrapper if wrappedScripts.size > 1 =>
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
case _ => ()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala.build.internal

case object AppCodeWrapper extends CodeWrapper {
case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
override def mainClassObject(className: Name) = className

def apply(
Expand All @@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper {
) = {
val wrapperObjectName = indexedWrapperName.backticked

val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val invokeMain = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
""
val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|object $wrapperObjectName extends App {
|val scriptPath = \"\"\"$scriptPath\"\"\"
|val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
|""".stripMargin
)
val bottom = AmmUtil.normalizeNewlines(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package scala.build.internal
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {
case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
Expand All @@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {

val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val mainInvocation = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
s"val _ = script.hashCode()"

val name = mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
| private var args$$opt0 = Option.empty[Array[String]]
Expand All @@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper {
|
| def main(args: Array[String]): Unit = {
| args$$set(args)
| val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
| $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package scala.build.internal
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
* threads from script
*/
case object ObjectCodeWrapper extends CodeWrapper {
case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
Expand All @@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val name = mainClassObject(indexedWrapperName).backticked
val aliasedWrapperName = name + "$$alias"
val funHashCodeMethod =
val realScript =
if (name == "main_sc")
s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
else s"${indexedWrapperName.backticked}.hashCode()"
s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
else s"${indexedWrapperName.backticked}"

val funHashCodeMethod = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
s"val _ = $realScript.hashCode()"
// We need to call hashCode (or any other method so compiler does not report a warning)
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
Expand All @@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
| }
| def main(args: Array[String]): Unit = {
| args$$set(args)
| val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
| $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package scala.build.internal

import scala.build.internal.util.WarningMessages

object WrapperUtils {

enum ScriptMainMethod:
case Exists(name: String)
case Multiple(names: Seq[String])
case ToplevelStatsPresent
case ToplevelStatsWithMultiple(names: Seq[String])
case NoMain

def warningMessage: List[String] =
this match
case ScriptMainMethod.Multiple(names) =>
List(WarningMessages.multipleMainObjectsInScript(names))
case ScriptMainMethod.ToplevelStatsPresent => List(
WarningMessages.mixedToplvelAndObjectInScript
)
case ToplevelStatsWithMultiple(names) =>
List(
WarningMessages.multipleMainObjectsInScript(names),
WarningMessages.mixedToplvelAndObjectInScript
)
case _ => Nil

def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod =
import scala.meta.*

val scriptDialect =
if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3

given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
val parsedCode = code.parse[Source] match
case Parsed.Success(Source(stats)) => stats
case _ => Nil

// Check if there is a main function defined inside an object
def checkSignature(defn: Defn.Def) =
defn.paramClauseGroups match
case List(Member.ParamClauseGroup(
Type.ParamClause(Nil),
List(Term.ParamClause(
List(Term.Param(
Nil,
_: Term.Name,
Some(Type.Apply.After_4_6_0(
Type.Name("Array"),
Type.ArgClause(List(Type.Name("String")))
)),
None
)),
None
))
)) => true
case _ => false

def noToplevelStatements = parsedCode.forall {
case _: Term => false
case _ => true
}

def hasMainSignature(templ: Template) = templ.body.stats.exists {
case defn: Defn.Def =>
defn.name.value == "main" && checkSignature(defn)
case _ => false
}
def extendsApp(templ: Template) = templ.inits match
case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true
case _ => false
val potentialMains = parsedCode.collect {
case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) =>
Seq(objName.value)
}.flatten

potentialMains match
case head :: Nil if noToplevelStatements =>
ScriptMainMethod.Exists(head)
case head :: Nil =>
ScriptMainMethod.ToplevelStatsPresent
case Nil => ScriptMainMethod.NoMain
case seq if noToplevelStatements =>
ScriptMainMethod.Multiple(seq)
case seq =>
ScriptMainMethod.ToplevelStatsWithMultiple(seq)

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ object WarningMessages {
val offlineModeBloopJvmNotFound =
"Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback"

def multipleMainObjectsInScript(names: Seq[String]) =
s"Only a single main is allowed within scripts. Multiple main classes were found in the script: ${names.mkString(", ")}"

def mixedToplvelAndObjectInScript =
"Script contains objects with main methods and top-level statements, only the latter will be run."

def directivesInMultipleFilesWarning(
projectFilePath: String,
pathsToReport: Iterable[String] = Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
(codeWrapper: CodeWrapper) =>
if (containsMainAnnot) logger.diagnostic(
codeWrapper match {
case _: AppCodeWrapper.type =>
case _: AppCodeWrapper =>
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
}
Expand All @@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
* @return
* code wrapper compatible with provided BuildOptions
*/
def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = {
val effectiveScalaVersion =
buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt)
.orElse(buildOptions.scalaOptions.defaultScalaVersion)
.getOrElse(Constants.defaultScalaVersion)
def logWarning(msg: String) = logger.diagnostic(msg)

def objectCodeWrapperForScalaVersion =
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
else ObjectCodeWrapper
if effectiveScalaVersion.startsWith("2") then
AppCodeWrapper(effectiveScalaVersion, logWarning)
else ObjectCodeWrapper(effectiveScalaVersion, logWarning)

buildOptions.scriptOptions.forceObjectWrapper match {
case Some(true) => objectCodeWrapperForScalaVersion
case _ =>
buildOptions.scalaOptions.platform.map(_.value) match {
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
case _ => ClassCodeWrapper
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
case _ if effectiveScalaVersion.startsWith("2") =>
AppCodeWrapper(effectiveScalaVersion, logWarning)
case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning)
}
}
}
Expand Down
Loading
Loading