|
| 1 | +package scala.build.internal |
| 2 | + |
| 3 | +import scala.build.internal.util.WarningMessages |
| 4 | + |
| 5 | +object WrapperUtils { |
| 6 | + |
| 7 | + enum ScriptMainMethod: |
| 8 | + case Exists(name: String) |
| 9 | + case Multiple(names: Seq[String]) |
| 10 | + case ToplevelStatsPresent |
| 11 | + case ToplevelStatsWithMultiple(names: Seq[String]) |
| 12 | + case NoMain |
| 13 | + |
| 14 | + def warningMessage: List[String] = |
| 15 | + this match |
| 16 | + case ScriptMainMethod.Multiple(names) => |
| 17 | + List(WarningMessages.multipleMainObjectsInScript(names)) |
| 18 | + case ScriptMainMethod.ToplevelStatsPresent => List( |
| 19 | + WarningMessages.mixedToplvelAndObjectInScript |
| 20 | + ) |
| 21 | + case ToplevelStatsWithMultiple(names) => |
| 22 | + List( |
| 23 | + WarningMessages.multipleMainObjectsInScript(names), |
| 24 | + WarningMessages.mixedToplvelAndObjectInScript |
| 25 | + ) |
| 26 | + case _ => Nil |
| 27 | + |
| 28 | + def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod = |
| 29 | + import scala.meta.* |
| 30 | + |
| 31 | + val scriptDialect = |
| 32 | + if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3 |
| 33 | + |
| 34 | + given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true) |
| 35 | + val parsedCode = code.parse[Source] match |
| 36 | + case Parsed.Success(Source(stats)) => stats |
| 37 | + case _ => Nil |
| 38 | + |
| 39 | + // Check if there is a main function defined inside an object |
| 40 | + def checkSignature(defn: Defn.Def) = |
| 41 | + defn.paramClauseGroups match |
| 42 | + case List(Member.ParamClauseGroup( |
| 43 | + Type.ParamClause(Nil), |
| 44 | + List(Term.ParamClause( |
| 45 | + List(Term.Param( |
| 46 | + Nil, |
| 47 | + _: Term.Name, |
| 48 | + Some(Type.Apply.After_4_6_0( |
| 49 | + Type.Name("Array"), |
| 50 | + Type.ArgClause(List(Type.Name("String"))) |
| 51 | + )), |
| 52 | + None |
| 53 | + )), |
| 54 | + None |
| 55 | + )) |
| 56 | + )) => true |
| 57 | + case _ => false |
| 58 | + |
| 59 | + def noToplevelStatements = parsedCode.forall { |
| 60 | + case _: Term => false |
| 61 | + case _ => true |
| 62 | + } |
| 63 | + |
| 64 | + def hasMainSignature(templ: Template) = templ.body.stats.exists { |
| 65 | + case defn: Defn.Def => |
| 66 | + defn.name.value == "main" && checkSignature(defn) |
| 67 | + case _ => false |
| 68 | + } |
| 69 | + def extendsApp(templ: Template) = templ.inits match |
| 70 | + case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true |
| 71 | + case _ => false |
| 72 | + val potentialMains = parsedCode.collect { |
| 73 | + case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) => |
| 74 | + Seq(objName.value) |
| 75 | + }.flatten |
| 76 | + |
| 77 | + potentialMains match |
| 78 | + case head :: Nil if noToplevelStatements => |
| 79 | + ScriptMainMethod.Exists(head) |
| 80 | + case head :: Nil => |
| 81 | + ScriptMainMethod.ToplevelStatsPresent |
| 82 | + case Nil => ScriptMainMethod.NoMain |
| 83 | + case seq if noToplevelStatements => |
| 84 | + ScriptMainMethod.Multiple(seq) |
| 85 | + case seq => |
| 86 | + ScriptMainMethod.ToplevelStatsWithMultiple(seq) |
| 87 | + |
| 88 | +} |
0 commit comments