@@ -18,6 +18,7 @@ import com.intellij.psi.javadoc.PsiDocComment
18
18
import com .intellij .psi .util .PsiTreeUtil
19
19
import org .apache .commons .lang3 .StringUtils
20
20
import org .jetbrains .annotations .{NonNls , TestOnly }
21
+ import org .jetbrains .plugins .scala .ScalaBundle
21
22
import org .jetbrains .plugins .scala .extensions .{PsiElementExt , _ }
22
23
import org .jetbrains .plugins .scala .lang .formatting .scalafmt .processors .PsiChange ._
23
24
import org .jetbrains .plugins .scala .lang .formatting .scalafmt .processors .ScalaFmtPreFormatProcessor ._
@@ -38,8 +39,7 @@ import org.jetbrains.plugins.scala.lang.psi.impl.ScalaPsiElementFactory
38
39
import org .jetbrains .plugins .scala .lang .psi .impl .expr .ScBlockImpl
39
40
import org .jetbrains .plugins .scala .lang .psi .{ScalaPsiUtil , TypeAdjuster }
40
41
import org .jetbrains .plugins .scala .lang .scaladoc .psi .api .ScDocComment
41
- import org .jetbrains .plugins .scala .project .UserDataHolderExt
42
- import org .jetbrains .plugins .scala .{ScalaBundle , ScalaFileType }
42
+ import org .jetbrains .plugins .scala .project .{ScalaFeatures , UserDataHolderExt }
43
43
import org .scalafmt .dynamic .exceptions .{PositionExceptionImpl , ReflectionException }
44
44
import org .scalafmt .dynamic .{ScalafmtReflect , ScalafmtReflectConfig , ScalafmtVersion }
45
45
@@ -48,6 +48,7 @@ import javax.swing.event.HyperlinkEvent
48
48
import scala .annotation .{nowarn , tailrec }
49
49
import scala .collection .immutable .ArraySeq
50
50
import scala .collection .mutable
51
+ import scala .jdk .CollectionConverters .CollectionHasAsScala
51
52
import scala .util .Try
52
53
import scala .util .control .NonFatal
53
54
import scala .util .matching .Regex
@@ -421,7 +422,7 @@ object ScalaFmtPreFormatProcessor {
421
422
}
422
423
}
423
424
424
- def processRange (elements : Seq [PsiElement ], wrap : Boolean , typeAdjuster : TypeAdjuster ): Either [Unit , Int ] = {
425
+ def processRange (elements : Seq [PsiElement ], features : ScalaFeatures , wrap : Boolean , typeAdjuster : TypeAdjuster ): Either [Unit , Int ] = {
425
426
val hasRewriteRules = context.config.hasRewriteRules
426
427
val rewriteElements : Seq [PsiElement ] = if (hasRewriteRules) elements.flatMap(maybeRewriteElements(_, range)) else Seq .empty
427
428
val rewriteElementsToFormatted = attachFormattedCode(rewriteElements)
@@ -432,7 +433,7 @@ object ScalaFmtPreFormatProcessor {
432
433
val formattedInSingleFile = formatInSingleFile(elements, wrap)(project, newContext)
433
434
formattedInSingleFile match {
434
435
case Some (formatted) =>
435
- replaceWithFormatted(elements, formatted, rewriteElementsToFormatted, range, typeAdjuster) match {
436
+ replaceWithFormatted(elements, formatted, rewriteElementsToFormatted, features, range, typeAdjuster) match {
436
437
case Left (err) =>
437
438
reportMarkerNotFound(file, err)
438
439
Left (())
@@ -454,11 +455,11 @@ object ScalaFmtPreFormatProcessor {
454
455
Left (())
455
456
} else {
456
457
// failed to wrap some elements, try the whole file
457
- processRange(Seq (file), wrap = false , typeAdjuster).map(Some (_))
458
+ processRange(Seq (file), file.features, wrap = false , typeAdjuster).map(Some (_))
458
459
Right (None )
459
460
}
460
461
} else {
461
- processRange(elementsWrapped, wrap = true , typeAdjuster).map(Some (_))
462
+ processRange(elementsWrapped, file.features, wrap = true , typeAdjuster).map(Some (_))
462
463
}
463
464
}
464
465
@@ -477,23 +478,29 @@ object ScalaFmtPreFormatProcessor {
477
478
private def getText (range : TextRange )(implicit fileText : String ): String =
478
479
fileText.substring(range.getStartOffset, range.getEndOffset)
479
480
480
- private def unwrap (wrapFile : PsiFile )(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] = {
481
+ private def unwrap (wrapFile : ScalaFile )(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] = {
481
482
val text = wrapFile.getText
482
483
483
484
val startMarkerIdx = findMarker(text, StartMarkerFormattedRegex )
484
485
// I don't know when it can be the case that start/end element are null, but handling it just to avoid exceptions
485
- val startElement = wrapFile.findElementAt(startMarkerIdx)
486
- if (startElement == null )
487
- return Left (CantFindMarkerElementInFormattedCode (true ))
486
+ if (startMarkerIdx == - 1 )
487
+ return Left (CantFindMarkerElementInFormattedCode (isStartMarker = true ))
488
488
489
489
val endMarkerIdx = findMarker(text, EndMarkerFormattedRegex )
490
- val endElement = wrapFile.findElementAt(endMarkerIdx)
491
- if (endElement == null )
492
- return Left (CantFindMarkerElementInFormattedCode (true ))
490
+ if (endMarkerIdx == - 1 )
491
+ return Left (CantFindMarkerElementInFormattedCode (isStartMarker = false ))
493
492
494
- // we need to call extra `getParent` because findElementAt returns DOC_COMMENT_START
495
- val startMarker = startElement.getParent
496
- val endMarker = endElement.getParent
493
+ val docComments = PsiTreeUtil .findChildrenOfType(wrapFile, classOf [ScDocComment ]).asScala
494
+
495
+ val startMarker = docComments.find(e => e.getTextOffset == startMarkerIdx && StartMarkerFormattedRegex .matches(e.getText)) match {
496
+ case None => return Left (CantFindMarkerElementInFormattedCode (isStartMarker = true ))
497
+ case Some (marker) => marker
498
+ }
499
+
500
+ val endMarker = docComments.find(e => e.getTextOffset == endMarkerIdx && EndMarkerFormattedRegex .matches(e.getText)) match {
501
+ case None => return Left (CantFindMarkerElementInFormattedCode (isStartMarker = false ))
502
+ case Some (marker) => marker
503
+ }
497
504
498
505
assert(startMarker.is[ScDocComment ])
499
506
assert(endMarker.is[ScDocComment ])
@@ -663,9 +670,10 @@ object ScalaFmtPreFormatProcessor {
663
670
}
664
671
665
672
private def unwrapPsiFromFormattedFile (
666
- formattedCode : WrappedCode
673
+ formattedCode : WrappedCode ,
674
+ features : ScalaFeatures
667
675
)(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , RewriteElements ] = {
668
- val wrapFile = PsiFileFactory .getInstance(project).createFileFromText( DummyWrapperClassName , ScalaFileType . INSTANCE , formattedCode.text )
676
+ val wrapFile = ScalaPsiElementFactory .createScalaFileFromText(formattedCode.text, features, shouldTrimText = false )
669
677
val elementsUnwrapped : Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] =
670
678
if (formattedCode.wrapped) unwrap(wrapFile)
671
679
else Right (Seq (wrapFile))
@@ -675,10 +683,11 @@ object ScalaFmtPreFormatProcessor {
675
683
}
676
684
677
685
private def unwrapPsiFromFormattedElements (
678
- elementsToFormatted : Seq [(PsiElement , WrappedCode )]
686
+ elementsToFormatted : Seq [(PsiElement , WrappedCode )],
687
+ features : ScalaFeatures
679
688
)(implicit project : Project ): Seq [(PsiElement , Either [CantFindMarkerElementInFormattedCode , RewriteElements ])] = {
680
689
val withUnwrapped = elementsToFormatted.map { case (element, formattedCode) =>
681
- val unwrapped = unwrapPsiFromFormattedFile(formattedCode)
690
+ val unwrapped = unwrapPsiFromFormattedFile(formattedCode, features )
682
691
(element, unwrapped)
683
692
}
684
693
withUnwrapped.sortBy(_._1.getTextRange.getStartOffset)
@@ -687,18 +696,19 @@ object ScalaFmtPreFormatProcessor {
687
696
private def replaceWithFormatted (elements : Iterable [PsiElement ],
688
697
formattedCode : WrappedCode ,
689
698
rewriteToFormatted : Seq [(PsiElement , WrappedCode )],
699
+ features : ScalaFeatures ,
690
700
range : TextRange ,
691
701
typeAdjuster : TypeAdjuster )
692
702
(implicit project : Project , fileText : String ): Either [CantFindMarkerElementInFormattedCode , Int ] = {
693
- val elementsUnwrapped : Seq [PsiElement ] = unwrapPsiFromFormattedFile(formattedCode) match {
703
+ val elementsUnwrapped : Seq [PsiElement ] = unwrapPsiFromFormattedFile(formattedCode, features ) match {
694
704
case Right (value) =>
695
705
value.elements
696
706
case Left (err) =>
697
707
return Left (err)
698
708
}
699
709
val elementsToTraverse : Iterable [(PsiElement , PsiElement )] = elements.zip(elementsUnwrapped)
700
710
val rewriteElementsToTraverse0 : Seq [(PsiElement , Either [CantFindMarkerElementInFormattedCode , RewriteElements ])] =
701
- unwrapPsiFromFormattedElements(rewriteToFormatted)
711
+ unwrapPsiFromFormattedElements(rewriteToFormatted, features )
702
712
703
713
rewriteElementsToTraverse0.find(_._2.isLeft) match {
704
714
case Some ((_, Left (err))) =>
0 commit comments