@@ -18,6 +18,7 @@ import com.intellij.psi.javadoc.PsiDocComment
18
18
import com .intellij .psi .util .PsiTreeUtil
19
19
import org .apache .commons .lang3 .StringUtils
20
20
import org .jetbrains .annotations .{NonNls , TestOnly }
21
+ import org .jetbrains .plugins .scala .ScalaBundle
21
22
import org .jetbrains .plugins .scala .extensions .{PsiElementExt , _ }
22
23
import org .jetbrains .plugins .scala .lang .formatting .scalafmt .processors .PsiChange ._
23
24
import org .jetbrains .plugins .scala .lang .formatting .scalafmt .processors .ScalaFmtPreFormatProcessor ._
@@ -38,9 +39,8 @@ import org.jetbrains.plugins.scala.lang.psi.impl.ScalaPsiElementFactory
38
39
import org .jetbrains .plugins .scala .lang .psi .impl .expr .ScBlockImpl
39
40
import org .jetbrains .plugins .scala .lang .psi .{ScalaPsiUtil , TypeAdjuster }
40
41
import org .jetbrains .plugins .scala .lang .scaladoc .psi .api .ScDocComment
41
- import org .jetbrains .plugins .scala .project .UserDataHolderExt
42
+ import org .jetbrains .plugins .scala .project .{ ScalaFeatures , UserDataHolderExt }
42
43
import org .jetbrains .plugins .scala .scalaMeta .ScalaMetaParseException
43
- import org .jetbrains .plugins .scala .{ScalaBundle , ScalaFileType }
44
44
import org .scalafmt .dynamic .exceptions .{PositionExceptionImpl , ReflectionException }
45
45
import org .scalafmt .dynamic .{ScalafmtReflect , ScalafmtReflectConfig , ScalafmtVersion }
46
46
@@ -49,6 +49,7 @@ import javax.swing.event.HyperlinkEvent
49
49
import scala .annotation .{nowarn , tailrec }
50
50
import scala .collection .immutable .ArraySeq
51
51
import scala .collection .mutable
52
+ import scala .jdk .CollectionConverters .CollectionHasAsScala
52
53
import scala .util .Try
53
54
import scala .util .control .NonFatal
54
55
import scala .util .matching .Regex
@@ -422,7 +423,7 @@ object ScalaFmtPreFormatProcessor {
422
423
}
423
424
}
424
425
425
- def processRange (elements : Seq [PsiElement ], wrap : Boolean , typeAdjuster : TypeAdjuster ): Either [Unit , Int ] = {
426
+ def processRange (elements : Seq [PsiElement ], features : ScalaFeatures , wrap : Boolean , typeAdjuster : TypeAdjuster ): Either [Unit , Int ] = {
426
427
val hasRewriteRules = context.config.hasRewriteRules
427
428
val rewriteElements : Seq [PsiElement ] = if (hasRewriteRules) elements.flatMap(maybeRewriteElements(_, range)) else Seq .empty
428
429
val rewriteElementsToFormatted = attachFormattedCode(rewriteElements)
@@ -433,7 +434,7 @@ object ScalaFmtPreFormatProcessor {
433
434
val formattedInSingleFile = formatInSingleFile(elements, wrap)(project, newContext)
434
435
formattedInSingleFile match {
435
436
case Some (formatted) =>
436
- replaceWithFormatted(elements, formatted, rewriteElementsToFormatted, range, typeAdjuster) match {
437
+ replaceWithFormatted(elements, formatted, rewriteElementsToFormatted, features, range, typeAdjuster) match {
437
438
case Left (err) =>
438
439
reportMarkerNotFound(file, err)
439
440
Left (())
@@ -455,11 +456,11 @@ object ScalaFmtPreFormatProcessor {
455
456
Left (())
456
457
} else {
457
458
// failed to wrap some elements, try the whole file
458
- processRange(Seq (file), wrap = false , typeAdjuster).map(Some (_))
459
+ processRange(Seq (file), file.features, wrap = false , typeAdjuster).map(Some (_))
459
460
Right (None )
460
461
}
461
462
} else {
462
- processRange(elementsWrapped, wrap = true , typeAdjuster).map(Some (_))
463
+ processRange(elementsWrapped, file.features, wrap = true , typeAdjuster).map(Some (_))
463
464
}
464
465
}
465
466
@@ -478,23 +479,29 @@ object ScalaFmtPreFormatProcessor {
478
479
private def getText (range : TextRange )(implicit fileText : String ): String =
479
480
fileText.substring(range.getStartOffset, range.getEndOffset)
480
481
481
- private def unwrap (wrapFile : PsiFile )(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] = {
482
+ private def unwrap (wrapFile : ScalaFile )(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] = {
482
483
val text = wrapFile.getText
483
484
484
485
val startMarkerIdx = findMarker(text, StartMarkerFormattedRegex )
485
486
// I don't know when it can be the case that start/end element are null, but handling it just to avoid exceptions
486
- val startElement = wrapFile.findElementAt(startMarkerIdx)
487
- if (startElement == null )
488
- return Left (CantFindMarkerElementInFormattedCode (true ))
487
+ if (startMarkerIdx == - 1 )
488
+ return Left (CantFindMarkerElementInFormattedCode (isStartMarker = true ))
489
489
490
490
val endMarkerIdx = findMarker(text, EndMarkerFormattedRegex )
491
- val endElement = wrapFile.findElementAt(endMarkerIdx)
492
- if (endElement == null )
493
- return Left (CantFindMarkerElementInFormattedCode (true ))
491
+ if (endMarkerIdx == - 1 )
492
+ return Left (CantFindMarkerElementInFormattedCode (isStartMarker = false ))
494
493
495
- // we need to call extra `getParent` because findElementAt returns DOC_COMMENT_START
496
- val startMarker = startElement.getParent
497
- val endMarker = endElement.getParent
494
+ val docComments = PsiTreeUtil .findChildrenOfType(wrapFile, classOf [ScDocComment ]).asScala
495
+
496
+ val startMarker = docComments.find(e => e.getTextOffset == startMarkerIdx && StartMarkerFormattedRegex .matches(e.getText)) match {
497
+ case None => return Left (CantFindMarkerElementInFormattedCode (isStartMarker = true ))
498
+ case Some (marker) => marker
499
+ }
500
+
501
+ val endMarker = docComments.find(e => e.getTextOffset == endMarkerIdx && EndMarkerFormattedRegex .matches(e.getText)) match {
502
+ case None => return Left (CantFindMarkerElementInFormattedCode (isStartMarker = false ))
503
+ case Some (marker) => marker
504
+ }
498
505
499
506
assert(startMarker.is[ScDocComment ])
500
507
assert(endMarker.is[ScDocComment ])
@@ -664,9 +671,10 @@ object ScalaFmtPreFormatProcessor {
664
671
}
665
672
666
673
private def unwrapPsiFromFormattedFile (
667
- formattedCode : WrappedCode
674
+ formattedCode : WrappedCode ,
675
+ features : ScalaFeatures
668
676
)(implicit project : Project ): Either [CantFindMarkerElementInFormattedCode , RewriteElements ] = {
669
- val wrapFile = PsiFileFactory .getInstance(project).createFileFromText( DummyWrapperClassName , ScalaFileType . INSTANCE , formattedCode.text )
677
+ val wrapFile = ScalaPsiElementFactory .createScalaFileFromText(formattedCode.text, features, shouldTrimText = false )
670
678
val elementsUnwrapped : Either [CantFindMarkerElementInFormattedCode , Seq [PsiElement ]] =
671
679
if (formattedCode.wrapped) unwrap(wrapFile)
672
680
else Right (Seq (wrapFile))
@@ -676,10 +684,11 @@ object ScalaFmtPreFormatProcessor {
676
684
}
677
685
678
686
private def unwrapPsiFromFormattedElements (
679
- elementsToFormatted : Seq [(PsiElement , WrappedCode )]
687
+ elementsToFormatted : Seq [(PsiElement , WrappedCode )],
688
+ features : ScalaFeatures
680
689
)(implicit project : Project ): Seq [(PsiElement , Either [CantFindMarkerElementInFormattedCode , RewriteElements ])] = {
681
690
val withUnwrapped = elementsToFormatted.map { case (element, formattedCode) =>
682
- val unwrapped = unwrapPsiFromFormattedFile(formattedCode)
691
+ val unwrapped = unwrapPsiFromFormattedFile(formattedCode, features )
683
692
(element, unwrapped)
684
693
}
685
694
withUnwrapped.sortBy(_._1.getTextRange.getStartOffset)
@@ -688,18 +697,19 @@ object ScalaFmtPreFormatProcessor {
688
697
private def replaceWithFormatted (elements : Iterable [PsiElement ],
689
698
formattedCode : WrappedCode ,
690
699
rewriteToFormatted : Seq [(PsiElement , WrappedCode )],
700
+ features : ScalaFeatures ,
691
701
range : TextRange ,
692
702
typeAdjuster : TypeAdjuster )
693
703
(implicit project : Project , fileText : String ): Either [CantFindMarkerElementInFormattedCode , Int ] = {
694
- val elementsUnwrapped : Seq [PsiElement ] = unwrapPsiFromFormattedFile(formattedCode) match {
704
+ val elementsUnwrapped : Seq [PsiElement ] = unwrapPsiFromFormattedFile(formattedCode, features ) match {
695
705
case Right (value) =>
696
706
value.elements
697
707
case Left (err) =>
698
708
return Left (err)
699
709
}
700
710
val elementsToTraverse : Iterable [(PsiElement , PsiElement )] = elements.zip(elementsUnwrapped)
701
711
val rewriteElementsToTraverse0 : Seq [(PsiElement , Either [CantFindMarkerElementInFormattedCode , RewriteElements ])] =
702
- unwrapPsiFromFormattedElements(rewriteToFormatted)
712
+ unwrapPsiFromFormattedElements(rewriteToFormatted, features )
703
713
704
714
rewriteElementsToTraverse0.find(_._2.isLeft) match {
705
715
case Some ((_, Left (err))) =>
0 commit comments