Skip to content

Commit cc565d1

Browse files
authored
Merge pull request #272 from hameerabbasi/fix-nan-propagation-immutable
Fix `test_nan_propagation` for immutable arrays
2 parents b2fd7fa + e73f038 commit cc565d1

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

array_api_tests/test_special_cases.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from warnings import warn
2424

2525
import pytest
26-
from hypothesis import given, note, settings
26+
from hypothesis import given, note, settings, assume
2727
from hypothesis import strategies as st
2828

2929
from array_api_tests.typing import Array, DataType
@@ -1331,10 +1331,11 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
13311331
)
13321332
def test_nan_propagation(func_name, x, data):
13331333
func = getattr(xp, func_name)
1334-
set_idx = data.draw(
1335-
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
1334+
nan_positions = data.draw(
1335+
hh.arrays(dtype=hh.bool_dtype, shape=x.shape), label="nan_positions"
13361336
)
1337-
x[set_idx] = float("nan")
1337+
assume(xp.any(nan_positions))
1338+
x = xp.where(nan_positions, xp.asarray(float("nan")), x)
13381339
note(f"{x=}")
13391340

13401341
out = func(x)

0 commit comments

Comments
 (0)