diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index e7d5e1b8..358a8eef 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -227,19 +227,17 @@ def test_irfftn(x, data): expected=dh.dtype_components[x.dtype], ) - # TODO: assert shape correctly - # _axes = sh.normalize_axis(axes, x.ndim) - # _s = x.shape if s is None else s - # expected = [] - # for i in range(x.ndim): - # if i in _axes: - # side = _s[_axes.index(i)] - # else: - # side = x.shape[i] - # expected.append(side) - # last_axis = max(_axes) - # expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1 - # ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected)) + _axes = sh.normalize_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1] + ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected)) @given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())