Skip to content

Commit bfbfe06

Browse files
committed
feat: Add TraversalBuilder.getValuePresentedSource method for further optimization.
1 parent 2469f72 commit bfbfe06

File tree

5 files changed

+140
-19
lines changed

5 files changed

+140
-19
lines changed

Diff for: stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala

+95-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@ import org.apache.pekko
1717
import pekko.NotUsed
1818
import pekko.stream._
1919
import pekko.stream.impl.TraversalTestUtils._
20-
import pekko.stream.scaladsl.Keep
20+
import pekko.stream.impl.fusing.IterableSource
21+
import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource }
22+
import pekko.stream.scaladsl.{ Keep, Source }
23+
import pekko.util.OptionVal
2124
import pekko.testkit.PekkoSpec
2225

26+
import scala.concurrent.Future
27+
2328
class TraversalBuilderSpec extends PekkoSpec {
2429

2530
"CompositeTraversalBuilder" must {
@@ -447,4 +452,93 @@ class TraversalBuilderSpec extends PekkoSpec {
447452
}
448453
}
449454

455+
"find Source.single via TraversalBuilder" in {
456+
TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a")
457+
TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None)
458+
459+
val singleSourceA = new SingleSource("a")
460+
TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA))
461+
462+
TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None)
463+
TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
464+
}
465+
466+
"find Source.single via TraversalBuilder with getValuePresentedSource" in {
467+
TraversalBuilder.getValuePresentedSource(Source.single("a")).get.asInstanceOf[SingleSource[String]].elem should ===(
468+
"a")
469+
val singleSourceA = new SingleSource("a")
470+
TraversalBuilder.getValuePresentedSource(singleSourceA) should be(OptionVal.Some(singleSourceA))
471+
472+
TraversalBuilder.getValuePresentedSource(Source.single("c").async) should be(OptionVal.None)
473+
TraversalBuilder.getValuePresentedSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(
474+
OptionVal.None)
475+
}
476+
477+
"find Source.empty via TraversalBuilder with getValuePresentedSource" in {
478+
val emptySource = EmptySource
479+
TraversalBuilder.getValuePresentedSource(emptySource) should be(OptionVal.Some(emptySource))
480+
481+
TraversalBuilder.getValuePresentedSource(Source.empty.async) should be(OptionVal.None)
482+
TraversalBuilder.getValuePresentedSource(Source.empty.mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
483+
}
484+
485+
"find javadsl Source.empty via TraversalBuilder with getValuePresentedSource" in {
486+
import pekko.stream.javadsl.Source
487+
val emptySource = Source.empty()
488+
TraversalBuilder.getValuePresentedSource(Source.empty()) should be(OptionVal.Some(emptySource))
489+
490+
TraversalBuilder.getValuePresentedSource(Source.empty().async) should be(OptionVal.None)
491+
TraversalBuilder.getValuePresentedSource(Source.empty().mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
492+
}
493+
494+
"find Source.future via TraversalBuilder with getValuePresentedSource" in {
495+
val future = Future.successful("a")
496+
TraversalBuilder.getValuePresentedSource(Source.future(future)).get.asInstanceOf[FutureSource[String]].future should ===(
497+
future)
498+
val futureSourceA = new FutureSource(future)
499+
TraversalBuilder.getValuePresentedSource(futureSourceA) should be(OptionVal.Some(futureSourceA))
500+
501+
TraversalBuilder.getValuePresentedSource(Source.future(future).async) should be(OptionVal.None)
502+
TraversalBuilder.getValuePresentedSource(Source.future(future).mapMaterializedValue(_ => "Mat")) should be(
503+
OptionVal.None)
504+
}
505+
506+
"find Source.iterable via TraversalBuilder with getValuePresentedSource" in {
507+
val iterable = List("a")
508+
TraversalBuilder.getValuePresentedSource(Source(iterable)).get.asInstanceOf[IterableSource[String]].elements should ===(
509+
iterable)
510+
val iterableSource = new IterableSource(iterable)
511+
TraversalBuilder.getValuePresentedSource(iterableSource) should be(OptionVal.Some(iterableSource))
512+
513+
TraversalBuilder.getValuePresentedSource(Source(iterable).async) should be(OptionVal.None)
514+
TraversalBuilder.getValuePresentedSource(Source(iterable).mapMaterializedValue(_ => "Mat")) should be(
515+
OptionVal.None)
516+
}
517+
518+
"find Source.javaStreamSource via TraversalBuilder with getValuePresentedSource" in {
519+
val javaStream = java.util.stream.Stream.empty[String]()
520+
TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream)).get
521+
.asInstanceOf[JavaStreamSource[String, _]].open() shouldEqual javaStream
522+
val streamSource = new JavaStreamSource(() => javaStream)
523+
TraversalBuilder.getValuePresentedSource(streamSource) should be(OptionVal.Some(streamSource))
524+
525+
TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream).async) should be(OptionVal.None)
526+
TraversalBuilder.getValuePresentedSource(
527+
Source.fromJavaStream(() => javaStream).mapMaterializedValue(_ => "Mat")) should be(
528+
OptionVal.None)
529+
}
530+
531+
"find Source.failed via TraversalBuilder with getValuePresentedSource" in {
532+
val failure = new RuntimeException("failure")
533+
TraversalBuilder.getValuePresentedSource(Source.failed(failure)).get.asInstanceOf[FailedSource[String]]
534+
.failure should ===(
535+
failure)
536+
val failedSourceA = new FailedSource(failure)
537+
TraversalBuilder.getValuePresentedSource(failedSourceA) should be(OptionVal.Some(failedSourceA))
538+
539+
TraversalBuilder.getValuePresentedSource(Source.failed(failure).async) should be(OptionVal.None)
540+
TraversalBuilder.getValuePresentedSource(Source.failed(failure).mapMaterializedValue(_ => "Mat")) should be(
541+
OptionVal.None)
542+
}
543+
450544
}

Diff for: stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala

-14
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import scala.concurrent.duration._
1919
import org.apache.pekko
2020
import pekko.NotUsed
2121
import pekko.stream._
22-
import pekko.stream.impl.TraversalBuilder
23-
import pekko.stream.impl.fusing.GraphStages.SingleSource
2422
import pekko.stream.stage.GraphStage
2523
import pekko.stream.stage.GraphStageLogic
2624
import pekko.stream.stage.OutHandler
@@ -29,7 +27,6 @@ import pekko.stream.testkit.TestPublisher
2927
import pekko.stream.testkit.Utils.TE
3028
import pekko.stream.testkit.scaladsl.TestSink
3129
import pekko.testkit.TestLatch
32-
import pekko.util.OptionVal
3330

3431
import org.scalatest.exceptions.TestFailedException
3532

@@ -283,16 +280,5 @@ class FlowFlattenMergeSpec extends StreamSpec {
283280
probe.expectComplete()
284281
}
285282

286-
"find Source.single via TraversalBuilder" in {
287-
TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a")
288-
TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None)
289-
290-
val singleSourceA = new SingleSource("a")
291-
TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA))
292-
293-
TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None)
294-
TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
295-
}
296-
297283
}
298284
}

Diff for: stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import pekko.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
2222
/**
2323
* INTERNAL API
2424
*/
25-
@InternalApi private[pekko] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] {
25+
@InternalApi private[pekko] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] {
2626
val out = Outlet[T]("FailedSource.out")
2727
override val shape = SourceShape(out)
2828

Diff for: stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.util.function.Consumer
2323

2424
/** INTERNAL API */
2525
@InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]](
26-
open: () => java.util.stream.BaseStream[T, S])
26+
val open: () => java.util.stream.BaseStream[T, S])
2727
extends GraphStage[SourceShape[T]] {
2828

2929
val out: Outlet[T] = Outlet("JavaStreamSource")

Diff for: stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala

+43-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import pekko.annotation.{ DoNotInherit, InternalApi }
2121
import pekko.stream._
2222
import pekko.stream.impl.StreamLayout.AtomicModule
2323
import pekko.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
24-
import pekko.stream.impl.fusing.GraphStageModule
25-
import pekko.stream.impl.fusing.GraphStages.SingleSource
24+
import pekko.stream.impl.fusing.{ GraphStageModule, IterableSource }
25+
import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource }
2626
import pekko.stream.scaladsl.Keep
2727
import pekko.util.OptionVal
2828
import pekko.util.unused
@@ -380,12 +380,53 @@ import pekko.util.unused
380380
}
381381
}
382382

383+
/**
384+
* Try to find `SingleSource` or wrapped such. This is used as a
385+
* performance optimization in FlattenConcat and possibly other places.
386+
* @since 1.2.0
387+
*/
388+
@InternalApi def getValuePresentedSource[A >: Null](
389+
graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
390+
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
391+
case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] |
392+
_: FailedSource[_] =>
393+
true
394+
case maybeEmpty if isEmptySource(maybeEmpty) => true
395+
case _ => false
396+
}
397+
graph match {
398+
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
399+
case _ =>
400+
graph.traversalBuilder match {
401+
case l: LinearTraversalBuilder =>
402+
l.pendingBuilder match {
403+
case OptionVal.Some(a: AtomicTraversalBuilder) =>
404+
a.module match {
405+
case m: GraphStageModule[_, _] =>
406+
m.stage match {
407+
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
408+
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
409+
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
410+
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
411+
else OptionVal.None
412+
case _ => OptionVal.None
413+
}
414+
case _ => OptionVal.None
415+
}
416+
case _ => OptionVal.None
417+
}
418+
case _ => OptionVal.None
419+
}
420+
}
421+
}
422+
383423
/**
384424
* Test if a Graph is an empty Source.
385425
*/
386426
def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match {
387427
case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true
388428
case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true
429+
case EmptySource => true
389430
case _ => false
390431
}
391432

0 commit comments

Comments
 (0)