Skip to content

Commit 7d4fae5

Browse files
authored
Merge branch 'main' into dependabot/github_actions/actions/checkout-4
2 parents c377e58 + fae63d2 commit 7d4fae5

File tree

7 files changed

+61
-24
lines changed

7 files changed

+61
-24
lines changed

array_api_compat/common/_aliases.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -325,31 +325,30 @@ def _asarray(
325325
else:
326326
COPY_FALSE = (False,)
327327
COPY_TRUE = (True,)
328-
if copy in COPY_FALSE:
328+
if copy in COPY_FALSE and namespace != "dask.array":
329329
# copy=False is not yet implemented in xp.asarray
330330
raise NotImplementedError("copy=False is not yet implemented")
331-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
332-
#print('hit me')
331+
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
333332
if dtype is not None and obj.dtype != dtype:
334333
copy = True
335-
#print(copy)
336334
if copy in COPY_TRUE:
337-
copy_kwargs = {}
338-
if namespace != "dask.array":
339-
copy_kwargs["copy"] = True
340-
else:
341-
# No copy kw in dask.asarray so we go thorugh np.asarray first
342-
# (like dask also does) but copy after
343-
if dtype is None:
344-
# Same dtype copy is no-op in dask
345-
#print("in here?")
346-
return obj.copy()
347-
import numpy as np
348-
#print(obj)
349-
obj = np.asarray(obj).copy()
350-
#print(obj)
351-
return xp.array(obj, dtype=dtype, **copy_kwargs)
335+
return xp.array(obj, copy=True, dtype=dtype)
352336
return obj
337+
elif namespace == "dask.array":
338+
if copy in COPY_TRUE:
339+
if dtype is None:
340+
return obj.copy()
341+
# Go through numpy, since dask copy is no-op by default
342+
import numpy as np
343+
obj = np.array(obj, dtype=dtype, copy=True)
344+
return xp.array(obj, dtype=dtype)
345+
else:
346+
import dask.array as da
347+
import numpy as np
348+
if not isinstance(obj, da.Array):
349+
obj = np.asarray(obj, dtype=dtype)
350+
return da.from_array(obj)
351+
return obj
353352

354353
return xp.asarray(obj, dtype=dtype, **kwargs)
355354

array_api_compat/dask/array/linalg.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from typing import TYPE_CHECKING
1717
if TYPE_CHECKING:
1818
from ...common._typing import Array
19+
from typing import Literal
1920

20-
# cupy.linalg doesn't have __all__. If it is added, replace this with
21+
# dask.array.linalg doesn't have __all__. If it is added, replace this with
2122
#
22-
# from cupy.linalg import __all__ as linalg_all
23+
# from dask.array.linalg import __all__ as linalg_all
2324
_n = {}
2425
exec('from dask.array.linalg import *', _n)
2526
del _n['__builtins__']
@@ -32,7 +33,15 @@
3233
QRResult = _linalg.QRResult
3334
SlogdetResult = _linalg.SlogdetResult
3435
SVDResult = _linalg.SVDResult
35-
qr = get_xp(da)(_linalg.qr)
36+
# TODO: use the QR wrapper once dask
37+
# supports the mode keyword on QR
38+
# https://github.com/dask/dask/issues/10388
39+
#qr = get_xp(da)(_linalg.qr)
40+
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
41+
**kwargs) -> QRResult:
42+
if mode != "reduced":
43+
raise ValueError("dask arrays only support using mode='reduced'")
44+
return QRResult(*da.linalg.qr(x, **kwargs))
3645
cholesky = get_xp(da)(_linalg.cholesky)
3746
matrix_rank = get_xp(da)(_linalg.matrix_rank)
3847
matrix_norm = get_xp(da)(_linalg.matrix_norm)
@@ -44,7 +53,7 @@
4453
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
4554
if full_matrices:
4655
raise ValueError("full_matrics=True is not supported by dask.")
47-
return da.linalg.svd(x, **kwargs)
56+
return da.linalg.svd(x, coerce_signs=False, **kwargs)
4857

4958
def svdvals(x: Array) -> Array:
5059
# TODO: can't avoid computing U or V for dask

dask-xfails.txt

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32]
3737
# (I think the test is not forcing the op to be computed?)
3838
array_api_tests/test_creation_functions.py::test_linspace
3939

40+
# out.shape=(2,) but should be (1,)
41+
array_api_tests/test_indexing_functions.py::test_take
42+
4043
# out=-0, but should be +0
4144
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
4245
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]

numpy-1-21-xfails.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ array_api_tests/test_set_functions.py::test_unique_values
8989
# The test suite is incorrectly checking sums that have loss of significance
9090
# (https://github.com/data-apis/array-api-tests/issues/168)
9191
array_api_tests/test_statistical_functions.py::test_sum
92+
array_api_tests/test_statistical_functions.py::test_prod
9293

9394
# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with
9495
# type promotion issues
@@ -121,21 +122,24 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bi
121122
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
122123
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
123124
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
124-
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
125125
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
126+
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
126127
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
127128
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
128129
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
129130
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
130131
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
131132
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
132133
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
133135
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
134136
array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
135137
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
136138
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
137139
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
138140
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
141+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
142+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
139143
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
140144
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]
141145
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]

numpy-dev-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
4242
# The test suite is incorrectly checking sums that have loss of significance
4343
# (https://github.com/data-apis/array-api-tests/issues/168)
4444
array_api_tests/test_statistical_functions.py::test_sum
45+
array_api_tests/test_statistical_functions.py::test_prod

numpy-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
4444
# The test suite is incorrectly checking sums that have loss of significance
4545
# (https://github.com/data-apis/array-api-tests/issues/168)
4646
array_api_tests/test_statistical_functions.py::test_sum
47+
array_api_tests/test_statistical_functions.py::test_prod

tests/test_common.py

+20
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,23 @@ def test_to_device_host(library):
6060
# here is that we can test portably after calling
6161
# to_device(x, "cpu") to return to host
6262
assert_allclose(x, expected)
63+
64+
65+
@pytest.mark.parametrize("target_library,func", is_functions.items())
66+
@pytest.mark.parametrize("source_library", is_functions.keys())
67+
def test_asarray(source_library, target_library, func, request):
68+
if source_library == "dask.array" and target_library == "torch":
69+
# Allow rest of test to execute instead of immediately xfailing
70+
# xref https://github.com/pandas-dev/pandas/issues/38902
71+
72+
# TODO: remove xfail once
73+
# https://github.com/dask/dask/issues/8260 is resolved
74+
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
75+
src_lib = import_(source_library, wrapper=True)
76+
tgt_lib = import_(target_library, wrapper=True)
77+
is_tgt_type = globals()[func]
78+
79+
a = src_lib.asarray([1, 2, 3])
80+
b = tgt_lib.asarray(a)
81+
82+
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"

0 commit comments

Comments
 (0)