Skip to content

Commit 6df8bd6

Browse files
authored
Dispatch to Dask if nanquantile is available (#9719)
* Dispatch to Dask is nanquantile is available * Fixup * Change test
1 parent 2619c0b commit 6df8bd6

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

xarray/core/variable.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
4848
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
49+
from xarray.namedarray.utils import module_available
4950
from xarray.util.deprecation_helpers import deprecate_dims
5051

5152
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
@@ -1948,7 +1949,7 @@ def _wrapper(npa, **kwargs):
19481949
output_core_dims=[["quantile"]],
19491950
output_dtypes=[np.float64],
19501951
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
1951-
dask="parallelized",
1952+
dask="allowed" if module_available("dask", "2024.11.0") else "parallelized",
19521953
kwargs=kwargs,
19531954
)
19541955

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _importorskip(
107107
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
108108
has_cftime, requires_cftime = _importorskip("cftime")
109109
has_dask, requires_dask = _importorskip("dask")
110+
has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0")
110111
with warnings.catch_warnings():
111112
warnings.filterwarnings(
112113
"ignore",

xarray/tests/test_variable.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
assert_equal,
3737
assert_identical,
3838
assert_no_warnings,
39+
has_dask_ge_2024_11_0,
3940
has_pandas_3,
4041
raise_if_dask_computes,
4142
requires_bottleneck,
@@ -1871,9 +1872,16 @@ def test_quantile_interpolation_deprecation(self, method) -> None:
18711872
def test_quantile_chunked_dim_error(self):
18721873
v = Variable(["x", "y"], self.d).chunk({"x": 2})
18731874

1874-
# this checks for ValueError in dask.array.apply_gufunc
1875-
with pytest.raises(ValueError, match=r"consists of multiple chunks"):
1876-
v.quantile(0.5, dim="x")
1875+
if has_dask_ge_2024_11_0:
1876+
# Dask rechunks
1877+
np.testing.assert_allclose(
1878+
v.compute().quantile(0.5, dim="x"), v.quantile(0.5, dim="x")
1879+
)
1880+
1881+
else:
1882+
# this checks for ValueError in dask.array.apply_gufunc
1883+
with pytest.raises(ValueError, match=r"consists of multiple chunks"):
1884+
v.quantile(0.5, dim="x")
18771885

18781886
@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
18791887
@pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]])

0 commit comments

Comments
 (0)