Skip to content

Commit da0489f

Browse files
authored
differentiate should not cast to numpy.array (#5408)
1 parent 34dc577 commit da0489f

File tree

3 files changed

+80
-15
lines changed

3 files changed

+80
-15
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ Bug fixes
5454
- Fix 1-level multi-index incorrectly converted to single index (:issue:`5384`,
5555
:pull:`5385`).
5656
By `Benoit Bovy <https://github.com/benbovy>`_.
57+
- Don't cast a duck array in a coordinate to :py:class:`numpy.ndarray` in
58+
:py:meth:`DataArray.differentiate` (:pull:`5408`)
59+
By `Justus Magin <https://github.com/keewis>`_.
5760
- Fix the ``repr`` of :py:class:`Variable` objects with ``display_expand_data=True``
5861
(:pull:`5406`)
5962
By `Justus Magin <https://github.com/keewis>`_.

xarray/core/dataset.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6244,7 +6244,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None):
62446244
if _contains_datetime_like_objects(v):
62456245
v = v._to_numeric(datetime_unit=datetime_unit)
62466246
grad = duck_array_ops.gradient(
6247-
v.data, coord_var, edge_order=edge_order, axis=v.get_axis_num(dim)
6247+
v.data,
6248+
coord_var.data,
6249+
edge_order=edge_order,
6250+
axis=v.get_axis_num(dim),
62486251
)
62496252
variables[k] = Variable(v.dims, grad)
62506253
else:

xarray/tests/test_units.py

+73-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
import xarray as xr
9-
from xarray.core import dtypes
9+
from xarray.core import dtypes, duck_array_ops
1010

1111
from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
1212
from .test_variable import _PAD_XR_NP_ARGS
@@ -276,13 +276,13 @@ class method:
276276
This is works a bit similar to using `partial(Class.method, arg, kwarg)`
277277
"""
278278

279-
def __init__(self, name, *args, **kwargs):
279+
def __init__(self, name, *args, fallback_func=None, **kwargs):
280280
self.name = name
281+
self.fallback = fallback_func
281282
self.args = args
282283
self.kwargs = kwargs
283284

284285
def __call__(self, obj, *args, **kwargs):
285-
from collections.abc import Callable
286286
from functools import partial
287287

288288
all_args = merge_args(self.args, args)
@@ -298,21 +298,23 @@ def __call__(self, obj, *args, **kwargs):
298298
if not isinstance(obj, xarray_classes):
299299
# remove typical xarray args like "dim"
300300
exclude_kwargs = ("dim", "dims")
301+
# TODO: figure out a way to replace dim / dims with axis
301302
all_kwargs = {
302303
key: value
303304
for key, value in all_kwargs.items()
304305
if key not in exclude_kwargs
305306
}
306-
307-
func = getattr(obj, self.name, None)
308-
309-
if func is None or not isinstance(func, Callable):
310-
# fall back to module level numpy functions if not a xarray object
311-
if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
312-
numpy_func = getattr(np, self.name)
313-
func = partial(numpy_func, obj)
307+
if self.fallback is not None:
308+
func = partial(self.fallback, obj)
314309
else:
315-
raise AttributeError(f"{obj} has no method named '{self.name}'")
310+
func = getattr(obj, self.name, None)
311+
312+
if func is None or not callable(func):
313+
# fall back to module level numpy functions
314+
numpy_func = getattr(np, self.name)
315+
func = partial(numpy_func, obj)
316+
else:
317+
func = getattr(obj, self.name)
316318

317319
return func(*all_args, **all_kwargs)
318320

@@ -3662,6 +3664,65 @@ def test_stacking_reordering(self, func, dtype):
36623664
assert_units_equal(expected, actual)
36633665
assert_identical(expected, actual)
36643666

3667+
@pytest.mark.parametrize(
3668+
"variant",
3669+
(
3670+
pytest.param(
3671+
"dims", marks=pytest.mark.skip(reason="indexes don't support units")
3672+
),
3673+
"coords",
3674+
),
3675+
)
3676+
@pytest.mark.parametrize(
3677+
"func",
3678+
(
3679+
method("differentiate", fallback_func=np.gradient),
3680+
method("integrate", fallback_func=duck_array_ops.cumulative_trapezoid),
3681+
method("cumulative_integrate", fallback_func=duck_array_ops.trapz),
3682+
),
3683+
ids=repr,
3684+
)
3685+
def test_differentiate_integrate(self, func, variant, dtype):
3686+
data_unit = unit_registry.m
3687+
unit = unit_registry.s
3688+
3689+
variants = {
3690+
"dims": ("x", unit, 1),
3691+
"coords": ("u", 1, unit),
3692+
}
3693+
coord, dim_unit, coord_unit = variants.get(variant)
3694+
3695+
array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
3696+
3697+
x = np.arange(array.shape[0]) * dim_unit
3698+
y = np.arange(array.shape[1]) * dim_unit
3699+
3700+
u = np.linspace(0, 1, array.shape[0]) * coord_unit
3701+
3702+
data_array = xr.DataArray(
3703+
data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y")
3704+
)
3705+
# we want to make sure the output unit is correct
3706+
units = extract_units(data_array)
3707+
units.update(
3708+
extract_units(
3709+
func(
3710+
data_array.data,
3711+
getattr(data_array, coord).data,
3712+
axis=0,
3713+
)
3714+
)
3715+
)
3716+
3717+
expected = attach_units(
3718+
func(strip_units(data_array), coord=strip_units(coord)),
3719+
units,
3720+
)
3721+
actual = func(data_array, coord=coord)
3722+
3723+
assert_units_equal(expected, actual)
3724+
assert_identical(expected, actual)
3725+
36653726
@pytest.mark.parametrize(
36663727
"variant",
36673728
(
@@ -3676,8 +3737,6 @@ def test_stacking_reordering(self, func, dtype):
36763737
"func",
36773738
(
36783739
method("diff", dim="x"),
3679-
method("differentiate", coord="x"),
3680-
method("integrate", coord="x"),
36813740
method("quantile", q=[0.25, 0.75]),
36823741
method("reduce", func=np.sum, dim="x"),
36833742
pytest.param(lambda x: x.dot(x), id="method_dot"),

0 commit comments

Comments
 (0)