Skip to content

Commit 32a50dc

Browse files
authored
Merge pull request #4694 from typelevel/oscar/20250102_drop_take_chain
2 parents 8ce7326 + fdcff72 commit 32a50dc

File tree

4 files changed

+242
-11
lines changed

4 files changed

+242
-11
lines changed

core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala

+8-6
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@ private[data] trait ChainCompanionCompat {
3838
}
3939

4040
private def fromImmutableSeq[A](s: immutable.Seq[A]): Chain[A] = {
41-
if (s.isEmpty) nil
42-
else if (s.lengthCompare(1) == 0) one(s.head)
43-
else Wrap(s)
41+
val lc = s.lengthCompare(1)
42+
if (lc < 0) nil
43+
else if (lc > 0) Wrap(s)
44+
else one(s.head)
4445
}
4546

4647
private def fromMutableSeq[A](s: Seq[A]): Chain[A] = {
47-
if (s.isEmpty) nil
48-
else if (s.lengthCompare(1) == 0) one(s.head)
49-
else Wrap(s.toVector)
48+
val lc = s.lengthCompare(1)
49+
if (lc < 0) nil
50+
else if (lc > 0) Wrap(s.toVector)
51+
else one(s.head)
5052
}
5153

5254
/**

core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ private[data] trait ChainCompanionCompat {
2828
/**
2929
* Creates a Chain from the specified sequence.
3030
*/
31-
def fromSeq[A](s: Seq[A]): Chain[A] =
32-
if (s.isEmpty) nil
33-
else if (s.lengthCompare(1) == 0) one(s.head)
34-
else Wrap(s)
31+
def fromSeq[A](s: Seq[A]): Chain[A] = {
32+
val lc = s.lengthCompare(1)
33+
if (lc < 0) nil
34+
else if (lc > 0) Wrap(s)
35+
else one(s.head)
36+
}
3537

3638
/**
3739
* Creates a Chain from the specified IterableOnce.

core/src/main/scala/cats/data/Chain.scala

+194-1
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,99 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
256256
result
257257
}
258258

259+
/**
260+
* take a certain amount of items from the front of the Chain
261+
*/
262+
final def take(count: Long): Chain[A] = {
263+
// invariant count >= 1
264+
@tailrec
265+
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
266+
arg match {
267+
case Wrap(seq) =>
268+
if (count == 1) {
269+
lhs.append(seq.head)
270+
} else {
271+
// count > 1
272+
val taken =
273+
if (count < Int.MaxValue) seq.take(count.toInt)
274+
else seq.take(Int.MaxValue)
275+
// we may have not taken all of count
276+
val newCount = count - taken.length
277+
val wrapped = Wrap(taken)
278+
// this is more efficient than using concat
279+
val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
280+
rhs match {
281+
case rhsNE: NonEmpty[A] if newCount > 0L =>
282+
// we have to keep taking on the rhs
283+
go(newLhs, newCount, rhsNE, Empty)
284+
case _ =>
285+
newLhs
286+
}
287+
}
288+
case Append(l, r) =>
289+
go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs))
290+
case s @ Singleton(_) =>
291+
// due to the invariant count >= 1
292+
val newLhs = if (lhs.isEmpty) s else Append(lhs, s)
293+
rhs match {
294+
case rhsNE: NonEmpty[A] if count > 1L =>
295+
go(newLhs, count - 1L, rhsNE, Empty)
296+
case _ => newLhs
297+
}
298+
}
299+
300+
this match {
301+
case ne: NonEmpty[A] if count > 0L =>
302+
go(Empty, count, ne, Empty)
303+
case _ => Empty
304+
}
305+
}
306+
307+
/**
308+
* take a certain amount of items from the back of the Chain
309+
*/
310+
final def takeRight(count: Long): Chain[A] = {
311+
// invariant count >= 1
312+
@tailrec
313+
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
314+
arg match {
315+
case Wrap(seq) =>
316+
if (count == 1L) {
317+
seq.last +: rhs
318+
} else {
319+
// count > 1
320+
val taken =
321+
if (count < Int.MaxValue) seq.takeRight(count.toInt)
322+
else seq.takeRight(Int.MaxValue)
323+
// we may have not taken all of count
324+
val newCount = count - taken.length
325+
val wrapped = Wrap(taken)
326+
val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
327+
lhs match {
328+
case lhsNE: NonEmpty[A] if newCount > 0 =>
329+
go(Empty, newCount, lhsNE, newRhs)
330+
case _ => newRhs
331+
}
332+
}
333+
case Append(l, r) =>
334+
go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs)
335+
case s @ Singleton(_) =>
336+
// due to the invariant count >= 1
337+
val newRhs = if (rhs.isEmpty) s else Append(s, rhs)
338+
lhs match {
339+
case lhsNE: NonEmpty[A] if count > 1 =>
340+
go(Empty, count - 1, lhsNE, newRhs)
341+
case _ => newRhs
342+
}
343+
}
344+
345+
this match {
346+
case ne: NonEmpty[A] if count > 0L =>
347+
go(Empty, count, ne, Empty)
348+
case _ => Empty
349+
}
350+
}
351+
259352
/**
260353
* Drops longest prefix of elements that satisfy a predicate.
261354
*
@@ -275,6 +368,105 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
275368
go(this)
276369
}
277370

371+
/**
372+
* Drop a certain amount of items from the front of the Chain
373+
*/
374+
final def drop(count: Long): Chain[A] = {
375+
// invariant count >= 1
376+
@tailrec
377+
def go(count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] =
378+
arg match {
379+
case Wrap(seq) =>
380+
val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue)
381+
val lc = dropped.lengthCompare(1)
382+
if (lc < 0) {
383+
// if dropped.length < 1, then it is zero
384+
// we may have not dropped all of count
385+
val newCount = count - seq.length
386+
rhs match {
387+
case rhsNE: NonEmpty[A] if newCount > 0 =>
388+
// we have to keep dropping on the rhs
389+
go(newCount, rhsNE, Empty)
390+
case _ =>
391+
// we know that count >= seq.length else we wouldn't be empty
392+
// so in this case, it is exactly count == seq.length
393+
rhs
394+
}
395+
} else {
396+
// dropped is not empty
397+
val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head)
398+
// we must be done
399+
if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
400+
}
401+
case Append(l, r) =>
402+
go(count, l, if (rhs.isEmpty) r else Append(r, rhs))
403+
case Singleton(_) =>
404+
// due to the invariant count >= 1
405+
rhs match {
406+
case rhsNE: NonEmpty[A] if count > 1L =>
407+
go(count - 1L, rhsNE, Empty)
408+
case _ =>
409+
rhs
410+
}
411+
}
412+
413+
this match {
414+
case ne: NonEmpty[A] if count > 0L =>
415+
go(count, ne, Empty)
416+
case _ => this
417+
}
418+
}
419+
420+
/**
421+
* Drop a certain amount of items from the back of the Chain
422+
*/
423+
final def dropRight(count: Long): Chain[A] = {
424+
// invariant count >= 1
425+
@tailrec
426+
def go(lhs: Chain[A], count: Long, arg: NonEmpty[A]): Chain[A] =
427+
arg match {
428+
case Wrap(seq) =>
429+
val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue)
430+
val lc = dropped.lengthCompare(1)
431+
if (lc < 0) {
432+
// if dropped.length < 1, then it is zero
433+
// we may have not dropped all of count
434+
val newCount = count - seq.length
435+
lhs match {
436+
case lhsNE: NonEmpty[A] if newCount > 0L =>
437+
// we have to keep dropping on the lhs
438+
go(Empty, newCount, lhsNE)
439+
case _ =>
440+
// we know that count >= seq.length else we wouldn't be empty
441+
// so in this case, it is exactly count == seq.length
442+
lhs
443+
}
444+
} else {
445+
// we must be done
446+
// note: dropped.nonEmpty
447+
val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head)
448+
if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
449+
}
450+
case Append(l, r) =>
451+
go(if (lhs.isEmpty) l else Append(lhs, l), count, r)
452+
case Singleton(_) =>
453+
// due to the invariant count >= 1
454+
lhs match {
455+
case lhsNE: NonEmpty[A] if count > 1L =>
456+
go(Empty, count - 1L, lhsNE)
457+
case _ =>
458+
lhs
459+
}
460+
}
461+
462+
this match {
463+
case ne: NonEmpty[A] if count > 0L =>
464+
go(Empty, count, ne)
465+
case _ =>
466+
this
467+
}
468+
}
469+
278470
/**
279471
* Folds over the elements from right to left using the supplied initial value and function.
280472
*/
@@ -940,7 +1132,8 @@ object Chain extends ChainInstances with ChainCompanionCompat {
9401132
* if the length is one, fromSeq returns Singleton
9411133
*
9421134
* The only places we create Wrap is in fromSeq and in methods that preserve
943-
* length: zipWithIndex, map, sort
1135+
* length: zipWithIndex, map, sort. Additionally, in drop/dropRight we carefully
1136+
* preserve this invariant.
9441137
*/
9451138
final private[data] case class Wrap[A](seq: immutable.Seq[A]) extends NonEmpty[A]
9461139

tests/shared/src/test/scala/cats/tests/ChainSuite.scala

+34
Original file line numberDiff line numberDiff line change
@@ -448,4 +448,38 @@ class ChainSuite extends CatsSuite {
448448
assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn))
449449
}
450450
}
451+
452+
private val genChainDropTakeArgs =
453+
Arbitrary.arbitrary[Chain[Int]].flatMap { chain =>
454+
// Bias to values close to the length
455+
Gen
456+
.oneOf(
457+
Gen.choose(Int.MinValue, Int.MaxValue),
458+
Gen.choose(-1, chain.length.toInt + 1)
459+
)
460+
.map((chain, _))
461+
}
462+
463+
test("drop(cnt).toList == toList.drop(cnt)") {
464+
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
465+
assertEquals(chain.drop(count).toList, chain.toList.drop(count))
466+
}
467+
}
468+
469+
test("dropRight(cnt).toList == toList.dropRight(cnt)") {
470+
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
471+
assertEquals(chain.dropRight(count).toList, chain.toList.dropRight(count))
472+
}
473+
}
474+
test("take(cnt).toList == toList.take(cnt)") {
475+
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
476+
assertEquals(chain.take(count).toList, chain.toList.take(count))
477+
}
478+
}
479+
480+
test("takeRight(cnt).toList == toList.takeRight(cnt)") {
481+
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
482+
assertEquals(chain.takeRight(count).toList, chain.toList.takeRight(count))
483+
}
484+
}
451485
}

0 commit comments

Comments
 (0)