Skip to content

Commit ddc14d8

Browse files
committed
TST: adapt tests for the lack of __array__
1 parent a3c04ab commit ddc14d8

File tree

2 files changed

+3
-32
lines changed

2 files changed

+3
-32
lines changed

array_api_strict/_creation_functions.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,6 @@
1616

1717
import numpy as np
1818

19-
@contextmanager
20-
def allow_array():
21-
"""
22-
Temporarily enable Array.__array__. This is needed for np.array to parse
23-
list of lists of Array objects.
24-
"""
25-
from . import _array_object
26-
original_value = _array_object._allow_array
27-
try:
28-
_array_object._allow_array = True
29-
yield
30-
finally:
31-
_array_object._allow_array = original_value
3219

3320
def _check_valid_dtype(dtype):
3421
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -112,8 +99,8 @@ def asarray(
11299
# Give a better error message in this case. NumPy would convert this
113100
# to an object array. TODO: This won't handle large integers in lists.
114101
raise OverflowError("Integer out of bounds for array dtypes")
115-
with allow_array():
116-
res = np.array(obj, dtype=_np_dtype, copy=copy)
102+
103+
res = np.array(obj, dtype=_np_dtype, copy=copy)
117104
return Array._new(res, device=device)
118105

119106

array_api_strict/tests/test_array_object.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -361,25 +361,9 @@ def test_array_conversion():
361361

362362
for device in ("device1", "device2"):
363363
a = ones((2, 3), device=array_api_strict.Device(device))
364-
with pytest.raises(RuntimeError, match="Can not convert array"):
364+
with pytest.raises((RuntimeError, TypeError)):
365365
asarray([a])
366366

367-
def test__array__():
368-
# __array__ should work for now
369-
a = ones((2, 3))
370-
np.array(a)
371-
372-
# Test the _allow_array private global flag for disabling it in the
373-
# future.
374-
from .. import _array_object
375-
original_value = _array_object._allow_array
376-
try:
377-
_array_object._allow_array = False
378-
a = ones((2, 3))
379-
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
380-
np.array(a)
381-
finally:
382-
_array_object._allow_array = original_value
383367

384368
def test_allow_newaxis():
385369
a = ones(5)

0 commit comments

Comments
 (0)