Skip to content

Commit 50a155a

Browse files
authored
Merge pull request #147 from ev-br/asarray_no_nested
BUG: do not allow asarray of nested sequences of arrays
2 parents 25cc3d7 + 6029e2f commit 50a155a

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

Diff for: array_api_strict/_creation_functions.py

+4
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def asarray(
121121

122122
if isinstance(obj, Array):
123123
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
124+
elif isinstance(obj, list | tuple):
125+
if any(isinstance(x, Array) for x in obj):
126+
raise TypeError("Nested Arrays are not allowed. Use `stack` instead.")
127+
124128
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
125129
# Give a better error message in this case. NumPy would convert this
126130
# to an object array. TODO: This won't handle large integers in lists.

Diff for: array_api_strict/tests/test_array_object.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,12 @@ def test_array_conversion():
482482
# __array__, which is only used in asarray when converting lists of
483483
# arrays.
484484
a = ones((2, 3))
485-
asarray([a])
485+
np.asarray(a)
486486

487487
for device in ("device1", "device2"):
488488
a = ones((2, 3), device=array_api_strict.Device(device))
489489
with pytest.raises(RuntimeError, match="Can not convert array"):
490-
asarray([a])
490+
np.asarray(a)
491491

492492
def test__array__():
493493
# __array__ should work for now

Diff for: array_api_strict/tests/test_creation_functions.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
zeros,
2323
zeros_like,
2424
)
25-
from .._dtypes import int16, float32, float64
25+
from .._dtypes import float32, float64
2626
from .._array_object import Array, CPU_DEVICE, Device
2727
from .._flags import set_array_api_strict_flags
2828

@@ -97,18 +97,20 @@ def test_asarray_copy():
9797
a[0] = 0
9898
assert all(b[0] == 0)
9999

100+
100101
def test_asarray_list_of_lists():
101-
a = asarray(1, dtype=int16)
102-
b = asarray([1], dtype=int16)
103-
res = asarray([a, a])
104-
assert res.shape == (2,)
105-
assert res.dtype == int16
106-
assert all(res == asarray([1, 1]))
107-
108-
res = asarray([b, b])
109-
assert res.shape == (2, 1)
110-
assert res.dtype == int16
111-
assert all(res == asarray([[1], [1]]))
102+
lst = [[1, 2, 3], [4, 5, 6]]
103+
res = asarray(lst)
104+
assert res.shape == (2, 3)
105+
106+
107+
def test_asarray_nested_arrays():
108+
# do not allow arrays in nested sequences
109+
with pytest.raises(TypeError):
110+
asarray([[1, 2, 3], asarray([4, 5, 6])])
111+
112+
with pytest.raises(TypeError):
113+
asarray([1, asarray(1)])
112114

113115

114116
def test_asarray_device_inference():

0 commit comments

Comments
 (0)