|
3 | 3 | import array_api_strict as xp
|
4 | 4 |
|
5 | 5 | from array_api_strict import ArrayAPIStrictFlags
|
6 |
| -from array_api_strict._flags import draft_version |
7 | 6 |
|
8 | 7 |
|
9 | 8 | def test_where_with_scalars():
|
10 | 9 | x = xp.asarray([1, 2, 3, 1])
|
11 | 10 |
|
12 | 11 | # Versions up to and including 2023.12 don't support scalar arguments
|
13 |
| - with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): |
14 |
| - xp.where(x == 1, 42, 44) |
| 12 | + with ArrayAPIStrictFlags(api_version='2023.12'): |
| 13 | + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): |
| 14 | + xp.where(x == 1, 42, 44) |
15 | 15 |
|
16 | 16 | # Versions after 2023.12 support scalar arguments
|
17 |
| - with (pytest.warns( |
18 |
| - UserWarning, |
19 |
| - match="The 2024.12 version of the array API specification is in draft status" |
20 |
| - ), |
21 |
| - ArrayAPIStrictFlags(api_version=draft_version), |
22 |
| - ): |
23 |
| - x_where = xp.where(x == 1, xp.asarray(42), 44) |
24 |
| - |
25 |
| - expected = xp.asarray([42, 44, 44, 42]) |
26 |
| - assert xp.all(x_where == expected) |
27 |
| - |
28 |
| - # The spec does not allow both x1 and x2 to be scalars |
29 |
| - with pytest.raises(ValueError, match="One of"): |
30 |
| - xp.where(x == 1, 42, 44) |
| 17 | + x_where = xp.where(x == 1, xp.asarray(42), 44) |
| 18 | + |
| 19 | + expected = xp.asarray([42, 44, 44, 42]) |
| 20 | + assert xp.all(x_where == expected) |
| 21 | + |
| 22 | + # The spec does not allow both x1 and x2 to be scalars |
| 23 | + with pytest.raises(ValueError, match="One of"): |
| 24 | + xp.where(x == 1, 42, 44) |
0 commit comments