|
17 | 17 |
|
18 | 18 | import operator |
19 | 19 | from enum import IntEnum |
20 | | -import warnings |
21 | 20 |
|
22 | 21 | from ._creation_functions import asarray |
23 | 22 | from ._dtypes import ( |
|
32 | 31 | _result_type, |
33 | 32 | _dtype_categories, |
34 | 33 | ) |
| 34 | +from ._flags import get_array_api_strict_flags, set_array_api_strict_flags |
35 | 35 |
|
36 | 36 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex |
37 | 37 | import types |
@@ -427,13 +427,17 @@ def _validate_index(self, key): |
427 | 427 | "the Array API)" |
428 | 428 | ) |
429 | 429 | elif isinstance(i, Array): |
430 | | - if i.dtype in _boolean_dtypes and len(_key) != 1: |
431 | | - assert isinstance(key, tuple) # sanity check |
432 | | - raise IndexError( |
433 | | - f"Single-axes index {i} is a boolean array and " |
434 | | - f"{len(key)=}, but masking is only specified in the " |
435 | | - "Array API when the array is the sole index." |
436 | | - ) |
| 430 | + if i.dtype in _boolean_dtypes: |
| 431 | + if len(_key) != 1: |
| 432 | + assert isinstance(key, tuple) # sanity check |
| 433 | + raise IndexError( |
| 434 | + f"Single-axes index {i} is a boolean array and " |
| 435 | + f"{len(key)=}, but masking is only specified in the " |
| 436 | + "Array API when the array is the sole index." |
| 437 | + ) |
| 438 | + if not get_array_api_strict_flags()['data_dependent_shapes']: |
| 439 | + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") |
| 440 | + |
437 | 441 | elif i.dtype in _integer_dtypes and i.ndim != 0: |
438 | 442 | raise IndexError( |
439 | 443 | f"Single-axes index {i} is a non-zero-dimensional " |
@@ -482,10 +486,21 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: |
482 | 486 | def __array_namespace__( |
483 | 487 | self: Array, /, *, api_version: Optional[str] = None |
484 | 488 | ) -> types.ModuleType: |
485 | | - if api_version is not None and api_version not in ["2021.12", "2022.12"]: |
486 | | - raise ValueError(f"Unrecognized array API version: {api_version!r}") |
487 | | - if api_version == "2021.12": |
488 | | - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") |
| 489 | + """ |
| 490 | + Return the array_api_strict namespace corresponding to api_version. |
| 491 | +
|
| 492 | + The default API version is '2022.12'. Note that '2021.12' is supported, |
| 493 | + but currently identical to '2022.12'. |
| 494 | +
|
| 495 | + For array_api_strict, calling this function with api_version will set |
| 496 | + the API version for the array_api_strict module globally. This can |
| 497 | + also be achieved with the |
| 498 | + {func}`array_api_strict.set_array_api_strict_flags` function. If you |
| 499 | + want to only set the version locally, use the |
| 500 | + {class}`array_api_strict.ArrayApiStrictFlags` context manager. |
| 501 | +
|
| 502 | + """ |
| 503 | + set_array_api_strict_flags(api_version=api_version) |
489 | 504 | import array_api_strict |
490 | 505 | return array_api_strict |
491 | 506 |
|
|
0 commit comments