25
25
26
26
API_VERSION = default_version = "2022.12"
27
27
28
+ BOOLEAN_INDEXING = True
29
+
28
30
DATA_DEPENDENT_SHAPES = True
29
31
30
32
all_extensions = (
46
48
def set_array_api_strict_flags (
47
49
* ,
48
50
api_version = None ,
51
+ boolean_indexing = None ,
49
52
data_dependent_shapes = None ,
50
53
enabled_extensions = None ,
51
54
):
@@ -67,6 +70,12 @@ def set_array_api_strict_flags(
67
70
Note that 2021.12 is supported, but currently gives the same thing as
68
71
2022.12 (except that the fft extension will be disabled).
69
72
73
+
74
+ - `boolean_indexing`: Whether indexing by a boolean array is supported.
75
+ Note that although boolean array indexing does result in data-dependent
76
+ shapes, this flag is independent of the `data_dependent_shapes` flag
77
+ (see below).
78
+
70
79
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
71
80
array-api-strict.
72
81
@@ -79,10 +88,12 @@ def set_array_api_strict_flags(
79
88
80
89
- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
81
90
- `nonzero`
82
- - Boolean array indexing
83
91
- `repeat` when the `repeats` argument is an array (requires 2023.12
84
92
version of the standard)
85
93
94
+ Note that while boolean indexing is also data-dependent, it is
95
+ controlled by a separate `boolean_indexing` flag (see above).
96
+
86
97
See
87
98
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
88
99
for more details.
@@ -102,8 +113,8 @@ def set_array_api_strict_flags(
102
113
>>> # Set the standard version to 2021.12
103
114
>>> set_array_api_strict_flags(api_version="2021.12")
104
115
105
- >>> # Disable data-dependent shapes
106
- >>> set_array_api_strict_flags(data_dependent_shapes=False)
116
+ >>> # Disable data-dependent shapes and boolean indexing
117
+ >>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False )
107
118
108
119
>>> # Enable only the linalg extension (disable the fft extension)
109
120
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
@@ -116,7 +127,7 @@ def set_array_api_strict_flags(
116
127
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
117
128
118
129
"""
119
- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
130
+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
120
131
121
132
if api_version is not None :
122
133
if api_version not in supported_versions :
@@ -126,6 +137,9 @@ def set_array_api_strict_flags(
126
137
API_VERSION = api_version
127
138
array_api_strict .__array_api_version__ = API_VERSION
128
139
140
+ if boolean_indexing is not None :
141
+ BOOLEAN_INDEXING = boolean_indexing
142
+
129
143
if data_dependent_shapes is not None :
130
144
DATA_DEPENDENT_SHAPES = data_dependent_shapes
131
145
@@ -169,7 +183,11 @@ def get_array_api_strict_flags():
169
183
>>> from array_api_strict import get_array_api_strict_flags
170
184
>>> flags = get_array_api_strict_flags()
171
185
>>> flags
172
- {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
186
+ {'api_version': '2022.12',
187
+ 'boolean_indexing': True,
188
+ 'data_dependent_shapes': True,
189
+ 'enabled_extensions': ('linalg', 'fft')
190
+ }
173
191
174
192
See Also
175
193
--------
@@ -181,6 +199,7 @@ def get_array_api_strict_flags():
181
199
"""
182
200
return {
183
201
"api_version" : API_VERSION ,
202
+ "boolean_indexing" : BOOLEAN_INDEXING ,
184
203
"data_dependent_shapes" : DATA_DEPENDENT_SHAPES ,
185
204
"enabled_extensions" : ENABLED_EXTENSIONS ,
186
205
}
@@ -215,9 +234,10 @@ def reset_array_api_strict_flags():
215
234
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
216
235
217
236
"""
218
- global API_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
237
+ global API_VERSION , BOOLEAN_INDEXING , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
219
238
API_VERSION = default_version
220
239
array_api_strict .__array_api_version__ = API_VERSION
240
+ BOOLEAN_INDEXING = True
221
241
DATA_DEPENDENT_SHAPES = True
222
242
ENABLED_EXTENSIONS = default_extensions
223
243
@@ -242,10 +262,11 @@ class ArrayAPIStrictFlags:
242
262
reset_array_api_strict_flags: Reset the flags to their default values.
243
263
244
264
"""
245
- def __init__ (self , * , api_version = None , data_dependent_shapes = None ,
246
- enabled_extensions = None ):
265
+ def __init__ (self , * , api_version = None , boolean_indexing = None ,
266
+ data_dependent_shapes = None , enabled_extensions = None ):
247
267
self .kwargs = {
248
268
"api_version" : api_version ,
269
+ "boolean_indexing" : boolean_indexing ,
249
270
"data_dependent_shapes" : data_dependent_shapes ,
250
271
"enabled_extensions" : enabled_extensions ,
251
272
}
@@ -265,6 +286,11 @@ def set_flags_from_environment():
265
286
api_version = os .environ ["ARRAY_API_STRICT_API_VERSION" ]
266
287
)
267
288
289
+ if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os .environ :
290
+ set_array_api_strict_flags (
291
+ boolean_indexing = os .environ ["ARRAY_API_STRICT_BOOLEAN_INDEXING" ].lower () == "true"
292
+ )
293
+
268
294
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
269
295
set_array_api_strict_flags (
270
296
data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
0 commit comments