Skip to content

Commit e5bebbe

Browse files
authored
Merge pull request #39 from asmeurer/boolean-indexing-flag
Make boolean_indexing a separate flag from data_dependent_shapes
2 parents d51c277 + 943756f commit e5bebbe

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

Diff for: array_api_strict/_array_object.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ def _validate_index(self, key):
436436
f"{len(key)=}, but masking is only specified in the "
437437
"Array API when the array is the sole index."
438438
)
439-
if not get_array_api_strict_flags()['data_dependent_shapes']:
440-
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
439+
if not get_array_api_strict_flags()['boolean_indexing']:
440+
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")
441441

442442
elif i.dtype in _integer_dtypes and i.ndim != 0:
443443
raise IndexError(

Diff for: array_api_strict/_flags.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
API_VERSION = default_version = "2022.12"
2727

28+
BOOLEAN_INDEXING = True
29+
2830
DATA_DEPENDENT_SHAPES = True
2931

3032
all_extensions = (
@@ -46,6 +48,7 @@
4648
def set_array_api_strict_flags(
4749
*,
4850
api_version=None,
51+
boolean_indexing=None,
4952
data_dependent_shapes=None,
5053
enabled_extensions=None,
5154
):
@@ -67,6 +70,12 @@ def set_array_api_strict_flags(
6770
Note that 2021.12 is supported, but currently gives the same thing as
6871
2022.12 (except that the fft extension will be disabled).
6972
73+
74+
- `boolean_indexing`: Whether indexing by a boolean array is supported.
75+
Note that although boolean array indexing does result in data-dependent
76+
shapes, this flag is independent of the `data_dependent_shapes` flag
77+
(see below).
78+
7079
- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
7180
array-api-strict.
7281
@@ -79,10 +88,12 @@ def set_array_api_strict_flags(
7988
8089
- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
8190
- `nonzero`
82-
- Boolean array indexing
8391
- `repeat` when the `repeats` argument is an array (requires 2023.12
8492
version of the standard)
8593
94+
Note that while boolean indexing is also data-dependent, it is
95+
controlled by a separate `boolean_indexing` flag (see above).
96+
8697
See
8798
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
8899
for more details.
@@ -102,8 +113,8 @@ def set_array_api_strict_flags(
102113
>>> # Set the standard version to 2021.12
103114
>>> set_array_api_strict_flags(api_version="2021.12")
104115
105-
>>> # Disable data-dependent shapes
106-
>>> set_array_api_strict_flags(data_dependent_shapes=False)
116+
>>> # Disable data-dependent shapes and boolean indexing
117+
>>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False)
107118
108119
>>> # Enable only the linalg extension (disable the fft extension)
109120
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
@@ -116,7 +127,7 @@ def set_array_api_strict_flags(
116127
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
117128
118129
"""
119-
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
130+
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
120131

121132
if api_version is not None:
122133
if api_version not in supported_versions:
@@ -126,6 +137,9 @@ def set_array_api_strict_flags(
126137
API_VERSION = api_version
127138
array_api_strict.__array_api_version__ = API_VERSION
128139

140+
if boolean_indexing is not None:
141+
BOOLEAN_INDEXING = boolean_indexing
142+
129143
if data_dependent_shapes is not None:
130144
DATA_DEPENDENT_SHAPES = data_dependent_shapes
131145

@@ -169,7 +183,11 @@ def get_array_api_strict_flags():
169183
>>> from array_api_strict import get_array_api_strict_flags
170184
>>> flags = get_array_api_strict_flags()
171185
>>> flags
172-
{'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
186+
{'api_version': '2022.12',
187+
'boolean_indexing': True,
188+
'data_dependent_shapes': True,
189+
'enabled_extensions': ('linalg', 'fft')
190+
}
173191
174192
See Also
175193
--------
@@ -181,6 +199,7 @@ def get_array_api_strict_flags():
181199
"""
182200
return {
183201
"api_version": API_VERSION,
202+
"boolean_indexing": BOOLEAN_INDEXING,
184203
"data_dependent_shapes": DATA_DEPENDENT_SHAPES,
185204
"enabled_extensions": ENABLED_EXTENSIONS,
186205
}
@@ -215,9 +234,10 @@ def reset_array_api_strict_flags():
215234
ArrayAPIStrictFlags: A context manager to temporarily set the flags.
216235
217236
"""
218-
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
237+
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
219238
API_VERSION = default_version
220239
array_api_strict.__array_api_version__ = API_VERSION
240+
BOOLEAN_INDEXING = True
221241
DATA_DEPENDENT_SHAPES = True
222242
ENABLED_EXTENSIONS = default_extensions
223243

@@ -242,10 +262,11 @@ class ArrayAPIStrictFlags:
242262
reset_array_api_strict_flags: Reset the flags to their default values.
243263
244264
"""
245-
def __init__(self, *, api_version=None, data_dependent_shapes=None,
246-
enabled_extensions=None):
265+
def __init__(self, *, api_version=None, boolean_indexing=None,
266+
data_dependent_shapes=None, enabled_extensions=None):
247267
self.kwargs = {
248268
"api_version": api_version,
269+
"boolean_indexing": boolean_indexing,
249270
"data_dependent_shapes": data_dependent_shapes,
250271
"enabled_extensions": enabled_extensions,
251272
}
@@ -265,6 +286,11 @@ def set_flags_from_environment():
265286
api_version=os.environ["ARRAY_API_STRICT_API_VERSION"]
266287
)
267288

289+
if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ:
290+
set_array_api_strict_flags(
291+
boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true"
292+
)
293+
268294
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
269295
set_array_api_strict_flags(
270296
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"

Diff for: array_api_strict/tests/test_flags.py

+17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_flags():
1313
flags = get_array_api_strict_flags()
1414
assert flags == {
1515
'api_version': '2022.12',
16+
'boolean_indexing': True,
1617
'data_dependent_shapes': True,
1718
'enabled_extensions': ('linalg', 'fft'),
1819
}
@@ -22,13 +23,15 @@ def test_flags():
2223
flags = get_array_api_strict_flags()
2324
assert flags == {
2425
'api_version': '2022.12',
26+
'boolean_indexing': True,
2527
'data_dependent_shapes': False,
2628
'enabled_extensions': ('linalg', 'fft'),
2729
}
2830
set_array_api_strict_flags(enabled_extensions=('fft',))
2931
flags = get_array_api_strict_flags()
3032
assert flags == {
3133
'api_version': '2022.12',
34+
'boolean_indexing': True,
3235
'data_dependent_shapes': False,
3336
'enabled_extensions': ('fft',),
3437
}
@@ -41,6 +44,7 @@ def test_flags():
4144
flags = get_array_api_strict_flags()
4245
assert flags == {
4346
'api_version': '2021.12',
47+
'boolean_indexing': True,
4448
'data_dependent_shapes': False,
4549
'enabled_extensions': ('linalg',),
4650
}
@@ -58,12 +62,14 @@ def test_flags():
5862
with pytest.warns(UserWarning):
5963
set_array_api_strict_flags(
6064
api_version='2021.12',
65+
boolean_indexing=False,
6166
data_dependent_shapes=False,
6267
enabled_extensions=())
6368
reset_array_api_strict_flags()
6469
flags = get_array_api_strict_flags()
6570
assert flags == {
6671
'api_version': '2022.12',
72+
'boolean_indexing': True,
6773
'data_dependent_shapes': True,
6874
'enabled_extensions': ('linalg', 'fft'),
6975
}
@@ -96,6 +102,17 @@ def test_data_dependent_shapes():
96102
pytest.raises(RuntimeError, lambda: unique_inverse(a))
97103
pytest.raises(RuntimeError, lambda: unique_values(a))
98104
pytest.raises(RuntimeError, lambda: nonzero(a))
105+
a[mask] # No error (boolean indexing is a separate flag)
106+
107+
def test_boolean_indexing():
108+
a = asarray([0, 0, 1, 2, 2])
109+
mask = asarray([True, False, True, False, True])
110+
111+
# Should not error
112+
a[mask]
113+
114+
set_array_api_strict_flags(boolean_indexing=False)
115+
99116
pytest.raises(RuntimeError, lambda: a[mask])
100117

101118
linalg_examples = {

Diff for: docs/api.rst

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ used by array-api-strict initially. They will not change the defaults used by
3030

3131
A string representing the version number.
3232

33+
.. envvar:: ARRAY_API_STRICT_BOOLEAN_INDEXING
34+
35+
"True" or "False" to enable or disable boolean indexing.
36+
3337
.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES
3438

3539
"True" or "False" to enable or disable data dependent shapes.

0 commit comments

Comments
 (0)