Skip to content

Commit f8e257a

Browse files
committed
Introduce unvectorized pytest marker
1 parent b3723d2 commit f8e257a

11 files changed

+46
-2
lines changed

array_api_tests/test_array_object.py

+3
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_getitem(shape, dtype, data):
105105
ph.assert_array_elements("__getitem__", out=out, expected=expected)
106106

107107

108+
@pytest.mark.unvectorized
108109
@given(
109110
shape=hh.shapes(),
110111
dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
@@ -152,6 +153,7 @@ def test_setitem(shape, dtypes, data):
152153
)
153154

154155

156+
@pytest.mark.unvectorized
155157
@pytest.mark.data_dependent_shapes
156158
@given(hh.shapes(), st.data())
157159
def test_getitem_masking(shape, data):
@@ -197,6 +199,7 @@ def test_getitem_masking(shape, data):
197199
)
198200

199201

202+
@pytest.mark.unvectorized
200203
@given(hh.shapes(), st.data())
201204
def test_setitem_masking(shape, data):
202205
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")

array_api_tests/test_indexing_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import xps
1111

1212

13+
@pytest.mark.unvectorized
1314
@pytest.mark.min_version("2022.12")
1415
@given(
1516
x=hh.arrays(xps.scalar_dtypes(), hh.shapes(min_dims=1, min_side=1)),

array_api_tests/test_manipulation_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_concat(dtypes, base_shape, data):
119119
)
120120

121121

122+
@pytest.mark.unvectorized
122123
@given(
123124
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()),
124125
axis=shared_shapes().flatmap(
@@ -147,6 +148,7 @@ def test_expand_dims(x, axis):
147148
)
148149

149150

151+
@pytest.mark.unvectorized
150152
@given(
151153
x=hh.arrays(
152154
dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s)
@@ -184,6 +186,7 @@ def test_squeeze(x, data):
184186
assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape))
185187

186188

189+
@pytest.mark.unvectorized
187190
@given(
188191
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
189192
data=st.data(),
@@ -208,6 +211,7 @@ def test_flip(x, data):
208211
out_indices=reverse_indices, kw=kw)
209212

210213

214+
@pytest.mark.unvectorized
211215
@given(
212216
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)),
213217
axes=shared_shapes(min_dims=1).flatmap(
@@ -248,6 +252,7 @@ def reshape_shapes(draw, shape):
248252
return tuple(rshape)
249253

250254

255+
@pytest.mark.unvectorized
251256
@pytest.mark.skip("flaky") # TODO: fix!
252257
@given(
253258
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)),
@@ -280,6 +285,7 @@ def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator
280285
yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape))
281286

282287

288+
@pytest.mark.unvectorized
283289
@given(hh.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data())
284290
def test_roll(x, data):
285291
shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)
@@ -319,6 +325,7 @@ def test_roll(x, data):
319325
assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw)
320326

321327

328+
@pytest.mark.unvectorized
322329
@given(
323330
shape=shared_shapes(min_dims=1),
324331
dtypes=hh.mutually_promotable_dtypes(None),

array_api_tests/test_operators_and_elementwise_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .typing import Array, DataType, Param, Scalar, ScalarType, Shape
2323

2424

25+
pytestmark = pytest.mark.unvectorized
26+
27+
2528
def all_integer_dtypes() -> st.SearchStrategy[DataType]:
2629
"""Returns a strategy for signed and unsigned integer dtype objects."""
2730
return xps.unsigned_integer_dtypes() | xps.integer_dtypes()

array_api_tests/test_searching_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from . import xps
1313

1414

15+
pytestmark = pytest.mark.unvectorized
16+
17+
1518
@given(
1619
x=hh.arrays(
1720
dtype=xps.real_dtypes(),

array_api_tests/test_set_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from . import shape_helpers as sh
1414
from . import xps
1515

16-
pytestmark = pytest.mark.data_dependent_shapes
16+
pytestmark = [pytest.mark.data_dependent_shapes, pytest.mark.unvectorized]
1717

1818

1919
@given(hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))

array_api_tests/test_sorting_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import cmath
22
from typing import Set
33

4+
import pytest
45
from hypothesis import given
56
from hypothesis import strategies as st
67
from hypothesis.control import assume
@@ -29,6 +30,7 @@ def assert_scalar_in_set(
2930

3031

3132
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
33+
@pytest.mark.unvectorized
3234
@given(
3335
x=hh.arrays(
3436
dtype=xps.real_dtypes(),
@@ -88,6 +90,7 @@ def test_argsort(x, data):
8890
)
8991

9092

93+
@pytest.mark.unvectorized
9194
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
9295
@given(
9396
x=hh.arrays(

array_api_tests/test_special_cases.py

+3
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
12101210
assert len(iop_params) != 0
12111211

12121212

1213+
@pytest.mark.unvectorized
12131214
@pytest.mark.parametrize("func_name, func, case", unary_params)
12141215
@given(
12151216
x=hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
@@ -1248,6 +1249,7 @@ def test_unary(func_name, func, case, x, data):
12481249
)
12491250

12501251

1252+
@pytest.mark.unvectorized
12511253
@pytest.mark.parametrize("func_name, func, case", binary_params)
12521254
@given(x1=x1_strat, x2=x2_strat, data=st.data())
12531255
def test_binary(func_name, func, case, x1, x2, data):
@@ -1292,6 +1294,7 @@ def test_binary(func_name, func, case, x1, x2, data):
12921294
assume(good_example)
12931295

12941296

1297+
@pytest.mark.unvectorized
12951298
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
12961299
@given(
12971300
oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes),

array_api_tests/test_statistical_functions.py

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2323
return st.none() | st.sampled_from(dtypes)
2424

2525

26+
@pytest.mark.unvectorized
2627
@given(
2728
x=hh.arrays(
2829
dtype=xps.real_dtypes(),
@@ -75,6 +76,7 @@ def test_mean(x, data):
7576
# Values testing mean is too finicky
7677

7778

79+
@pytest.mark.unvectorized
7880
@given(
7981
x=hh.arrays(
8082
dtype=xps.real_dtypes(),
@@ -105,6 +107,7 @@ def test_min(x, data):
105107
ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
106108

107109

110+
@pytest.mark.unvectorized
108111
@given(
109112
x=hh.arrays(
110113
dtype=xps.numeric_dtypes(),
@@ -193,6 +196,7 @@ def test_std(x, data):
193196
# We can't easily test the result(s) as standard deviation methods vary a lot
194197

195198

199+
@pytest.mark.unvectorized
196200
@pytest.mark.skip("flaky") # TODO: fix!
197201
@given(
198202
x=hh.arrays(
@@ -245,6 +249,7 @@ def test_sum(x, data):
245249
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
246250

247251

252+
@pytest.mark.unvectorized
248253
@pytest.mark.skip(reason="flaky") # TODO: fix!
249254
@given(
250255
x=hh.arrays(

array_api_tests/test_utility_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from hypothesis import given
23
from hypothesis import strategies as st
34

@@ -9,6 +10,7 @@
910
from . import xps
1011

1112

13+
@pytest.mark.unvectorized
1214
@given(
1315
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)),
1416
data=st.data(),
@@ -36,6 +38,7 @@ def test_all(x, data):
3638
out=result, expected=expected, kw=kw)
3739

3840

41+
@pytest.mark.unvectorized
3942
@given(
4043
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
4144
data=st.data(),

conftest.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def pytest_configure(config):
7777
"markers",
7878
"min_version(api_version): run when greater or equal to api_version",
7979
)
80+
config.addinivalue_line(
81+
"markers",
82+
"unvectorized: asserts against values via element-wise iteration (not performative!)",
83+
)
8084
# Hypothesis
8185
hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
8286
disable_deadline = config.getoption("--hypothesis-disable-deadline")
@@ -104,6 +108,9 @@ def xp_has_ext(ext: str) -> bool:
104108

105109

106110
def pytest_collection_modifyitems(config, items):
111+
# 1. Prepare for iterating over items
112+
# -----------------------------------
113+
107114
skips_file = skips_path = config.getoption('--skips-file')
108115
if skips_file is None:
109116
skips_file = Path(__file__).parent / "skips.txt"
@@ -139,6 +146,9 @@ def pytest_collection_modifyitems(config, items):
139146
disabled_exts = config.getoption("--disable-extension")
140147
disabled_dds = config.getoption("--disable-data-dependent-shapes")
141148

149+
# 2. Iterate through items and apply markers accordingly
150+
# ------------------------------------------------------
151+
142152
for item in items:
143153
markers = list(item.iter_markers())
144154
# skip if specified in skips file
@@ -182,6 +192,9 @@ def pytest_collection_modifyitems(config, items):
182192
)
183193
)
184194

195+
# 3. Warn on bad skipped/xfailed ids
196+
# ----------------------------------
197+
185198
bad_ids_end_msg = (
186199
"Note the relevant tests might not of been collected by pytest, or "
187200
"another specified id might have already matched a test."
@@ -203,4 +216,4 @@ def pytest_collection_modifyitems(config, items):
203216
f"{f_bad_ids}\n"
204217
f"(xfails file: {xfails_file})\n"
205218
f"{bad_ids_end_msg}"
206-
)
219+
)

0 commit comments

Comments
 (0)