|
16 | 16 |
|
17 | 17 | """
|
18 | 18 |
|
| 19 | +__all__ = [] |
| 20 | + |
19 | 21 | # Warning: __array_api_version__ could change globally with
|
20 | 22 | # set_array_api_strict_flags(). This should always be accessed as an
|
21 | 23 | # attribute, like xp.__array_api_version__, or using
|
22 | 24 | # array_api_strict.get_array_api_strict_flags()['api_version'].
|
23 | 25 | from ._flags import API_VERSION as __array_api_version__
|
24 | 26 |
|
25 |
| -__all__ = ["__array_api_version__"] |
| 27 | +__all__ += ["__array_api_version__"] |
26 | 28 |
|
27 | 29 | from ._constants import e, inf, nan, pi, newaxis
|
28 | 30 |
|
|
137 | 139 | bitwise_right_shift,
|
138 | 140 | bitwise_xor,
|
139 | 141 | ceil,
|
| 142 | + clip, |
140 | 143 | conj,
|
| 144 | + copysign, |
141 | 145 | cos,
|
142 | 146 | cosh,
|
143 | 147 | divide,
|
|
148 | 152 | floor_divide,
|
149 | 153 | greater,
|
150 | 154 | greater_equal,
|
| 155 | + hypot, |
151 | 156 | imag,
|
152 | 157 | isfinite,
|
153 | 158 | isinf,
|
|
163 | 168 | logical_not,
|
164 | 169 | logical_or,
|
165 | 170 | logical_xor,
|
| 171 | + maximum, |
| 172 | + minimum, |
166 | 173 | multiply,
|
167 | 174 | negative,
|
168 | 175 | not_equal,
|
|
172 | 179 | remainder,
|
173 | 180 | round,
|
174 | 181 | sign,
|
| 182 | + signbit, |
175 | 183 | sin,
|
176 | 184 | sinh,
|
177 | 185 | square,
|
|
199 | 207 | "bitwise_right_shift",
|
200 | 208 | "bitwise_xor",
|
201 | 209 | "ceil",
|
| 210 | + "clip", |
202 | 211 | "conj",
|
| 212 | + "copysign", |
203 | 213 | "cos",
|
204 | 214 | "cosh",
|
205 | 215 | "divide",
|
|
210 | 220 | "floor_divide",
|
211 | 221 | "greater",
|
212 | 222 | "greater_equal",
|
| 223 | + "hypot", |
213 | 224 | "imag",
|
214 | 225 | "isfinite",
|
215 | 226 | "isinf",
|
|
225 | 236 | "logical_not",
|
226 | 237 | "logical_or",
|
227 | 238 | "logical_xor",
|
| 239 | + "maximum", |
| 240 | + "minimum", |
228 | 241 | "multiply",
|
229 | 242 | "negative",
|
230 | 243 | "not_equal",
|
|
234 | 247 | "remainder",
|
235 | 248 | "round",
|
236 | 249 | "sign",
|
| 250 | + "signbit", |
237 | 251 | "sin",
|
238 | 252 | "sinh",
|
239 | 253 | "square",
|
|
248 | 262 |
|
249 | 263 | __all__ += ["take"]
|
250 | 264 |
|
251 |
| -# linalg is an extension in the array API spec, which is a sub-namespace. Only |
252 |
| -# a subset of functions in it are imported into the top-level namespace. |
253 |
| -from . import linalg |
| 265 | +from ._info import __array_namespace_info__ |
254 | 266 |
|
255 |
| -__all__ += ["linalg"] |
| 267 | +__all__ += [ |
| 268 | + "__array_namespace_info__", |
| 269 | +] |
256 | 270 |
|
257 | 271 | from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
|
258 | 272 |
|
259 | 273 | __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
|
260 | 274 |
|
261 |
| -from . import fft |
262 |
| -__all__ += ["fft"] |
263 |
| - |
264 | 275 | from ._manipulation_functions import (
|
265 | 276 | concat,
|
266 | 277 | expand_dims,
|
267 | 278 | flip,
|
| 279 | + moveaxis, |
268 | 280 | permute_dims,
|
| 281 | + repeat, |
269 | 282 | reshape,
|
270 | 283 | roll,
|
271 | 284 | squeeze,
|
272 | 285 | stack,
|
| 286 | + tile, |
| 287 | + unstack, |
273 | 288 | )
|
274 | 289 |
|
275 |
| -__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] |
| 290 | +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] |
276 | 291 |
|
277 |
| -from ._searching_functions import argmax, argmin, nonzero, where |
| 292 | +from ._searching_functions import argmax, argmin, nonzero, searchsorted, where |
278 | 293 |
|
279 |
| -__all__ += ["argmax", "argmin", "nonzero", "where"] |
| 294 | +__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] |
280 | 295 |
|
281 | 296 | from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
|
282 | 297 |
|
|
286 | 301 |
|
287 | 302 | __all__ += ["argsort", "sort"]
|
288 | 303 |
|
289 |
| -from ._statistical_functions import max, mean, min, prod, std, sum, var |
| 304 | +from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var |
290 | 305 |
|
291 |
| -__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] |
| 306 | +__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] |
292 | 307 |
|
293 | 308 | from ._utility_functions import all, any
|
294 | 309 |
|
|
308 | 323 | from . import _version
|
309 | 324 | __version__ = _version.get_versions()['version']
|
310 | 325 | del _version
|
| 326 | + |
| 327 | + |
| 328 | +# Extensions can be enabled or disabled dynamically. In order to make |
| 329 | +# "array_api_strict.linalg" give an AttributeError when it is disabled, we |
| 330 | +# use __getattr__. Note that linalg and fft are dynamically added and removed |
| 331 | +# from __all__ in set_array_api_strict_flags. |
| 332 | + |
| 333 | +def __getattr__(name): |
| 334 | + if name in ['linalg', 'fft']: |
| 335 | + if name in get_array_api_strict_flags()['enabled_extensions']: |
| 336 | + if name == 'linalg': |
| 337 | + from . import _linalg |
| 338 | + return _linalg |
| 339 | + elif name == 'fft': |
| 340 | + from . import _fft |
| 341 | + return _fft |
| 342 | + else: |
| 343 | + raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict") |
| 344 | + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |
0 commit comments