From 6029e2fbc252a81f85bb218806b765fd32d72f9f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 9 Apr 2025 13:11:12 +0200 Subject: [PATCH] BUG: do not allow asarray of nested sequences of arrays --- array_api_strict/_creation_functions.py | 4 +++ array_api_strict/tests/test_array_object.py | 4 +-- .../tests/test_creation_functions.py | 26 ++++++++++--------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 3b80b8a..db3897c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -121,6 +121,10 @@ def asarray( if isinstance(obj, Array): return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) + elif isinstance(obj, list | tuple): + if any(isinstance(x, Array) for x in obj): + raise TypeError("Nested Arrays are not allowed. Use `stack` instead.") + if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c7330d8..dbab1af 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -482,12 +482,12 @@ def test_array_conversion(): # __array__, which is only used in asarray when converting lists of # arrays. a = ones((2, 3)) - asarray([a]) + np.asarray(a) for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) with pytest.raises(RuntimeError, match="Can not convert array"): - asarray([a]) + np.asarray(a) def test__array__(): # __array__ should work for now diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index fc4e3cb..573fc7f 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -22,7 +22,7 @@ zeros, zeros_like, ) -from .._dtypes import int16, float32, float64 +from .._dtypes import float32, float64 from .._array_object import Array, CPU_DEVICE, Device from .._flags import set_array_api_strict_flags @@ -97,18 +97,20 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) + def test_asarray_list_of_lists(): - a = asarray(1, dtype=int16) - b = asarray([1], dtype=int16) - res = asarray([a, a]) - assert res.shape == (2,) - assert res.dtype == int16 - assert all(res == asarray([1, 1])) - - res = asarray([b, b]) - assert res.shape == (2, 1) - assert res.dtype == int16 - assert all(res == asarray([[1], [1]])) + lst = [[1, 2, 3], [4, 5, 6]] + res = asarray(lst) + assert res.shape == (2, 3) + + +def test_asarray_nested_arrays(): + # do not allow arrays in nested sequences + with pytest.raises(TypeError): + asarray([[1, 2, 3], asarray([4, 5, 6])]) + + with pytest.raises(TypeError): + asarray([1, asarray(1)]) def test_asarray_device_inference():