Skip to content

Commit 559edd4

Browse files
committed
improvement: look for definition in pc only for local symbols in the current tree
1 parent 19690b4 commit 559edd4

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

presentation-compiler/src/main/dotty/tools/pc/PcDefinitionProvider.scala

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.pc
22

3+
import java.net.URI
34
import java.nio.file.Paths
45
import java.util.ArrayList
56

@@ -16,6 +17,7 @@ import dotty.tools.dotc.core.Contexts.Context
1617
import dotty.tools.dotc.core.Flags.{Exported, ModuleClass}
1718
import dotty.tools.dotc.core.Symbols.*
1819
import dotty.tools.dotc.interactive.Interactive
20+
import dotty.tools.dotc.interactive.Interactive.Include
1921
import dotty.tools.dotc.interactive.InteractiveDriver
2022
import dotty.tools.dotc.util.SourceFile
2123
import dotty.tools.dotc.util.SourcePosition
@@ -51,10 +53,10 @@ class PcDefinitionProvider(
5153
given ctx: Context = driver.localContext(params)
5254
val indexedContext = IndexedContext(ctx)
5355
val result =
54-
if findTypeDef then findTypeDefinitions(path, pos, indexedContext)
55-
else findDefinitions(path, pos, indexedContext)
56+
if findTypeDef then findTypeDefinitions(path, pos, indexedContext, uri)
57+
else findDefinitions(path, pos, indexedContext, uri)
5658

57-
if result.locations().nn.isEmpty() then fallbackToUntyped(pos)(using ctx)
59+
if result.locations().nn.isEmpty() then fallbackToUntyped(pos, uri)(using ctx)
5860
else result
5961
end definitions
6062

@@ -70,32 +72,35 @@ class PcDefinitionProvider(
7072
* @param pos cursor position
7173
* @return definition result
7274
*/
73-
private def fallbackToUntyped(pos: SourcePosition)(
75+
private def fallbackToUntyped(pos: SourcePosition, uri: URI)(
7476
using ctx: Context
7577
) =
7678
lazy val untpdPath = NavigateAST
7779
.untypedPath(pos.span)
7880
.collect { case t: untpd.Tree => t }
7981

80-
definitionsForSymbol(untpdPath.headOption.map(_.symbol).toList, pos)
82+
definitionsForSymbol(untpdPath.headOption.map(_.symbol).toList, uri, pos)
8183
end fallbackToUntyped
8284

8385
private def findDefinitions(
8486
path: List[Tree],
8587
pos: SourcePosition,
86-
indexed: IndexedContext
88+
indexed: IndexedContext,
89+
uri: URI,
8790
): DefinitionResult =
8891
import indexed.ctx
8992
definitionsForSymbol(
9093
MetalsInteractive.enclosingSymbols(path, pos, indexed),
94+
uri,
9195
pos
9296
)
9397
end findDefinitions
9498

9599
private def findTypeDefinitions(
96100
path: List[Tree],
97101
pos: SourcePosition,
98-
indexed: IndexedContext
102+
indexed: IndexedContext,
103+
uri: URI,
99104
): DefinitionResult =
100105
import indexed.ctx
101106
val enclosing = path.expandRangeToEnclosingApply(pos)
@@ -108,24 +113,25 @@ class PcDefinitionProvider(
108113
case Nil =>
109114
path.headOption match
110115
case Some(value: Literal) =>
111-
definitionsForSymbol(List(value.typeOpt.widen.typeSymbol), pos)
116+
definitionsForSymbol(List(value.typeOpt.widen.typeSymbol), uri, pos)
112117
case _ => DefinitionResultImpl.empty
113118
case _ =>
114-
definitionsForSymbol(typeSymbols, pos)
119+
definitionsForSymbol(typeSymbols, uri, pos)
115120

116121
end findTypeDefinitions
117122

118123
private def definitionsForSymbol(
119124
symbols: List[Symbol],
125+
uri: URI,
120126
pos: SourcePosition
121127
)(using ctx: Context): DefinitionResult =
122128
symbols match
123129
case symbols @ (sym :: other) =>
124130
val isLocal = sym.source == pos.source
125131
if isLocal then
132+
val include = Include.definitions | Include.local
126133
val (exportedDefs, otherDefs) =
127-
Interactive.findDefinitions(List(sym), driver, false, false)
128-
.filter(_.source == sym.source)
134+
Interactive.findTreesMatching(driver.openedTrees(uri), include, sym)
129135
.partition(_.tree.symbol.is(Exported))
130136

131137
otherDefs.headOption.orElse(exportedDefs.headOption) match

presentation-compiler/test/dotty/tools/pc/tests/CompilerCachingSuite.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import dotty.tools.pc.ScalaPresentationCompiler
66
import org.junit.{Before, Test}
77

88
import scala.language.unsafeNulls
9-
import scala.meta.internal.metals.EmptyCancelToken
109
import scala.meta.internal.metals.CompilerOffsetParams
10+
import scala.meta.internal.metals.EmptyCancelToken
11+
import scala.meta.internal.metals.EmptyReportContext
12+
import scala.meta.internal.metals.PcQueryContext
1113
import scala.meta.pc.OffsetParams
1214
import scala.concurrent.Future
1315
import scala.concurrent.Await
@@ -26,20 +28,22 @@ class CompilerCachingSuite extends BasePCSuite:
2628
private def checkCompilationCount(expected: Int): Unit =
2729
presentationCompiler match
2830
case pc: ScalaPresentationCompiler =>
29-
val compilations = pc.compilerAccess.withNonInterruptableCompiler(None)(-1, EmptyCancelToken) { driver =>
31+
val compilations = pc.compilerAccess.withNonInterruptableCompiler(-1, EmptyCancelToken) { driver =>
3032
driver.compiler().currentCtx.runId
31-
}.get(timeout.length, timeout.unit)
33+
}(emptyQueryContext).get(timeout.length, timeout.unit)
3234
assertEquals(expected, compilations, s"Expected $expected compilations but got $compilations")
3335
case _ => throw IllegalStateException("Presentation compiler should always be of type of ScalaPresentationCompiler")
3436

3537
private def getContext(): Context =
3638
presentationCompiler match
3739
case pc: ScalaPresentationCompiler =>
38-
pc.compilerAccess.withNonInterruptableCompiler(None)(null, EmptyCancelToken) { driver =>
40+
pc.compilerAccess.withNonInterruptableCompiler(null, EmptyCancelToken) { driver =>
3941
driver.compiler().currentCtx
40-
}.get(timeout.length, timeout.unit)
42+
}(emptyQueryContext).get(timeout.length, timeout.unit)
4143
case _ => throw IllegalStateException("Presentation compiler should always be of type of ScalaPresentationCompiler")
4244

45+
private def emptyQueryContext = PcQueryContext(None, () => "")(using EmptyReportContext)
46+
4347
@Before
4448
def beforeEach: Unit =
4549
presentationCompiler.restart()

presentation-compiler/test/dotty/tools/pc/utils/TestInlayHints.scala

+8-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package dotty.tools.pc.utils
33
import scala.collection.mutable.ListBuffer
44

55
import scala.meta.internal.jdk.CollectionConverters._
6+
import scala.meta.internal.pc.InlayHints
67
import dotty.tools.pc.utils.InteractiveEnrichments.*
78

9+
import com.google.gson.JsonElement
810
import org.eclipse.lsp4j.InlayHint
911
import org.eclipse.lsp4j.TextEdit
1012
import org.eclipse.{lsp4j => l}
@@ -31,7 +33,7 @@ object TestInlayHints {
3133
case Right(labelParts) => labelParts.asScala.map(_.getValue()).toList
3234
}
3335
val data =
34-
inlayHint.getData().asInstanceOf[Array[Any]]
36+
InlayHints.fromData(inlayHint.getData().asInstanceOf[JsonElement])._2
3537
buffer += "/*"
3638
labels.zip(data).foreach { case (label, data) =>
3739
buffer += label.nn
@@ -41,15 +43,13 @@ object TestInlayHints {
4143
buffer.toList.mkString
4244
}
4345

44-
private def readData(data: Any): List[String] = {
45-
data match {
46-
case data: String if data.isEmpty => Nil
47-
case data: String => List("<<", data, ">>")
48-
case data: l.Position =>
46+
private def readData(data: Either[String, l.Position]): List[String] =
47+
data match
48+
case Left("") => Nil
49+
case Left(data) => List("<<", data, ">>")
50+
case Right(data) =>
4951
val str = s"(${data.getLine()}:${data.getCharacter()})"
5052
List("<<", str, ">>")
51-
}
52-
}
5353

5454
def applyInlayHints(text: String, inlayHints: List[InlayHint]): String = {
5555
val textEdits = inlayHints.map { hint =>

0 commit comments

Comments
 (0)