Skip to content

Commit 293466a

Browse files
committed
add repeat()
1 parent 16fdf2d commit 293466a

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

tests/test_usage.py

+22
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,28 @@ def test_no_end(self):
841841
assert lst == [*range(77, 177)]
842842

843843

844+
class TestRepeatMethod:
845+
def test_repeat(self):
846+
en = Enumerable.repeat('a', 5)
847+
assert en.to_list() == ['a'] * 5
848+
849+
def test_no_elem(self):
850+
en = Enumerable.repeat('g', 0)
851+
assert en.to_list() == []
852+
853+
def test_1_elem(self):
854+
en = Enumerable.repeat('x', 1)
855+
assert en.to_list() == ['x']
856+
857+
def test_invalid(self):
858+
with pytest.raises(ValueError):
859+
Enumerable.repeat(99, -1)
860+
861+
def test_no_end(self):
862+
hellos = Enumerable.repeat((), None).take(107)
863+
assert hellos.to_list() == [()] * 107
864+
865+
844866
class TestSelectMethod:
845867
def test_select(self):
846868
gen_func = lambda: (i for i in range(4))

types_linq/enumerable.py

+15
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,21 @@ def inner(curr=start):
538538
curr += 1
539539
return Enumerable(inner)
540540

541+
@staticmethod
542+
def repeat(value: TResult, count: Optional[int] = None) -> Enumerable[TResult]:
543+
if count is not None:
544+
if count < 0:
545+
raise ValueError('count must be nonnegative')
546+
def inner(val=value, cnt=count):
547+
while cnt > 0:
548+
yield val
549+
cnt -= 1
550+
else:
551+
def inner(val=value):
552+
while True:
553+
yield val
554+
return Enumerable(inner)
555+
541556
def reverse(self) -> Enumerable[TSource_co]:
542557
return Enumerable(lambda: self._reversed_impl(fallback=True))
543558

types_linq/enumerable.pyi

+9-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional,
33
from .lookup import Lookup
44
from .grouping import Grouping
55
from .ordered_enumerable import OrderedEnumerable
6-
from .more_typing import SupportsAverage
6+
from .more_typing import SupportsAverage, TResult
77
from .more_typing import (
88
TAccumulate,
99
TCollection,
@@ -480,7 +480,14 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
480480
If `count` is `None`, the sequence is infinite.
481481
'''
482482

483-
# @@@ TODO
483+
# count: Optional[int] is nonstandard behavior
484+
@staticmethod
485+
def repeat(value: TResult, count: Optional[int] = None) -> Enumerable[TResult]:
486+
'''
487+
Generates a sequence that contains one repeated value.
488+
489+
If `count` is `None`, the sequence is infinite.
490+
'''
484491

485492
def reverse(self) -> Enumerable[TSource_co]:
486493
'''

0 commit comments

Comments
 (0)