Skip to content

Commit 5cd47df

Browse files
committed
Adapt dask
1 parent 1720fb6 commit 5cd47df

File tree

2 files changed

+183
-62
lines changed

2 files changed

+183
-62
lines changed
+182-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,186 @@
1-
from dask.array import *
1+
from dask.array import * # noqa: F401, F403
2+
from dask.array import __all__ as _dask_array_all
3+
4+
from dask.array import (
5+
# Element wise aliases
6+
arccos as acos,
7+
arccosh as acosh,
8+
arcsin as asin,
9+
arcsinh as asinh,
10+
arctan as atan,
11+
arctan2 as atan2,
12+
arctanh as atanh,
13+
# Other
14+
concatenate as concat,
15+
invert as bitwise_invert,
16+
left_shift as bitwise_left_shift,
17+
power as pow,
18+
right_shift as bitwise_right_shift,
19+
bool_ as bool,
20+
)
221

322
# These imports may overwrite names from the import * above.
4-
from ._aliases import *
23+
from numpy import (
24+
can_cast,
25+
complex64,
26+
complex128,
27+
e,
28+
finfo,
29+
float32,
30+
float64,
31+
iinfo,
32+
inf,
33+
int8,
34+
int16,
35+
int32,
36+
int64,
37+
nan,
38+
newaxis,
39+
pi,
40+
result_type,
41+
uint8,
42+
uint16,
43+
uint32,
44+
uint64,
45+
)
46+
47+
from ..common._helpers import (
48+
array_namespace,
49+
device,
50+
get_namespace,
51+
is_array_api_obj,
52+
size,
53+
to_device,
54+
)
55+
from ._aliases import (
56+
UniqueAllResult,
57+
UniqueCountsResult,
58+
UniqueInverseResult,
59+
arange,
60+
asarray,
61+
astype,
62+
ceil,
63+
empty,
64+
empty_like,
65+
eye,
66+
floor,
67+
full,
68+
full_like,
69+
isdtype,
70+
linspace,
71+
matmul,
72+
matrix_transpose,
73+
nonzero,
74+
ones,
75+
ones_like,
76+
permute_dims,
77+
prod,
78+
reshape,
79+
std,
80+
sum,
81+
tensordot,
82+
trunc,
83+
unique_all,
84+
unique_counts,
85+
unique_inverse,
86+
unique_values,
87+
var,
88+
vecdot,
89+
zeros,
90+
zeros_like,
91+
)
92+
93+
__all__ = []
94+
95+
__all__ += _dask_array_all
96+
97+
__all__ += [
98+
"can_cast",
99+
"complex64",
100+
"complex128",
101+
"e",
102+
"finfo",
103+
"float32",
104+
"float64",
105+
"iinfo",
106+
"inf",
107+
"int8",
108+
"int16",
109+
"int32",
110+
"int64",
111+
"nan",
112+
"newaxis",
113+
"pi",
114+
"result_type",
115+
"uint8",
116+
"uint16",
117+
"uint32",
118+
"uint64",
119+
]
120+
121+
__all__ += [
122+
"array_namespace",
123+
"device",
124+
"get_namespace",
125+
"is_array_api_obj",
126+
"size",
127+
"to_device",
128+
]
129+
130+
# 'sort', 'argsort' are unsupported by dask.array
131+
132+
__all__ += [
133+
"UniqueAllResult",
134+
"UniqueCountsResult",
135+
"UniqueInverseResult",
136+
"acos",
137+
"acosh",
138+
"arange",
139+
"asarray",
140+
"asin",
141+
"asinh",
142+
"astype",
143+
"atan",
144+
"atan2",
145+
"atanh",
146+
"bitwise_invert",
147+
"bitwise_left_shift",
148+
"bitwise_right_shift",
149+
"bool",
150+
"ceil",
151+
"concat",
152+
"empty",
153+
"empty_like",
154+
"eye",
155+
"floor",
156+
"full",
157+
"full_like",
158+
"isdtype",
159+
"linspace",
160+
"matmul",
161+
"matrix_transpose",
162+
"nonzero",
163+
"ones",
164+
"ones_like",
165+
"permute_dims",
166+
"pow",
167+
"prod",
168+
"reshape",
169+
"std",
170+
"sum",
171+
"tensordot",
172+
"trunc",
173+
"unique_all",
174+
"unique_counts",
175+
"unique_inverse",
176+
"unique_values",
177+
"var",
178+
"vecdot",
179+
"zeros",
180+
"zeros_like",
181+
]
182+
5183

6-
__array_api_version__ = '2022.12'
184+
__array_api_version__ = "2022.12"
7185

8-
__import__(__package__ + '.linalg')
186+
__import__(__package__ + ".linalg")

array_api_compat/dask/array/_aliases.py

+1-58
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,12 @@
11
from __future__ import annotations
2+
from functools import partial
23

34
from ...common import _aliases
45
from ...common._helpers import _check_device
56

67
from ..._internal import get_xp
78

89
import numpy as np
9-
from numpy import (
10-
# Constants
11-
e,
12-
inf,
13-
nan,
14-
pi,
15-
newaxis,
16-
# Dtypes
17-
bool_ as bool,
18-
float32,
19-
float64,
20-
int8,
21-
int16,
22-
int32,
23-
int64,
24-
uint8,
25-
uint16,
26-
uint32,
27-
uint64,
28-
complex64,
29-
complex128,
30-
iinfo,
31-
finfo,
32-
can_cast,
33-
result_type,
34-
)
3510

3611
from typing import TYPE_CHECKING
3712
if TYPE_CHECKING:
@@ -75,7 +50,6 @@ def dask_arange(
7550
arange = get_xp(da)(dask_arange)
7651
eye = get_xp(da)(_aliases.eye)
7752

78-
from functools import partial
7953
asarray = partial(_aliases._asarray, namespace='dask.array')
8054
asarray.__doc__ = _aliases._asarray.__doc__
8155

@@ -112,34 +86,3 @@ def dask_arange(
11286
matmul = get_xp(np)(_aliases.matmul)
11387
tensordot = get_xp(np)(_aliases.tensordot)
11488

115-
from dask.array import (
116-
# Element wise aliases
117-
arccos as acos,
118-
arccosh as acosh,
119-
arcsin as asin,
120-
arcsinh as asinh,
121-
arctan as atan,
122-
arctan2 as atan2,
123-
arctanh as atanh,
124-
left_shift as bitwise_left_shift,
125-
right_shift as bitwise_right_shift,
126-
invert as bitwise_invert,
127-
power as pow,
128-
# Other
129-
concatenate as concat,
130-
)
131-
132-
# exclude these from all since
133-
_da_unsupported = ['sort', 'argsort']
134-
135-
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
136-
137-
__all__ = common_aliases + ['asarray', 'bool', 'acos',
138-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
139-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
140-
'bitwise_right_shift', 'concat', 'pow',
141-
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
142-
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
143-
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
144-
145-
del da, partial, common_aliases, _da_unsupported,

0 commit comments

Comments
 (0)