@@ -21,9 +21,11 @@ import cats.data.NonEmptyList
2121import cats .instances .list ._
2222import cats .syntax .all ._
2323import cats .effect ._
24-
2524import fetch ._
2625
26+ import java .util .concurrent .atomic .AtomicInteger
27+ import scala .concurrent .duration .{DurationInt , FiniteDuration }
28+
2729class FetchBatchingTests extends FetchSpec {
2830 import TestHelper ._
2931
@@ -94,6 +96,51 @@ class FetchBatchingTests extends FetchSpec {
9496 }
9597 }
9698
99+ case class BatchAcrossFetchData (id : Int )
100+
101+ object BatchAcrossFetches extends Data [BatchAcrossFetchData , String ] {
102+ def name = " Batch across Fetches"
103+
104+ private val batchesCounter = new AtomicInteger (0 )
105+ private val fetchesCounter = new AtomicInteger (0 )
106+
107+ def reset (): Unit = {
108+ batchesCounter.set(0 )
109+ fetchesCounter.set(0 )
110+ }
111+
112+ def counters : (Int , Int ) =
113+ (fetchesCounter.get(), batchesCounter.get())
114+
115+ def unBatchedSource [F [_]: Concurrent ]: DataSource [F , BatchAcrossFetchData , String ] =
116+ new DataSource [F , BatchAcrossFetchData , String ] {
117+ override def data = BatchAcrossFetches
118+
119+ override def CF = Concurrent [F ]
120+
121+ override def fetch (request : BatchAcrossFetchData ): F [Option [String ]] = {
122+ fetchesCounter.incrementAndGet()
123+ CF .pure(Some (request.toString))
124+ }
125+
126+ override def batch (
127+ ids : NonEmptyList [BatchAcrossFetchData ]
128+ ): F [Map [BatchAcrossFetchData , String ]] = {
129+ batchesCounter.incrementAndGet()
130+ CF .pure(
131+ ids.map(id => id -> id.toString).toList.toMap
132+ )
133+ }
134+
135+ override val batchExecution = InParallel
136+ }
137+
138+ def batchedSource [F [_]: Async ](
139+ interval : FiniteDuration
140+ ): Resource [F , DataSource [F , BatchAcrossFetchData , String ]] =
141+ DataSource .batchAcrossFetches(unBatchedSource, interval)
142+ }
143+
97144 def fetchBatchedDataSeq [F [_]: Concurrent ](id : Int ): Fetch [F , Int ] =
98145 Fetch (BatchedDataSeq (id), SeqBatch .source)
99146
@@ -207,4 +254,38 @@ class FetchBatchingTests extends FetchSpec {
207254 result shouldEqual ids.map(_.toString)
208255 }.unsafeToFuture()
209256 }
257+
258+ " Fetches produced across unrelated fetches to a DataSource that is NOT batched across fetch executions should NOT be bundled together" in {
259+ BatchAcrossFetches .reset()
260+ val dataSource = BatchAcrossFetches .unBatchedSource[IO ]
261+ val id1 = BatchAcrossFetchData (1 )
262+ val id2 = BatchAcrossFetchData (2 )
263+ val execution1 = Fetch .run[IO ](Fetch (id1, dataSource))
264+ val execution2 = Fetch .run[IO ](Fetch (id2, dataSource))
265+ val singleExecution = (execution1, execution2).parMapN { (_, _) =>
266+ val (fetchRequests, batchRequests) = BatchAcrossFetches .counters
267+ fetchRequests shouldEqual 2
268+ batchRequests shouldEqual 0
269+ }
270+ singleExecution.unsafeToFuture()
271+ }
272+
273+ " Fetches produced across unrelated fetches to a DataSource that is batched across fetch executions should be bundled together" in {
274+ BatchAcrossFetches .reset()
275+ val dataSource = BatchAcrossFetches .batchedSource[IO ](500 .millis)
276+ val id1 = BatchAcrossFetchData (1 )
277+ val id2 = BatchAcrossFetchData (2 )
278+ dataSource
279+ .use { dataSource =>
280+ val execution1 = Fetch .run[IO ](Fetch (id1, dataSource))
281+ val execution2 = Fetch .run[IO ](Fetch (id2, dataSource))
282+ val singleExecution = (execution1, execution2).parMapN { (_, _) =>
283+ val (fetchRequests, batchRequests) = BatchAcrossFetches .counters
284+ fetchRequests shouldEqual 0
285+ batchRequests shouldEqual 1
286+ }
287+ singleExecution
288+ }
289+ .unsafeToFuture()
290+ }
210291}
0 commit comments