Skip to content

Commit 29fe679

Browse files
dcherianIllviljanpre-commit-ci[bot]headtr1ck
authored
Rewrite interp to use apply_ufunc (pydata#9881)
* Don't eagerly compute dask arrays in localize * Clean up test * Clean up Variable handling * Silence test warning * Use apply_ufunc instead * Add test for pydata#4463 Closes pydata#4463 * complete tests * Add comments * Clear up broadcasting * typing * try a different warning filter * one more fix * types + more duck_array_ops * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Michael Niklas <[email protected]> * Apply suggestions from code review Co-authored-by: Illviljan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Illviljan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Revert "Apply suggestions from code review" This reverts commit 1b9845d. --------- Co-authored-by: Illviljan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas <[email protected]>
1 parent a90fff9 commit 29fe679

File tree

3 files changed

+269
-204
lines changed

3 files changed

+269
-204
lines changed

xarray/core/dataset.py

+18-37
Original file line numberDiff line numberDiff line change
@@ -2921,19 +2921,11 @@ def _validate_interp_indexers(
29212921
"""Variant of _validate_indexers to be used for interpolation"""
29222922
for k, v in self._validate_indexers(indexers):
29232923
if isinstance(v, Variable):
2924-
if v.ndim == 1:
2925-
yield k, v.to_index_variable()
2926-
else:
2927-
yield k, v
2928-
elif isinstance(v, int):
2924+
yield k, v
2925+
elif is_scalar(v):
29292926
yield k, Variable((), v, attrs=self.coords[k].attrs)
29302927
elif isinstance(v, np.ndarray):
2931-
if v.ndim == 0:
2932-
yield k, Variable((), v, attrs=self.coords[k].attrs)
2933-
elif v.ndim == 1:
2934-
yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs)
2935-
else:
2936-
raise AssertionError() # Already tested by _validate_indexers
2928+
yield k, Variable(dims=(k,), data=v, attrs=self.coords[k].attrs)
29372929
else:
29382930
raise TypeError(type(v))
29392931

@@ -4127,18 +4119,6 @@ def interp(
41274119

41284120
coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
41294121
indexers = dict(self._validate_interp_indexers(coords))
4130-
4131-
if coords:
4132-
# This avoids broadcasting over coordinates that are both in
4133-
# the original array AND in the indexing array. It essentially
4134-
# forces interpolation along the shared coordinates.
4135-
sdims = (
4136-
set(self.dims)
4137-
.intersection(*[set(nx.dims) for nx in indexers.values()])
4138-
.difference(coords.keys())
4139-
)
4140-
indexers.update({d: self.variables[d] for d in sdims})
4141-
41424122
obj = self if assume_sorted else self.sortby(list(coords))
41434123

41444124
def maybe_variable(obj, k):
@@ -4169,16 +4149,18 @@ def _validate_interp_indexer(x, new_x):
41694149
for k, v in indexers.items()
41704150
}
41714151

4172-
# optimization: subset to coordinate range of the target index
4173-
if method in ["linear", "nearest"]:
4174-
for k, v in validated_indexers.items():
4175-
obj, newidx = missing._localize(obj, {k: v})
4176-
validated_indexers[k] = newidx[k]
4177-
4178-
# optimization: create dask coordinate arrays once per Dataset
4179-
# rather than once per Variable when dask.array.unify_chunks is called later
4180-
# GH4739
4181-
if obj.__dask_graph__():
4152+
has_chunked_array = bool(
4153+
any(is_chunked_array(v._data) for v in obj._variables.values())
4154+
)
4155+
if has_chunked_array:
4156+
# optimization: subset to coordinate range of the target index
4157+
if method in ["linear", "nearest"]:
4158+
for k, v in validated_indexers.items():
4159+
obj, newidx = missing._localize(obj, {k: v})
4160+
validated_indexers[k] = newidx[k]
4161+
# optimization: create dask coordinate arrays once per Dataset
4162+
# rather than once per Variable when dask.array.unify_chunks is called later
4163+
# GH4739
41824164
dask_indexers = {
41834165
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
41844166
for k, (index, dest) in validated_indexers.items()
@@ -4190,10 +4172,9 @@ def _validate_interp_indexer(x, new_x):
41904172
if name in indexers:
41914173
continue
41924174

4193-
if is_duck_dask_array(var.data):
4194-
use_indexers = dask_indexers
4195-
else:
4196-
use_indexers = validated_indexers
4175+
use_indexers = (
4176+
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
4177+
)
41974178

41984179
dtype_kind = var.dtype.kind
41994180
if dtype_kind in "uifc":

0 commit comments

Comments
 (0)