Skip to content

Commit 9cb5a13

Browse files
authoredFeb 6, 2024
Merge pull request #76 from lithomas1/add-dask
Add dask to array-api-compat
2 parents 916a84b + 54f4838 commit 9cb5a13

15 files changed

+427
-10
lines changed
 
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
name: Array API Tests (Dask)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-dask:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: dask
10+
module-name: dask.array
11+
extra-requires: numpy
12+
pytest-extra-args: --disable-deadline --max-examples=5

‎.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Install Dependencies
1616
run: |
1717
python -m pip install --upgrade pip
18-
python -m pip install pytest numpy torch
18+
python -m pip install pytest numpy torch dask[array]
1919
2020
- name: Run Tests
2121
run: |

‎array_api_compat/common/_aliases.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def _asarray(
303303
import numpy as xp
304304
elif namespace == 'cupy':
305305
import cupy as xp
306+
elif namespace == 'dask.array':
307+
import dask.array as xp
306308
else:
307309
raise ValueError("Unrecognized namespace argument to asarray()")
308310

@@ -322,11 +324,27 @@ def _asarray(
322324
if copy in COPY_FALSE:
323325
# copy=False is not yet implemented in xp.asarray
324326
raise NotImplementedError("copy=False is not yet implemented")
325-
if isinstance(obj, xp.ndarray):
327+
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
328+
#print('hit me')
326329
if dtype is not None and obj.dtype != dtype:
327330
copy = True
331+
#print(copy)
328332
if copy in COPY_TRUE:
329-
return xp.array(obj, copy=True, dtype=dtype)
333+
copy_kwargs = {}
334+
if namespace != "dask.array":
335+
copy_kwargs["copy"] = True
336+
else:
337+
# No copy kw in dask.asarray so we go thorugh np.asarray first
338+
# (like dask also does) but copy after
339+
if dtype is None:
340+
# Same dtype copy is no-op in dask
341+
#print("in here?")
342+
return obj.copy()
343+
import numpy as np
344+
#print(obj)
345+
obj = np.asarray(obj).copy()
346+
#print(obj)
347+
return xp.array(obj, dtype=dtype, **copy_kwargs)
330348
return obj
331349

332350
return xp.asarray(obj, dtype=dtype, **kwargs)

‎array_api_compat/common/_helpers.py

+24
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,23 @@ def _is_torch_array(x):
4040
# TODO: Should we reject ndarray subclasses?
4141
return isinstance(x, torch.Tensor)
4242

43+
def _is_dask_array(x):
44+
# Avoid importing dask if it isn't already
45+
if 'dask.array' not in sys.modules:
46+
return False
47+
48+
import dask.array
49+
50+
return isinstance(x, dask.array.Array)
51+
4352
def is_array_api_obj(x):
4453
"""
4554
Check if x is an array API compatible array object.
4655
"""
4756
return _is_numpy_array(x) \
4857
or _is_cupy_array(x) \
4958
or _is_torch_array(x) \
59+
or _is_dask_array(x) \
5060
or hasattr(x, '__array_namespace__')
5161

5262
def _check_api_version(api_version):
@@ -95,6 +105,13 @@ def your_function(x, y):
95105
else:
96106
import torch
97107
namespaces.add(torch)
108+
elif _is_dask_array(x):
109+
_check_api_version(api_version)
110+
if _use_compat:
111+
from ..dask import array as dask_namespace
112+
namespaces.add(dask_namespace)
113+
else:
114+
raise TypeError("_use_compat cannot be False if input array is a dask array!")
98115
elif hasattr(x, '__array_namespace__'):
99116
namespaces.add(x.__array_namespace__(api_version=api_version))
100117
else:
@@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
219236
return _cupy_to_device(x, device, stream=stream)
220237
elif _is_torch_array(x):
221238
return _torch_to_device(x, device, stream=stream)
239+
elif _is_dask_array(x):
240+
if stream is not None:
241+
raise ValueError("The stream argument to to_device() is not supported")
242+
# TODO: What if our array is on the GPU already?
243+
if device == 'cpu':
244+
return x
245+
raise ValueError(f"Unsupported device {device!r}")
222246
return x.to_device(device, stream=stream)
223247

224248
def size(x):

‎array_api_compat/common/_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def matrix_rank(x: ndarray,
7777
# dimensional arrays.
7878
if x.ndim < 2:
7979
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
80-
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
80+
S = get_xp(xp)(svdvals)(x, **kwargs)
8181
if rtol is None:
8282
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
8383
else:
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dask.array import *
2+
3+
# These imports may overwrite names from the import * above.
4+
from ._aliases import *
5+
6+
__array_api_version__ = '2022.12'
7+
8+
__import__(__package__ + '.linalg')
+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
from ...common import _aliases
4+
from ...common._helpers import _check_device
5+
6+
from ..._internal import get_xp
7+
8+
import numpy as np
9+
from numpy import (
10+
# Constants
11+
e,
12+
inf,
13+
nan,
14+
pi,
15+
newaxis,
16+
# Dtypes
17+
bool_ as bool,
18+
float32,
19+
float64,
20+
int8,
21+
int16,
22+
int32,
23+
int64,
24+
uint8,
25+
uint16,
26+
uint32,
27+
uint64,
28+
complex64,
29+
complex128,
30+
iinfo,
31+
finfo,
32+
can_cast,
33+
result_type,
34+
)
35+
36+
from typing import TYPE_CHECKING
37+
if TYPE_CHECKING:
38+
from typing import Optional, Union
39+
from ...common._typing import ndarray, Device, Dtype
40+
41+
import dask.array as da
42+
43+
isdtype = get_xp(np)(_aliases.isdtype)
44+
astype = _aliases.astype
45+
46+
# Common aliases
47+
48+
# This arange func is modified from the common one to
49+
# not pass stop/step as keyword arguments, which will cause
50+
# an error with dask
51+
52+
# TODO: delete the xp stuff, it shouldn't be necessary
53+
def dask_arange(
54+
start: Union[int, float],
55+
/,
56+
stop: Optional[Union[int, float]] = None,
57+
step: Union[int, float] = 1,
58+
*,
59+
xp,
60+
dtype: Optional[Dtype] = None,
61+
device: Optional[Device] = None,
62+
**kwargs
63+
) -> ndarray:
64+
_check_device(xp, device)
65+
args = [start]
66+
if stop is not None:
67+
args.append(stop)
68+
else:
69+
# stop is None, so start is actually stop
70+
# prepend the default value for start which is 0
71+
args.insert(0, 0)
72+
args.append(step)
73+
return xp.arange(*args, dtype=dtype, **kwargs)
74+
75+
arange = get_xp(da)(dask_arange)
76+
eye = get_xp(da)(_aliases.eye)
77+
78+
from functools import partial
79+
asarray = partial(_aliases._asarray, namespace='dask.array')
80+
asarray.__doc__ = _aliases._asarray.__doc__
81+
82+
linspace = get_xp(da)(_aliases.linspace)
83+
eye = get_xp(da)(_aliases.eye)
84+
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
85+
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
86+
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
87+
unique_all = get_xp(da)(_aliases.unique_all)
88+
unique_counts = get_xp(da)(_aliases.unique_counts)
89+
unique_inverse = get_xp(da)(_aliases.unique_inverse)
90+
unique_values = get_xp(da)(_aliases.unique_values)
91+
permute_dims = get_xp(da)(_aliases.permute_dims)
92+
std = get_xp(da)(_aliases.std)
93+
var = get_xp(da)(_aliases.var)
94+
empty = get_xp(da)(_aliases.empty)
95+
empty_like = get_xp(da)(_aliases.empty_like)
96+
full = get_xp(da)(_aliases.full)
97+
full_like = get_xp(da)(_aliases.full_like)
98+
ones = get_xp(da)(_aliases.ones)
99+
ones_like = get_xp(da)(_aliases.ones_like)
100+
zeros = get_xp(da)(_aliases.zeros)
101+
zeros_like = get_xp(da)(_aliases.zeros_like)
102+
reshape = get_xp(da)(_aliases.reshape)
103+
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
104+
vecdot = get_xp(da)(_aliases.vecdot)
105+
106+
nonzero = get_xp(da)(_aliases.nonzero)
107+
sum = get_xp(np)(_aliases.sum)
108+
prod = get_xp(np)(_aliases.prod)
109+
ceil = get_xp(np)(_aliases.ceil)
110+
floor = get_xp(np)(_aliases.floor)
111+
trunc = get_xp(np)(_aliases.trunc)
112+
matmul = get_xp(np)(_aliases.matmul)
113+
tensordot = get_xp(np)(_aliases.tensordot)
114+
115+
from dask.array import (
116+
# Element wise aliases
117+
arccos as acos,
118+
arccosh as acosh,
119+
arcsin as asin,
120+
arcsinh as asinh,
121+
arctan as atan,
122+
arctan2 as atan2,
123+
arctanh as atanh,
124+
left_shift as bitwise_left_shift,
125+
right_shift as bitwise_right_shift,
126+
invert as bitwise_invert,
127+
power as pow,
128+
# Other
129+
concatenate as concat,
130+
)
131+
132+
# exclude these from all since
133+
_da_unsupported = ['sort', 'argsort']
134+
135+
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
136+
137+
__all__ = common_aliases + ['asarray', 'bool', 'acos',
138+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
139+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
140+
'bitwise_right_shift', 'concat', 'pow',
141+
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
142+
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
143+
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
144+
145+
del da, partial, common_aliases, _da_unsupported,

‎array_api_compat/dask/array/linalg.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
3+
from dask.array.linalg import *
4+
from ...common import _linalg
5+
from ..._internal import get_xp
6+
from dask.array import matmul, tensordot, trace, outer
7+
from ._aliases import matrix_transpose, vecdot
8+
9+
import dask.array as da
10+
11+
from typing import TYPE_CHECKING
12+
if TYPE_CHECKING:
13+
from typing import Union, Tuple
14+
from ...common._typing import ndarray
15+
16+
# cupy.linalg doesn't have __all__. If it is added, replace this with
17+
#
18+
# from cupy.linalg import __all__ as linalg_all
19+
_n = {}
20+
exec('from dask.array.linalg import *', _n)
21+
del _n['__builtins__']
22+
linalg_all = list(_n)
23+
del _n
24+
25+
EighResult = _linalg.EighResult
26+
QRResult = _linalg.QRResult
27+
SlogdetResult = _linalg.SlogdetResult
28+
SVDResult = _linalg.SVDResult
29+
qr = get_xp(da)(_linalg.qr)
30+
cholesky = get_xp(da)(_linalg.cholesky)
31+
matrix_rank = get_xp(da)(_linalg.matrix_rank)
32+
matrix_norm = get_xp(da)(_linalg.matrix_norm)
33+
34+
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
35+
# TODO: can't avoid computing U or V for dask
36+
_, s, _ = svd(x)
37+
return s
38+
39+
vector_norm = get_xp(da)(_linalg.vector_norm)
40+
diagonal = get_xp(da)(_linalg.diagonal)
41+
42+
__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
43+
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
44+
"svdvals", "vector_norm", "diagonal"]
45+
46+
del get_xp
47+
del da
48+
del _linalg

‎dask-skips.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# FFT isn't conformant
2+
array_api_tests/test_fft.py
3+
4+
# slow and not implemented in dask
5+
array_api_tests/test_linalg.py::test_matrix_power

‎dask-xfails.txt

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# This fails in dask
2+
# import dask.array as da
3+
# a = da.array([1]).reshape((1,1))
4+
# key = (0, slice(None, None, -1))
5+
# a[key] = da.array([1])
6+
7+
# Failing hypothesis test case
8+
#x=dask.array<zeros_like, shape=(0, 2), dtype=bool, chunksize=(0, 2), chunktype=numpy.ndarray>
9+
#| Draw 1 (key): (slice(None, None, None), slice(None, None, None))
10+
#| Draw 2 (value): dask.array<zeros_like, shape=(0, 2), dtype=bool, chunksize=(0, 2), chunktype=numpy.ndarray>
11+
12+
# Various shape mismatches e.g.
13+
ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2)
14+
array_api_tests/test_array_object.py::test_setitem
15+
16+
# Fails since bad upcast from uint8 -> int64
17+
# MRE:
18+
# a = da.array(0, dtype="uint8")
19+
# b = da.array(False)
20+
# a[b] = 0
21+
array_api_tests/test_array_object.py::test_setitem_masking
22+
23+
# Various indexing errors
24+
array_api_tests/test_array_object.py::test_getitem_masking
25+
26+
# asarray(copy=False) is not yet implemented
27+
# copied from numpy xfails, TODO: should this pass with dask?
28+
array_api_tests/test_creation_functions.py::test_asarray_arrays
29+
30+
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
31+
array_api_tests/test_creation_functions.py::test_eye
32+
33+
# finfo(float32).eps returns float32 but should return float
34+
array_api_tests/test_data_type_functions.py::test_finfo[float32]
35+
36+
# out[-1]=dask.aray<getitem ...> but should be some floating number
37+
# (I think the test is not forcing the op to be computed?)
38+
array_api_tests/test_creation_functions.py::test_linspace
39+
40+
# out=-0, but should be +0
41+
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
42+
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
43+
44+
# output is nan but should be infinity
45+
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
46+
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
47+
48+
# No sorting in dask
49+
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
50+
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
51+
array_api_tests/test_sorting_functions.py::test_argsort
52+
array_api_tests/test_sorting_functions.py::test_sort
53+
array_api_tests/test_signatures.py::test_func_signature[argsort]
54+
array_api_tests/test_signatures.py::test_func_signature[sort]
55+
56+
# Array methods and attributes not already on np.ndarray cannot be wrapped
57+
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
58+
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
59+
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
60+
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
61+
62+
# Fails because shape is NaN since we don't materialize it yet
63+
array_api_tests/test_searching_functions.py::test_nonzero
64+
array_api_tests/test_set_functions.py::test_unique_all
65+
array_api_tests/test_set_functions.py::test_unique_counts
66+
67+
# Different error but same cause as above, we're just trying to do ndindex on nan shape
68+
array_api_tests/test_set_functions.py::test_unique_inverse
69+
array_api_tests/test_set_functions.py::test_unique_values
70+
71+
# Linalg failures (signature failures/missing methods)
72+
73+
# fails for ndim > 2
74+
array_api_tests/test_linalg.py::test_svdvals
75+
array_api_tests/test_linalg.py::test_cholesky
76+
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
77+
array_api_tests/test_linalg.py::test_tensordot
78+
# probably same reason for failing as numpy
79+
array_api_tests/test_linalg.py::test_trace
80+
81+
# Linalg - these don't exist in dask
82+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
83+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
84+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh]
85+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh]
86+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power]
87+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
88+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet]
89+
array_api_tests/test_linalg.py::test_cross
90+
array_api_tests/test_linalg.py::test_det
91+
array_api_tests/test_linalg.py::test_eigvalsh
92+
array_api_tests/test_linalg.py::test_pinv
93+
array_api_tests/test_linalg.py::test_slogdet
94+
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
95+
array_api_tests/test_has_names.py::test_has_names[linalg-det]
96+
array_api_tests/test_has_names.py::test_has_names[linalg-eigh]
97+
array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh]
98+
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
99+
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
100+
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
101+
102+
array_api_tests/test_linalg.py::test_matrix_norm
103+
array_api_tests/test_linalg.py::test_matrix_rank
104+
105+
# missing mode kw
106+
# https://github.com/dask/dask/issues/10388
107+
array_api_tests/test_linalg.py::test_qr
108+
109+
# Constructing the input arrays fails to a weird shape error...
110+
array_api_tests/test_linalg.py::test_solve
111+
112+
# missing full_matrics kw
113+
# https://github.com/dask/dask/issues/10389
114+
# also only supports 2-d inputs
115+
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd]
116+
array_api_tests/test_linalg.py::test_svd
117+
118+
# Missing dlpack stuff
119+
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
120+
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
121+
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
122+
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
123+
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
124+
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
125+
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
126+
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
127+
128+
# Some cases unsupported by dask
129+
array_api_tests/test_manipulation_functions.py::test_roll
130+
131+
# No mT on dask array
132+
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices

‎tests/test_array_namespace.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import pytest
77

8-
9-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
109
@pytest.mark.parametrize("api_version", [None, '2021.12'])
1110
def test_array_namespace(library, api_version):
1211
lib = import_(library)
@@ -17,7 +16,10 @@ def test_array_namespace(library, api_version):
1716
if 'array_api' in library:
1817
assert namespace == lib
1918
else:
20-
assert namespace == getattr(array_api_compat, library)
19+
if library == "dask.array":
20+
assert namespace == array_api_compat.dask.array
21+
else:
22+
assert namespace == getattr(array_api_compat, library)
2123

2224

2325
def test_array_namespace_errors():

‎tests/test_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from numpy.testing import assert_allclose
77

8-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
99
def test_to_device_host(library):
1010
# different libraries have different semantics
1111
# for DtoH transfers; ensure that we support a portable

‎tests/test_isdtype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
6464
assert type(res) is bool
6565
return res
6666

67-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
67+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
6868
def test_isdtype_spec_dtypes(library):
6969
xp = import_('array_api_compat.' + library)
7070

@@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library):
9898
'bfloat16',
9999
]
100100

101-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
101+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
102102
@pytest.mark.parametrize("dtype_", additional_dtypes)
103103
def test_isdtype_additional_dtypes(library, dtype_):
104104
xp = import_('array_api_compat.' + library)

‎tests/test_vendoring.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ def test_vendoring_cupy():
1717
def test_vendoring_torch():
1818
from vendor_test import uses_torch
1919
uses_torch._test_torch()
20+
21+
def test_vendoring_dask():
22+
from vendor_test import uses_dask
23+
uses_dask._test_dask()

‎vendor_test/uses_dask.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Basic test that vendoring works
2+
3+
from .vendored._compat.dask import array as dask_compat
4+
5+
import dask.array as da
6+
import numpy as np
7+
8+
def _test_dask():
9+
a = dask_compat.asarray([1., 2., 3.])
10+
b = dask_compat.arange(3, dtype=dask_compat.float32)
11+
12+
# np.pow does not exist. Update this to use something else if it is added
13+
res = dask_compat.pow(a, b)
14+
assert res.dtype == dask_compat.float64 == np.float64
15+
assert isinstance(a, da.Array)
16+
assert isinstance(b, da.Array)
17+
assert isinstance(res, da.Array)
18+
19+
np.testing.assert_allclose(res, [1., 2., 9.])

0 commit comments

Comments
 (0)
Please sign in to comment.