Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add test cases discovery for TestNG #7200

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ lazy val metals = project
"org.scala-lang.modules" %% "scala-xml" % "2.3.0",
("org.virtuslab.scala-cli" % "scala-cli-bsp" % V.scalaCli)
.exclude("ch.epfl.scala", "bsp4j"),
// For test frameworks
"ch.epfl.scala" %% "bloop-config" % V.bloopConfig,
),
buildInfoPackage := "scala.meta.internal.metals",
buildInfoKeys := Seq[BuildInfoKey](
Expand Down Expand Up @@ -760,7 +762,6 @@ lazy val unit = project
Test / javaOptions += "-Xmx2G",
libraryDependencies ++= List(
"io.get-coursier" %% "coursier" % V.coursier, // for jars
"ch.epfl.scala" %% "bloop-config" % V.bloopConfig,
"org.scalameta" %% "munit" % V.munit,
),
buildInfoPackage := "tests",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import scala.meta.internal.metals.debug.BuildTargetClasses.TestSymbolInfo
import scala.meta.internal.semanticdb.Scala.Descriptor
import scala.meta.internal.semanticdb.Scala.Symbols

import bloop.config.Config.TestFramework
import ch.epfl.scala.{bsp4j => b}

/**
Expand Down Expand Up @@ -186,7 +187,7 @@ final class BuildTargetClasses(val buildTargets: BuildTargets)(implicit
)
} {
// item.getFramework() can return null!
val framework = TestFramework(Option(item.getFramework()))
val framework = TestFrameworkUtils.from(Option(item.getFramework()))
val testInfo = BuildTargetClasses.TestSymbolInfo(className, framework)
classes(target).testClasses.put(symbol, testInfo)
}
Expand Down Expand Up @@ -225,44 +226,32 @@ final class BuildTargetClasses(val buildTargets: BuildTargets)(implicit
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean) {
def names: List[String]
}
object TestFrameworkUtils {
val WeaverTestFramework: TestFramework = TestFramework(
List("weaver.framework.CatsEffect")
)
private lazy val supportedFrameworks = Set(
TestFramework.JUnit,
TestFramework.munit,
TestFramework.ScalaTest,
WeaverTestFramework,
TestFramework.TestNG,
)

object TestFramework {
def apply(framework: Option[String]): TestFramework = framework
def from(framework: Option[String]): TestFramework = framework
.map {
case "JUnit" => JUnit4
case "munit" => MUnit
case "ScalaTest" => Scalatest
case "weaver-cats-effect" => WeaverCatsEffect
case _ => Unknown
case "JUnit" => TestFramework.JUnit
case "munit" => TestFramework.munit
case "ScalaTest" => TestFramework.ScalaTest
case "weaver-cats-effect" => WeaverTestFramework
case "TestNG" => TestFramework.TestNG
case _ => TestFramework(Nil)
}
.getOrElse(Unknown)
}

case object JUnit4 extends TestFramework(true) {
def names: List[String] = List("com.novocode.junit.JUnitFramework")
}

case object MUnit extends TestFramework(true) {
def names: List[String] = List("munit.Framework")
}

case object Scalatest extends TestFramework(true) {
def names: List[String] =
List(
"org.scalatest.tools.Framework",
"org.scalatest.tools.ScalaTestFramework",
)
}
.getOrElse(TestFramework(Nil))

case object WeaverCatsEffect extends TestFramework(true) {
def names: List[String] = List("weaver.BaseCatsSuite")
}

case object Unknown extends TestFramework(false) {
def names: List[String] = Nil
def canResolveTests(framework: TestFramework): Boolean = supportedFrameworks(
framework
)
}

object BuildTargetClasses {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import scala.meta.internal.metals.testProvider.TestSuitesProvider
import scala.meta.internal.mtags.OnDemandSymbolIndex
import scala.meta.io.AbsolutePath

import bloop.config.Config
import ch.epfl.scala.bsp4j.BuildTargetIdentifier
import ch.epfl.scala.bsp4j.DebugSessionParams
import ch.epfl.scala.bsp4j.ScalaMainClass
Expand Down Expand Up @@ -392,7 +393,7 @@ class DebugProvider(
private def discoverTests(
id: BuildTargetIdentifier,
testClasses: b.ScalaTestSuites,
): Future[Map[TestFramework, List[Discovered]]] = {
): Future[Map[Config.TestFramework, List[Discovered]]] = {
val symbolInfosList =
for {
selection <- testClasses.getSuites().asScala.toList
Expand Down Expand Up @@ -595,7 +596,8 @@ class DebugProvider(
request.requestData.copy(
suites = request.requestData.suites.map { suite =>
testProvider.getFramework(buildTarget, suite) match {
case JUnit4 | MUnit =>
case Config.TestFramework.JUnit |
Config.TestFramework.munit =>
suite.copy(tests = suite.tests.map(escapeTestName))
case _ => suite
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import scala.concurrent.ExecutionContext

import scala.meta.internal.metals.JdkSources
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.debug.TestFramework
import scala.meta.internal.metals.debug.server.testing.FingerprintInfo
import scala.meta.internal.metals.debug.server.testing.LoggingEventHandler
import scala.meta.internal.metals.debug.server.testing.TestInternals
import scala.meta.internal.metals.debug.server.testing.TestServer
import scala.meta.io.AbsolutePath

import bloop.config.Config
import ch.epfl.scala.bsp4j.ScalaTestSuites
import ch.epfl.scala.debugadapter.CancelableFuture
import ch.epfl.scala.debugadapter.DebuggeeListener
Expand All @@ -31,7 +31,7 @@ class TestSuiteDebugAdapter(
testClasses: ScalaTestSuites,
project: DebugeeProject,
userJavaHome: Option[String],
discoveredTests: Map[TestFramework, List[Discovered]],
discoveredTests: Map[Config.TestFramework, List[Discovered]],
)(implicit ec: ExecutionContext)
extends MetalsDebuggee() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package scala.meta.internal.metals.testProvider
import scala.collection.concurrent.TrieMap

import scala.meta.internal.metals.debug.BuildTargetClasses
import scala.meta.internal.metals.debug.TestFramework
import scala.meta.internal.metals.debug.TestFrameworkUtils
import scala.meta.internal.metals.testProvider.TestExplorerEvent._
import scala.meta.internal.mtags
import scala.meta.io.AbsolutePath

import bloop.config.Config
import ch.epfl.scala.bsp4j.BuildTarget
import org.eclipse.{lsp4j => l}

Expand Down Expand Up @@ -36,7 +37,7 @@ private[testProvider] final case class TestEntry(

private[testProvider] final case class TestSuiteDetails(
fullyQualifiedName: FullyQualifiedName,
framework: TestFramework,
framework: Config.TestFramework,
className: ClassName,
symbol: mtags.Symbol,
location: l.Location,
Expand All @@ -46,7 +47,7 @@ private[testProvider] final case class TestSuiteDetails(
className = className.value,
symbol = symbol.value,
location = location,
canResolveChildren = framework.canResolveChildren,
canResolveChildren = TestFrameworkUtils.canResolveTests(framework),
)

def asRemoveEvent: TestExplorerEvent = RemoveTestSuite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@ import scala.meta.internal.metals.UserConfiguration
import scala.meta.internal.metals.clients.language.MetalsLanguageClient
import scala.meta.internal.metals.codelenses.CodeLens
import scala.meta.internal.metals.debug.BuildTargetClasses
import scala.meta.internal.metals.debug.JUnit4
import scala.meta.internal.metals.debug.MUnit
import scala.meta.internal.metals.debug.Scalatest
import scala.meta.internal.metals.debug.TestFramework
import scala.meta.internal.metals.debug.Unknown
import scala.meta.internal.metals.debug.WeaverCatsEffect
import scala.meta.internal.metals.debug.TestFrameworkUtils
import scala.meta.internal.metals.testProvider.TestExplorerEvent._
import scala.meta.internal.metals.testProvider.frameworks.JunitTestFinder
import scala.meta.internal.metals.testProvider.frameworks.MunitTestFinder
import scala.meta.internal.metals.testProvider.frameworks.ScalatestTestFinder
import scala.meta.internal.metals.testProvider.frameworks.TestNGTestFinder
import scala.meta.internal.metals.testProvider.frameworks.WeaverCatsEffectTestFinder
import scala.meta.internal.mtags
import scala.meta.internal.mtags.GlobalSymbolIndex
Expand All @@ -42,6 +38,7 @@ import scala.meta.internal.semanticdb.TextDocument
import scala.meta.internal.semanticdb.TextDocuments
import scala.meta.io.AbsolutePath

import bloop.config.Config
import ch.epfl.scala.bsp4j.BuildTarget
import ch.epfl.scala.bsp4j.ScalaPlatform
import ch.epfl.scala.{bsp4j => b}
Expand All @@ -65,6 +62,7 @@ final class TestSuitesProvider(

private val index = new TestSuitesIndex
private val junitTestFinder = new JunitTestFinder
private val testNGTestFinder = new TestNGTestFinder
private val munitTestFinder =
new MunitTestFinder(trees, symbolIndex, semanticdbs)
private val scalatestTestFinder =
Expand Down Expand Up @@ -278,7 +276,9 @@ final class TestSuitesProvider(
metadata <- index.getMetadata(path).toList
events = {
val suites = metadata.entries.map(_.suiteDetails).distinct
val canResolve = suites.exists(_.framework.canResolveChildren)
val canResolve = suites.exists(suite =>
TestFrameworkUtils.canResolveTests(suite.framework)
)
if (canResolve) getTestCasesForSuites(path, suites, textDocument)
else Seq.empty
}
Expand Down Expand Up @@ -307,34 +307,40 @@ final class TestSuitesProvider(
.map { semanticdb =>
suites.flatMap { suite =>
val testCases = suite.framework match {
case JUnit4 =>
case Config.TestFramework.JUnit =>
junitTestFinder.findTests(
doc = semanticdb,
path = path,
suiteSymbol = suite.symbol,
)
case MUnit =>
case Config.TestFramework.munit =>
munitTestFinder.findTests(
doc = semanticdb,
path = path,
suiteName = suite.fullyQualifiedName,
symbol = suite.symbol,
)
case Scalatest =>
case Config.TestFramework.ScalaTest =>
scalatestTestFinder.findTests(
doc = semanticdb,
path = path,
suiteName = suite.fullyQualifiedName,
symbol = suite.symbol,
)
case WeaverCatsEffect =>
case TestFrameworkUtils.WeaverTestFramework =>
weaverCatsEffect.findTests(
doc = semanticdb,
path = path,
suiteName = suite.fullyQualifiedName,
symbol = suite.symbol,
)
case Unknown => Vector.empty
case Config.TestFramework.TestNG =>
testNGTestFinder.findTests(
doc = semanticdb,
path = path,
suiteSymbol = suite.symbol,
)
case _ => Vector.empty
}

if (testCases.nonEmpty) {
Expand Down Expand Up @@ -385,7 +391,8 @@ final class TestSuitesProvider(
if (isExplorerEnabled) {
val addedTestCases = addedEntries.mapValues {
_.flatMap { entry =>
val canResolve = entry.suiteDetails.framework.canResolveChildren
val canResolve =
TestFrameworkUtils.canResolveTests(entry.suiteDetails.framework)
if (canResolve && buffers.contains(entry.path))
getTestCasesForSuites(entry.path, Vector(entry.suiteDetails), None)
else Nil
Expand Down Expand Up @@ -562,9 +569,9 @@ final class TestSuitesProvider(
def getFramework(
target: BuildTarget,
selection: ScalaTestSuiteSelection,
): TestFramework = getFromCache(target, selection.className)
): Config.TestFramework = getFromCache(target, selection.className)
.map(_.suiteDetails.framework)
.getOrElse(Unknown)
.getOrElse(Config.TestFramework(Nil))

def getFromCache(
target: BuildTarget,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package scala.meta.internal.metals.testProvider.frameworks

import scala.reflect.NameTransformer

import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.testProvider.TestCaseEntry
import scala.meta.internal.mtags
import scala.meta.internal.semanticdb.SymbolInformation
import scala.meta.internal.semanticdb.TextDocument
import scala.meta.internal.semanticdb.TypeRef
import scala.meta.io.AbsolutePath

trait AnnotationTestFinder {
def expectedAnnotationSymbol: String

def findTests(
doc: TextDocument,
path: AbsolutePath,
suiteSymbol: mtags.Symbol,
): Vector[TestCaseEntry] = {
val uri = path.toURI

def isMethodWithTestAnnotation(symbol: SymbolInformation) = {
symbol.kind == SymbolInformation.Kind.METHOD && symbol.annotations
.exists(_.tpe match {
case TypeRef(_, annotationSymbol, _) =>
annotationSymbol == expectedAnnotationSymbol
case _ => false
})
}

def isValid(symbol: SymbolInformation): Boolean =
isMethodWithTestAnnotation(symbol) && symbol.symbol.startsWith(
suiteSymbol.value
)

doc.symbols
.collect {
case symbol if isValid(symbol) =>
doc
.toLocation(uri, symbol.symbol)
.map { location =>
val encodedName = NameTransformer.encode(symbol.displayName)
TestCaseEntry(
encodedName,
symbol.displayName,
location,
)
}
}
.flatten
.toVector
}
}
Loading