Skip to content

Commit e4df95e

Browse files
authored
improvement: Detect objects with main class in scripts (#3479)
* improvement: Detect objects with main class in scripts Prebiously, if user had a legacy script with main method then it would not be picked up at all. Now, when we detect the correct signature we try to run it. This will work in case of `def main..` and when object extends App The possibility of false positives is pretty low, since user would have to have their own App, String or Array types. We will also only use that object if there are no toplevel statements * improvement: Print both warnings for scripts with main method defined
1 parent cc549b7 commit e4df95e

File tree

9 files changed

+296
-21
lines changed

9 files changed

+296
-21
lines changed

Diff for: modules/build/src/main/scala/scala/build/ScopedSources.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ final case class ScopedSources(
6262
): Either[BuildException, Sources] = either {
6363
val combinedOptions = combinedBuildOptions(scope, baseOptions)
6464

65-
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions)
65+
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger)
6666

6767
val wrappedScripts = unwrappedScripts
6868
.flatMap(_.valueFor(scope).toSeq)
6969
.map(_.wrap(codeWrapper))
7070

7171
codeWrapper match {
72-
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
72+
case _: AppCodeWrapper if wrappedScripts.size > 1 =>
7373
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
7474
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
7575
case _ => ()

Diff for: modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala.build.internal
22

3-
case object AppCodeWrapper extends CodeWrapper {
3+
case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
44
override def mainClassObject(className: Name) = className
55

66
def apply(
@@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper {
1212
) = {
1313
val wrapperObjectName = indexedWrapperName.backticked
1414

15+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
16+
val invokeMain = mainObject match
17+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)"
18+
case otherwise =>
19+
otherwise.warningMessage.foreach(log)
20+
""
1521
val packageDirective =
1622
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
1723
val top = AmmUtil.normalizeNewlines(
1824
s"""$packageDirective
1925
|
2026
|object $wrapperObjectName extends App {
21-
|val scriptPath = \"\"\"$scriptPath\"\"\"
27+
|val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
2228
|""".stripMargin
2329
)
2430
val bottom = AmmUtil.normalizeNewlines(

Diff for: modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package scala.build.internal
55
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
66
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
77
*/
8-
case object ClassCodeWrapper extends CodeWrapper {
8+
case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
99

1010
override def mainClassObject(className: Name): Name =
1111
Name(className.raw ++ "_sc")
@@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper {
1616
extraCode: String,
1717
scriptPath: String
1818
) = {
19+
20+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
21+
val mainInvocation = mainObject match
22+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)"
23+
case otherwise =>
24+
otherwise.warningMessage.foreach(log)
25+
s"val _ = script.hashCode()"
26+
1927
val name = mainClassObject(indexedWrapperName).backticked
20-
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
28+
val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
2129
val mainObjectCode =
2230
AmmUtil.normalizeNewlines(s"""|object $name {
2331
| private var args$$opt0 = Option.empty[Array[String]]
@@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper {
3341
|
3442
| def main(args: Array[String]): Unit = {
3543
| args$$set(args)
36-
| val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
44+
| $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
3745
| }
3846
|}
3947
|

Diff for: modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala

+12-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package scala.build.internal
44
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
55
* threads from script
66
*/
7-
case object ObjectCodeWrapper extends CodeWrapper {
7+
case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
88

99
override def mainClassObject(className: Name): Name =
1010
Name(className.raw ++ "_sc")
@@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper {
1515
extraCode: String,
1616
scriptPath: String
1717
) = {
18+
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
1819
val name = mainClassObject(indexedWrapperName).backticked
1920
val aliasedWrapperName = name + "$$alias"
20-
val funHashCodeMethod =
21+
val realScript =
2122
if (name == "main_sc")
22-
s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
23-
else s"${indexedWrapperName.backticked}.hashCode()"
23+
s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
24+
else s"${indexedWrapperName.backticked}"
25+
26+
val funHashCodeMethod = mainObject match
27+
case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)"
28+
case otherwise =>
29+
otherwise.warningMessage.foreach(log)
30+
s"val _ = $realScript.hashCode()"
2431
// We need to call hashCode (or any other method so compiler does not report a warning)
2532
val mainObjectCode =
2633
AmmUtil.normalizeNewlines(s"""|object $name {
@@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
3441
| }
3542
| def main(args: Array[String]): Unit = {
3643
| args$$set(args)
37-
| val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
44+
| $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
3845
| }
3946
|}
4047
|""".stripMargin)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package scala.build.internal
2+
3+
import scala.build.internal.util.WarningMessages
4+
5+
object WrapperUtils {
6+
7+
enum ScriptMainMethod:
8+
case Exists(name: String)
9+
case Multiple(names: Seq[String])
10+
case ToplevelStatsPresent
11+
case ToplevelStatsWithMultiple(names: Seq[String])
12+
case NoMain
13+
14+
def warningMessage: List[String] =
15+
this match
16+
case ScriptMainMethod.Multiple(names) =>
17+
List(WarningMessages.multipleMainObjectsInScript(names))
18+
case ScriptMainMethod.ToplevelStatsPresent => List(
19+
WarningMessages.mixedToplvelAndObjectInScript
20+
)
21+
case ToplevelStatsWithMultiple(names) =>
22+
List(
23+
WarningMessages.multipleMainObjectsInScript(names),
24+
WarningMessages.mixedToplvelAndObjectInScript
25+
)
26+
case _ => Nil
27+
28+
def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod =
29+
import scala.meta.*
30+
31+
val scriptDialect =
32+
if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3
33+
34+
given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
35+
val parsedCode = code.parse[Source] match
36+
case Parsed.Success(Source(stats)) => stats
37+
case _ => Nil
38+
39+
// Check if there is a main function defined inside an object
40+
def checkSignature(defn: Defn.Def) =
41+
defn.paramClauseGroups match
42+
case List(Member.ParamClauseGroup(
43+
Type.ParamClause(Nil),
44+
List(Term.ParamClause(
45+
List(Term.Param(
46+
Nil,
47+
_: Term.Name,
48+
Some(Type.Apply.After_4_6_0(
49+
Type.Name("Array"),
50+
Type.ArgClause(List(Type.Name("String")))
51+
)),
52+
None
53+
)),
54+
None
55+
))
56+
)) => true
57+
case _ => false
58+
59+
def noToplevelStatements = parsedCode.forall {
60+
case _: Term => false
61+
case _ => true
62+
}
63+
64+
def hasMainSignature(templ: Template) = templ.body.stats.exists {
65+
case defn: Defn.Def =>
66+
defn.name.value == "main" && checkSignature(defn)
67+
case _ => false
68+
}
69+
def extendsApp(templ: Template) = templ.inits match
70+
case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true
71+
case _ => false
72+
val potentialMains = parsedCode.collect {
73+
case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) =>
74+
Seq(objName.value)
75+
}.flatten
76+
77+
potentialMains match
78+
case head :: Nil if noToplevelStatements =>
79+
ScriptMainMethod.Exists(head)
80+
case head :: Nil =>
81+
ScriptMainMethod.ToplevelStatsPresent
82+
case Nil => ScriptMainMethod.NoMain
83+
case seq if noToplevelStatements =>
84+
ScriptMainMethod.Multiple(seq)
85+
case seq =>
86+
ScriptMainMethod.ToplevelStatsWithMultiple(seq)
87+
88+
}

Diff for: modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala

+6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ object WarningMessages {
105105
val offlineModeBloopJvmNotFound =
106106
"Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback"
107107

108+
def multipleMainObjectsInScript(names: Seq[String]) =
109+
s"Only a single main is allowed within scripts. Multiple main classes were found in the script: ${names.mkString(", ")}"
110+
111+
def mixedToplvelAndObjectInScript =
112+
"Script contains objects with main methods and top-level statements, only the latter will be run."
113+
108114
def directivesInMultipleFilesWarning(
109115
projectFilePath: String,
110116
pathsToReport: Iterable[String] = Nil

Diff for: modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala

+10-7
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
134134
(codeWrapper: CodeWrapper) =>
135135
if (containsMainAnnot) logger.diagnostic(
136136
codeWrapper match {
137-
case _: AppCodeWrapper.type =>
137+
case _: AppCodeWrapper =>
138138
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
139139
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
140140
}
@@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
157157
* @return
158158
* code wrapper compatible with provided BuildOptions
159159
*/
160-
def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
160+
def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = {
161161
val effectiveScalaVersion =
162162
buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt)
163163
.orElse(buildOptions.scalaOptions.defaultScalaVersion)
164164
.getOrElse(Constants.defaultScalaVersion)
165+
def logWarning(msg: String) = logger.diagnostic(msg)
165166

166167
def objectCodeWrapperForScalaVersion =
167168
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
168-
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
169-
else ObjectCodeWrapper
169+
if effectiveScalaVersion.startsWith("2") then
170+
AppCodeWrapper(effectiveScalaVersion, logWarning)
171+
else ObjectCodeWrapper(effectiveScalaVersion, logWarning)
170172

171173
buildOptions.scriptOptions.forceObjectWrapper match {
172174
case Some(true) => objectCodeWrapperForScalaVersion
173175
case _ =>
174176
buildOptions.scalaOptions.platform.map(_.value) match {
175-
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
176-
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
177-
case _ => ClassCodeWrapper
177+
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
178+
case _ if effectiveScalaVersion.startsWith("2") =>
179+
AppCodeWrapper(effectiveScalaVersion, logWarning)
180+
case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning)
178181
}
179182
}
180183
}

0 commit comments

Comments
 (0)