Skip to content

Commit e199357

Browse files
authored
Merge pull request #42 from asmeurer/no-iter
Disable array iteration
2 parents c2c55ce + 935ea87 commit e199357

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

Diff for: array_api_strict/_array_object.py

+9
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,15 @@ def __invert__(self: Array, /) -> Array:
647647
res = self._array.__invert__()
648648
return self.__class__._new(res)
649649

650+
def __iter__(self: Array, /):
651+
"""
652+
Performs the operation __iter__.
653+
"""
654+
# Manually disable iteration, since __getitem__ raises IndexError on
655+
# things like ones((3, 3))[0], which causes list(ones((3, 3))) to give
656+
# [].
657+
raise TypeError("array iteration is not allowed in array-api-strict")
658+
650659
def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
651660
"""
652661
Performs the operation __le__.

Diff for: array_api_strict/tests/test_array_object.py

+4
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,7 @@ def test_array_namespace():
416416

417417
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
418418
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12"))
419+
420+
def test_no_iter():
421+
pytest.raises(TypeError, lambda: iter(ones(3)))
422+
pytest.raises(TypeError, lambda: iter(ones((3, 3))))

0 commit comments

Comments
 (0)