From 93f0484857a65cfb8c1c4492101aff1e9a725eb8 Mon Sep 17 00:00:00 2001 From: Albert Meltzer Date: Fri, 24 Dec 2021 08:26:30 -0800 Subject: [PATCH] Support cross- and multi-source projects with git --- .../org/scalafmt/sbt/ScalafmtPlugin.scala | 80 +++++++++++-------- plugin/src/sbt-test/scalafmt-sbt/sbt/test | 2 +- 2 files changed, 48 insertions(+), 34 deletions(-) diff --git a/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala b/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala index c987b84..c4e31e6 100644 --- a/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala +++ b/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala @@ -176,15 +176,17 @@ object ScalafmtPlugin extends AutoPlugin { @inline private def asRelative(file: File): String = baseDir.relativize(file.getCanonicalFile.toPath).toString - private def filterFiles(sources: Seq[File]): Seq[File] = { - val filter = getFileFilter() + private def filterFiles(sources: Seq[File], dirs: Seq[File]): Seq[File] = { + val filter = getFileFilter(dirs) sources.map(_.getCanonicalFile).distinct.filter { file => val path = file.toPath scalafmtSession.matchesProjectFilters(path) && filter(path) } } - private def getFileFilter(): Path => Boolean = { + private def getFileFilter(dirs: Seq[File]): Path => Boolean = { + // dirs don't have to be within baseDir but within the same git tree + def absDirs = dirs.map(x => AbsoluteFile(x.getCanonicalFile.toPath)) def gitOps = GitOps.FactoryImpl(AbsoluteFile(baseDir)) def getFromFiles(getFiles: => Seq[AbsoluteFile], gitCmd: => String) = { def gitMessage = s"[git $gitCmd] ($baseDir)" @@ -199,12 +201,12 @@ object ScalafmtPlugin extends AutoPlugin { } if (filterMode == FilterMode.diffDirty) - getFromFiles(gitOps.status(), "status") + getFromFiles(gitOps.status(absDirs: _*), "status") else if (filterMode.startsWith(FilterMode.diffRefPrefix)) { val branch = filterMode.substring(FilterMode.diffRefPrefix.length) - getFromFiles(gitOps.diff(branch), s"diff $branch") + getFromFiles(gitOps.diff(branch, absDirs: _*), s"diff $branch") } else if (filterMode != FilterMode.none && scalafmtSession.isGitOnly) - getFromFiles(gitOps.lsTree(), "ls-files") + getFromFiles(gitOps.lsTree(absDirs: _*), "ls-files") else { log.debug("considering all files (no git)") _ => true @@ -244,8 +246,8 @@ object ScalafmtPlugin extends AutoPlugin { res } - def formatTrackedSources(sources: Seq[File]): Unit = { - val filteredSources = filterFiles(sources) + def formatTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = { + val filteredSources = filterFiles(sources, dirs) trackSourcesAndConfig(cacheStoreFactory, filteredSources) { (outDiff, configChanged, prev) => val filesToFormat: Seq[File] = @@ -261,8 +263,8 @@ object ScalafmtPlugin extends AutoPlugin { } } - def formatSources(sources: Seq[File]): Unit = - formatFilteredSources(filterFiles(sources)) + def formatSources(sources: Seq[File], dirs: Seq[File]): Unit = + formatFilteredSources(filterFiles(sources, dirs)) private def formatFilteredSources(sources: Seq[File]): Unit = { if (sources.nonEmpty) @@ -274,8 +276,8 @@ object ScalafmtPlugin extends AutoPlugin { if (cnt > 0) log.info(s"Reformatted $cnt Scala sources") } - def checkTrackedSources(sources: Seq[File]): Unit = { - val filteredSources = filterFiles(sources) + def checkTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = { + val filteredSources = filterFiles(sources, dirs) val result = trackSourcesAndConfig(cacheStoreFactory, filteredSources) { (outDiff, configChanged, prev) => val filesToCheck: Seq[File] = @@ -300,8 +302,8 @@ object ScalafmtPlugin extends AutoPlugin { throwOnFailure(result) } - def checkSources(sources: Seq[File]): Unit = - throwOnFailure(checkFilteredSources(filterFiles(sources))) + def checkSources(sources: Seq[File], dirs: Seq[File]): Unit = + throwOnFailure(checkFilteredSources(filterFiles(sources, dirs))) private def checkFilteredSources(sources: Seq[File]): ScalafmtAnalysis = { if (sources.nonEmpty) { @@ -393,57 +395,69 @@ object ScalafmtPlugin extends AutoPlugin { } } - private def scalafmtTask(sources: Seq[File], session: FormatSession) = + private def scalafmtTask( + sources: Seq[File], + dirs: Seq[File], + session: FormatSession + ) = Def.task { - session.formatTrackedSources(sources) + session.formatTrackedSources(sources, dirs) } tag (ScalafmtTagPack: _*) - private def scalafmtCheckTask(sources: Seq[File], session: FormatSession) = + private def scalafmtCheckTask( + sources: Seq[File], + dirs: Seq[File], + session: FormatSession + ) = Def.task { - session.checkTrackedSources(sources) + session.checkTrackedSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def getScalafmtSourcesTask( - f: (Seq[File], FormatSession) => InitTask + f: (Seq[File], Seq[File], FormatSession) => InitTask ) = Def.taskDyn[Unit] { val sources = (unmanagedSources in scalafmt).?.value.getOrElse(Seq.empty) - getScalafmtTask(f)(sources, scalaConfig.value) + val dirs = (unmanagedSourceDirectories in scalafmt).?.value.getOrElse(Nil) + getScalafmtTask(f)(sources, dirs, scalaConfig.value) } private def scalafmtSbtTask( sources: Seq[File], + dirs: Seq[File], session: FormatSession ) = Def.task { - session.formatSources(sources) + session.formatSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def scalafmtSbtCheckTask( sources: Seq[File], + dirs: Seq[File], session: FormatSession ) = Def.task { - session.checkSources(sources) + session.checkSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def getScalafmtSbtTasks( - func: (Seq[File], FormatSession) => InitTask + func: (Seq[File], Seq[File], FormatSession) => InitTask ) = Def.taskDyn { joinScalafmtTasks(func)( - (sbtSources.value, sbtConfig.value), - (metabuildSources.value, scalaConfig.value) + (sbtSources.value, Nil, sbtConfig.value), + (metabuildSources.value, Nil, scalaConfig.value) ) } private def joinScalafmtTasks( - func: (Seq[File], FormatSession) => InitTask - )(tuples: (Seq[File], Path)*) = { - val tasks = tuples - .map { case (files, config) => getScalafmtTask(func)(files, config) } + func: (Seq[File], Seq[File], FormatSession) => InitTask + )(tuples: (Seq[File], Seq[File], Path)*) = { + val tasks = tuples.map { case (files, dirs, config) => + getScalafmtTask(func)(files, dirs, config) + } Def.sequential(tasks.tail.toList, tasks.head) } private def getScalafmtTask( - func: (Seq[File], FormatSession) => InitTask - )(files: Seq[File], config: Path) = Def.taskDyn[Unit] { + func: (Seq[File], Seq[File], FormatSession) => InitTask + )(files: Seq[File], dirs: Seq[File], config: Path) = Def.taskDyn[Unit] { if (files.isEmpty) Def.task(Unit) else { val session = new FormatSession( @@ -460,7 +474,7 @@ object ScalafmtPlugin extends AutoPlugin { scalafmtDetailedError.value ) ) - func(files, session) + func(files, dirs, session) } } @@ -505,7 +519,7 @@ object ScalafmtPlugin extends AutoPlugin { scalafmtFailOnErrors.value, scalafmtDetailedError.value ) - ).formatSources(absFiles) + ).formatSources(absFiles, Nil) } ) diff --git a/plugin/src/sbt-test/scalafmt-sbt/sbt/test b/plugin/src/sbt-test/scalafmt-sbt/sbt/test index 3e73192..a305e2f 100644 --- a/plugin/src/sbt-test/scalafmt-sbt/sbt/test +++ b/plugin/src/sbt-test/scalafmt-sbt/sbt/test @@ -202,7 +202,7 @@ $ exec git -C p19 add "jvm/src/main/scala/TestGood.scala" > p19/scalafmtCheck $ copy-file changes/invalid.scala p19/shared/src/main/scala/TestInvalid1.scala $ exec git -C p19 add "shared/src/main/scala/TestInvalid1.scala" -> p19/scalafmtCheck +-> p19/scalafmtCheck $ copy-file changes/target/managed.scala project/target/managed.scala $ copy-file changes/x/Something.scala project/x/Something.scala