Skip to content

Commit e239389

Browse files
authored
Support vectorized interpolation with more scipy interpolators (#9526)
* vectorize 1d interpolators * whats new * formatting
1 parent ea06c6f commit e239389

File tree

8 files changed

+115
-53
lines changed

8 files changed

+115
-53
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ New Features
3232
`Tom Nicholas <https://github.com/TomNicholas>`_.
3333
- Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`).
3434
By `Eni Awowale <https://github.com/eni-awowale>`_.
35+
- Added support for vectorized interpolation using additional interpolators
36+
from the ``scipy.interpolate`` module (:issue:`9049`, :pull:`9526`).
37+
By `Holly Mandel <https://github.com/hollymandel>`_.
3538

3639
Breaking changes
3740
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2224,12 +2224,12 @@ def interp(
22242224
22252225
Performs univariate or multivariate interpolation of a DataArray onto
22262226
new coordinates using scipy's interpolation routines. If interpolating
2227-
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
2228-
called. When interpolating along multiple existing dimensions, an
2227+
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
2228+
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
2229+
is called. When interpolating along multiple existing dimensions, an
22292230
attempt is made to decompose the interpolation into multiple
2230-
1-dimensional interpolations. If this is possible,
2231-
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
2232-
:py:func:`scipy.interpolate.interpn` is called.
2231+
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called.
2232+
Otherwise, :py:func:`scipy.interpolate.interpn` is called.
22332233
22342234
Parameters
22352235
----------

xarray/core/dataset.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3885,12 +3885,12 @@ def interp(
38853885
38863886
Performs univariate or multivariate interpolation of a Dataset onto
38873887
new coordinates using scipy's interpolation routines. If interpolating
3888-
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
3889-
called. When interpolating along multiple existing dimensions, an
3888+
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
3889+
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
3890+
is called. When interpolating along multiple existing dimensions, an
38903891
attempt is made to decompose the interpolation into multiple
3891-
1-dimensional interpolations. If this is possible,
3892-
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
3893-
:py:func:`scipy.interpolate.interpn` is called.
3892+
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator
3893+
is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called.
38943894
38953895
Parameters
38963896
----------

xarray/core/missing.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
copy=False,
139139
bounds_error=False,
140140
order=None,
141+
axis=-1,
141142
**kwargs,
142143
):
143144
from scipy.interpolate import interp1d
@@ -173,6 +174,7 @@ def __init__(
173174
bounds_error=bounds_error,
174175
assume_sorted=assume_sorted,
175176
copy=copy,
177+
axis=axis,
176178
**self.cons_kwargs,
177179
)
178180

@@ -479,7 +481,8 @@ def _get_interpolator(
479481
interp1d_methods = get_args(Interp1dOptions)
480482
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))
481483

482-
# prioritize scipy.interpolate
484+
# prefer numpy.interp for 1d linear interpolation. This function cannot
485+
# take higher dimensional data but scipy.interp1d can.
483486
if (
484487
method == "linear"
485488
and not kwargs.get("fill_value", None) == "extrapolate"
@@ -492,21 +495,33 @@ def _get_interpolator(
492495
if method in interp1d_methods:
493496
kwargs.update(method=method)
494497
interp_class = ScipyInterpolator
495-
elif vectorizeable_only:
496-
raise ValueError(
497-
f"{method} is not a vectorizeable interpolator. "
498-
f"Available methods are {interp1d_methods}"
499-
)
500498
elif method == "barycentric":
499+
kwargs.update(axis=-1)
501500
interp_class = _import_interpolant("BarycentricInterpolator", method)
502501
elif method in ["krogh", "krog"]:
502+
kwargs.update(axis=-1)
503503
interp_class = _import_interpolant("KroghInterpolator", method)
504504
elif method == "pchip":
505+
kwargs.update(axis=-1)
505506
interp_class = _import_interpolant("PchipInterpolator", method)
506507
elif method == "spline":
508+
utils.emit_user_level_warning(
509+
"The 1d SplineInterpolator class is performing an incorrect calculation and "
510+
"is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
511+
PendingDeprecationWarning,
512+
)
513+
if vectorizeable_only:
514+
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
507515
kwargs.update(method=method)
508516
interp_class = SplineInterpolator
509517
elif method == "akima":
518+
kwargs.update(axis=-1)
519+
interp_class = _import_interpolant("Akima1DInterpolator", method)
520+
elif method == "makima":
521+
kwargs.update(method="makima", axis=-1)
522+
interp_class = _import_interpolant("Akima1DInterpolator", method)
523+
elif method == "makima":
524+
kwargs.update(method="makima", axis=-1)
510525
interp_class = _import_interpolant("Akima1DInterpolator", method)
511526
else:
512527
raise ValueError(f"{method} is not a valid scipy interpolator")
@@ -525,6 +540,7 @@ def _get_interpolator_nd(method, **kwargs):
525540

526541
if method in valid_methods:
527542
kwargs.update(method=method)
543+
kwargs.setdefault("bounds_error", False)
528544
interp_class = _import_interpolant("interpn", method)
529545
else:
530546
raise ValueError(
@@ -614,9 +630,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
614630
if not indexes_coords:
615631
return var.copy()
616632

617-
# default behavior
618-
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
619-
620633
result = var
621634
# decompose the interpolation into a succession of independent interpolation
622635
for indep_indexes_coords in decompose_interp(indexes_coords):
@@ -663,8 +676,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
663676
new_x : a list of 1d array
664677
New coordinates. Should not contain NaN.
665678
method : string
666-
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
667-
1-dimensional interpolation.
679+
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
680+
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
668681
{'linear', 'nearest'} for multidimensional interpolation
669682
**kwargs
670683
Optional keyword arguments to be passed to scipy.interpolator
@@ -756,7 +769,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
756769
def _interp1d(var, x, new_x, func, kwargs):
757770
# x, new_x are tuples of size 1.
758771
x, new_x = x[0], new_x[0]
759-
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
772+
rslt = func(x, var, **kwargs)(np.ravel(new_x))
760773
if new_x.ndim > 1:
761774
return reshape(rslt, (var.shape[:-1] + new_x.shape))
762775
if new_x.ndim == 0:

xarray/core/types.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def copy(
228228
Interp1dOptions = Literal[
229229
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
230230
]
231-
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
231+
InterpolantOptions = Literal[
232+
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
233+
]
232234
InterpOptions = Union[Interp1dOptions, InterpolantOptions]
233235

234236
DatetimeUnitOptions = Literal[

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def _importorskip(
8787

8888
has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
8989
has_scipy, requires_scipy = _importorskip("scipy")
90+
has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13")
9091
with warnings.catch_warnings():
9192
warnings.filterwarnings(
9293
"ignore",

xarray/tests/test_interp.py

+68-29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
assert_identical,
1717
has_dask,
1818
has_scipy,
19+
has_scipy_ge_1_13,
1920
requires_cftime,
2021
requires_dask,
2122
requires_scipy,
@@ -132,29 +133,66 @@ def func(obj, new_x):
132133
assert_allclose(actual, expected)
133134

134135

135-
@pytest.mark.parametrize("use_dask", [False, True])
136-
def test_interpolate_vectorize(use_dask: bool) -> None:
137-
if not has_scipy:
138-
pytest.skip("scipy is not installed.")
139-
140-
if not has_dask and use_dask:
141-
pytest.skip("dask is not installed in the environment.")
142-
136+
@requires_scipy
137+
@pytest.mark.parametrize(
138+
"use_dask, method",
139+
(
140+
(False, "linear"),
141+
(False, "akima"),
142+
pytest.param(
143+
False,
144+
"makima",
145+
marks=pytest.mark.skipif(not has_scipy_ge_1_13, reason="scipy too old"),
146+
),
147+
pytest.param(
148+
True,
149+
"linear",
150+
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
151+
),
152+
pytest.param(
153+
True,
154+
"akima",
155+
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
156+
),
157+
),
158+
)
159+
def test_interpolate_vectorize(use_dask: bool, method: InterpOptions) -> None:
143160
# scipy interpolation for the reference
144-
def func(obj, dim, new_x):
161+
def func(obj, dim, new_x, method):
162+
scipy_kwargs = {}
163+
interpolant_options = {
164+
"barycentric": scipy.interpolate.BarycentricInterpolator,
165+
"krogh": scipy.interpolate.KroghInterpolator,
166+
"pchip": scipy.interpolate.PchipInterpolator,
167+
"akima": scipy.interpolate.Akima1DInterpolator,
168+
"makima": scipy.interpolate.Akima1DInterpolator,
169+
}
170+
145171
shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
146172
for s in new_x.shape[::-1]:
147173
shape.insert(obj.get_axis_num(dim), s)
148174

149-
return scipy.interpolate.interp1d(
150-
da[dim],
151-
obj.data,
152-
axis=obj.get_axis_num(dim),
153-
bounds_error=False,
154-
fill_value=np.nan,
155-
)(new_x).reshape(shape)
175+
if method in interpolant_options:
176+
interpolant = interpolant_options[method]
177+
if method == "makima":
178+
scipy_kwargs["method"] = method
179+
return interpolant(
180+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
181+
)(new_x).reshape(shape)
182+
else:
183+
184+
return scipy.interpolate.interp1d(
185+
da[dim],
186+
obj.data,
187+
axis=obj.get_axis_num(dim),
188+
kind=method,
189+
bounds_error=False,
190+
fill_value=np.nan,
191+
**scipy_kwargs,
192+
)(new_x).reshape(shape)
156193

157194
da = get_example_data(0)
195+
158196
if use_dask:
159197
da = da.chunk({"y": 5})
160198

@@ -165,17 +203,17 @@ def func(obj, dim, new_x):
165203
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
166204
)
167205

168-
actual = da.interp(x=xdest, method="linear")
206+
actual = da.interp(x=xdest, method=method)
169207

170208
expected = xr.DataArray(
171-
func(da, "x", xdest),
209+
func(da, "x", xdest, method),
172210
dims=["z", "y"],
173211
coords={
174212
"z": xdest["z"],
175213
"z2": xdest["z2"],
176214
"y": da["y"],
177215
"x": ("z", xdest.values),
178-
"x2": ("z", func(da["x2"], "x", xdest)),
216+
"x2": ("z", func(da["x2"], "x", xdest, method)),
179217
},
180218
)
181219
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
@@ -191,18 +229,18 @@ def func(obj, dim, new_x):
191229
},
192230
)
193231

194-
actual = da.interp(x=xdest, method="linear")
232+
actual = da.interp(x=xdest, method=method)
195233

196234
expected = xr.DataArray(
197-
func(da, "x", xdest),
235+
func(da, "x", xdest, method),
198236
dims=["z", "w", "y"],
199237
coords={
200238
"z": xdest["z"],
201239
"w": xdest["w"],
202240
"z2": xdest["z2"],
203241
"y": da["y"],
204242
"x": (("z", "w"), xdest.data),
205-
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
243+
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
206244
},
207245
)
208246
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
@@ -393,19 +431,17 @@ def test_nans(use_dask: bool) -> None:
393431
assert actual.count() > 0
394432

395433

434+
@requires_scipy
396435
@pytest.mark.parametrize("use_dask", [True, False])
397436
def test_errors(use_dask: bool) -> None:
398-
if not has_scipy:
399-
pytest.skip("scipy is not installed.")
400-
401-
# akima and spline are unavailable
437+
# spline is unavailable
402438
da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)})
403439
if not has_dask and use_dask:
404440
pytest.skip("dask is not installed in the environment.")
405441
da = da.chunk()
406442

407-
for method in ["akima", "spline"]:
408-
with pytest.raises(ValueError):
443+
for method in ["spline"]:
444+
with pytest.raises(ValueError), pytest.warns(PendingDeprecationWarning):
409445
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]
410446

411447
# not sorted
@@ -922,7 +958,10 @@ def test_interp1d_bounds_error() -> None:
922958
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
923959
],
924960
)
925-
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
961+
def test_coord_attrs(
962+
x,
963+
expect_same_attrs: bool,
964+
) -> None:
926965
base_attrs = dict(foo="bar")
927966
ds = xr.Dataset(
928967
data_vars=dict(a=2 * np.arange(5)),

xarray/tests/test_missing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def test_scipy_methods_function(method) -> None:
137137
# Note: Pandas does some wacky things with these methods and the full
138138
# integration tests won't work.
139139
da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True)
140-
actual = da.interpolate_na(method=method, dim="time")
140+
if method == "spline":
141+
with pytest.warns(PendingDeprecationWarning):
142+
actual = da.interpolate_na(method=method, dim="time")
143+
else:
144+
actual = da.interpolate_na(method=method, dim="time")
141145
assert (da.count("time") <= actual.count("time")).all()
142146

143147

0 commit comments

Comments
 (0)