Skip to content

Commit 71ff7a1

Browse files
committedNov 13, 2021
add support for negative index to element_at()
reflect .net6 ElementAt's index overload
1 parent 0a77256 commit 71ff7a1

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed
 

‎doc/api/types_linq.enumerable.rst

+7
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ Returns
639639
Returns the element at specified index in the sequence. `IndexOutOfRangeError` is raised if
640640
no such element exists.
641641

642+
If the index is negative, it means counting from the end.
643+
642644
This method always uses a generic list element-finding method (O(n)) regardless the
643645
implementation of the wrapped iterable.
644646

@@ -651,6 +653,9 @@ Example
651653
>>> Enumerable(gen()).element_at(1)
652654
10
653655
656+
>>> Enumerable(gen()).element_at(-1)
657+
100
658+
654659
----
655660

656661
instancemethod ``element_at[TDefault](index, __default)``
@@ -666,6 +671,8 @@ Returns
666671
Returns the element at specified index in the sequence. Default value is returned if no
667672
such element exists.
668673

674+
If the index is negative, it means counting from the end.
675+
669676
This method always uses a generic list element-finding method (O(n)) regardless the
670677
implementation of the wrapped iterable.
671678

‎tests/test_usage.py

+10
Original file line numberDiff line numberDiff line change
@@ -327,17 +327,27 @@ def test_overload1_0(self):
327327
gen = lambda: (i for i in range(7, 12))
328328
en = Enumerable(gen)
329329
assert en.element_at(0) == 7
330+
assert en.element_at(-5) == 7
330331

331332
def test_overload1_end(self):
332333
gen = lambda: (i for i in range(7, 12))
333334
en = Enumerable(gen)
334335
assert en.element_at(4) == 11
336+
assert en.element_at(-1) == 11
335337

336338
def test_overload1_out(self):
337339
gen = lambda: (i for i in range(7, 12))
338340
en = Enumerable(gen)
339341
with pytest.raises(IndexOutOfRangeError):
340342
en.element_at(5)
343+
with pytest.raises(IndexOutOfRangeError):
344+
en.element_at(-6)
345+
346+
def test_overload2_within(self):
347+
gen = lambda: (i for i in range(7, 12))
348+
en = Enumerable(gen)
349+
assert en.element_at(2, 0) == 9
350+
assert en.element_at(-3, 0) == 9
341351

342352
def test_overload2_out(self):
343353
gen = lambda: (i for i in range(7, 12))

‎types_linq/enumerable.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def _contains_impl(self, value: object, fallback: bool) -> bool:
5555
def __contains__(self, value: object) -> bool:
5656
return self._contains_impl(value, fallback=False)
5757

58+
@staticmethod
59+
def _raise_not_enough_elements() -> NoReturn:
60+
raise IndexOutOfRangeError('Not enough elements in the sequence')
61+
5862
def _every(self, step: int) -> Enumerable[TSource_co]:
5963
return self.where2(lambda _, i: i % step == 0)
6064

@@ -71,20 +75,27 @@ def _getitem_impl(self,
7175
return iterable[index]
7276
except IndexError as e:
7377
raise IndexOutOfRangeError from e
74-
iterator = iter(iterable)
75-
try:
76-
for _ in range(index):
77-
next(iterator)
78-
return next(iterator)
79-
except StopIteration:
80-
raise IndexOutOfRangeError('Not enough elements in the sequence')
78+
if index >= 0:
79+
iterator = iter(iterable)
80+
try:
81+
for _ in range(index):
82+
next(iterator)
83+
return next(iterator)
84+
except StopIteration:
85+
self._raise_not_enough_elements()
86+
else:
87+
en = iterable if isinstance(iterable, Enumerable) else Enumerable(iterable)
88+
last = en.take_last(-index).to_list()
89+
if len(last) < -index:
90+
self._raise_not_enough_elements()
91+
return last[0]
8192

8293
else: # isinstance(index, slice)
8394
if not fallback and isinstance(iterable, Sequence):
8495
try:
8596
res = iterable[index]
8697
except IndexError as e:
87-
raise IndexOutOfRangeError(e)
98+
raise IndexOutOfRangeError from e
8899
return res if isinstance(res, Enumerable) else Enumerable(res)
89100
# we do not enumerate all values if the begin and the end only involve
90101
# nonnegative indices since in which case the sliced part can be obtained

‎types_linq/enumerable.pyi

+7
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
476476
Returns the element at specified index in the sequence. `IndexOutOfRangeError` is raised if
477477
no such element exists.
478478
479+
If the index is negative, it means counting from the end.
480+
479481
This method always uses a generic list element-finding method (O(n)) regardless the
480482
implementation of the wrapped iterable.
481483
@@ -487,6 +489,9 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
487489
488490
>>> Enumerable(gen()).element_at(1)
489491
10
492+
493+
>>> Enumerable(gen()).element_at(-1)
494+
100
490495
'''
491496

492497
@overload
@@ -495,6 +500,8 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
495500
Returns the element at specified index in the sequence. Default value is returned if no
496501
such element exists.
497502
503+
If the index is negative, it means counting from the end.
504+
498505
This method always uses a generic list element-finding method (O(n)) regardless the
499506
implementation of the wrapped iterable.
500507

0 commit comments

Comments
 (0)
Please sign in to comment.