6
6
import pytest
7
7
8
8
import xarray as xr
9
- from xarray .core import dtypes
9
+ from xarray .core import dtypes , duck_array_ops
10
10
11
11
from . import assert_allclose , assert_duckarray_allclose , assert_equal , assert_identical
12
12
from .test_variable import _PAD_XR_NP_ARGS
@@ -276,13 +276,13 @@ class method:
276
276
This is works a bit similar to using `partial(Class.method, arg, kwarg)`
277
277
"""
278
278
279
- def __init__ (self , name , * args , ** kwargs ):
279
+ def __init__ (self , name , * args , fallback_func = None , ** kwargs ):
280
280
self .name = name
281
+ self .fallback = fallback_func
281
282
self .args = args
282
283
self .kwargs = kwargs
283
284
284
285
def __call__ (self , obj , * args , ** kwargs ):
285
- from collections .abc import Callable
286
286
from functools import partial
287
287
288
288
all_args = merge_args (self .args , args )
@@ -298,21 +298,23 @@ def __call__(self, obj, *args, **kwargs):
298
298
if not isinstance (obj , xarray_classes ):
299
299
# remove typical xarray args like "dim"
300
300
exclude_kwargs = ("dim" , "dims" )
301
+ # TODO: figure out a way to replace dim / dims with axis
301
302
all_kwargs = {
302
303
key : value
303
304
for key , value in all_kwargs .items ()
304
305
if key not in exclude_kwargs
305
306
}
306
-
307
- func = getattr (obj , self .name , None )
308
-
309
- if func is None or not isinstance (func , Callable ):
310
- # fall back to module level numpy functions if not a xarray object
311
- if not isinstance (obj , (xr .Variable , xr .DataArray , xr .Dataset )):
312
- numpy_func = getattr (np , self .name )
313
- func = partial (numpy_func , obj )
307
+ if self .fallback is not None :
308
+ func = partial (self .fallback , obj )
314
309
else :
315
- raise AttributeError (f"{ obj } has no method named '{ self .name } '" )
310
+ func = getattr (obj , self .name , None )
311
+
312
+ if func is None or not callable (func ):
313
+ # fall back to module level numpy functions
314
+ numpy_func = getattr (np , self .name )
315
+ func = partial (numpy_func , obj )
316
+ else :
317
+ func = getattr (obj , self .name )
316
318
317
319
return func (* all_args , ** all_kwargs )
318
320
@@ -3662,6 +3664,65 @@ def test_stacking_reordering(self, func, dtype):
3662
3664
assert_units_equal (expected , actual )
3663
3665
assert_identical (expected , actual )
3664
3666
3667
+ @pytest .mark .parametrize (
3668
+ "variant" ,
3669
+ (
3670
+ pytest .param (
3671
+ "dims" , marks = pytest .mark .skip (reason = "indexes don't support units" )
3672
+ ),
3673
+ "coords" ,
3674
+ ),
3675
+ )
3676
+ @pytest .mark .parametrize (
3677
+ "func" ,
3678
+ (
3679
+ method ("differentiate" , fallback_func = np .gradient ),
3680
+ method ("integrate" , fallback_func = duck_array_ops .cumulative_trapezoid ),
3681
+ method ("cumulative_integrate" , fallback_func = duck_array_ops .trapz ),
3682
+ ),
3683
+ ids = repr ,
3684
+ )
3685
+ def test_differentiate_integrate (self , func , variant , dtype ):
3686
+ data_unit = unit_registry .m
3687
+ unit = unit_registry .s
3688
+
3689
+ variants = {
3690
+ "dims" : ("x" , unit , 1 ),
3691
+ "coords" : ("u" , 1 , unit ),
3692
+ }
3693
+ coord , dim_unit , coord_unit = variants .get (variant )
3694
+
3695
+ array = np .linspace (0 , 10 , 5 * 10 ).reshape (5 , 10 ).astype (dtype ) * data_unit
3696
+
3697
+ x = np .arange (array .shape [0 ]) * dim_unit
3698
+ y = np .arange (array .shape [1 ]) * dim_unit
3699
+
3700
+ u = np .linspace (0 , 1 , array .shape [0 ]) * coord_unit
3701
+
3702
+ data_array = xr .DataArray (
3703
+ data = array , coords = {"x" : x , "y" : y , "u" : ("x" , u )}, dims = ("x" , "y" )
3704
+ )
3705
+ # we want to make sure the output unit is correct
3706
+ units = extract_units (data_array )
3707
+ units .update (
3708
+ extract_units (
3709
+ func (
3710
+ data_array .data ,
3711
+ getattr (data_array , coord ).data ,
3712
+ axis = 0 ,
3713
+ )
3714
+ )
3715
+ )
3716
+
3717
+ expected = attach_units (
3718
+ func (strip_units (data_array ), coord = strip_units (coord )),
3719
+ units ,
3720
+ )
3721
+ actual = func (data_array , coord = coord )
3722
+
3723
+ assert_units_equal (expected , actual )
3724
+ assert_identical (expected , actual )
3725
+
3665
3726
@pytest .mark .parametrize (
3666
3727
"variant" ,
3667
3728
(
@@ -3676,8 +3737,6 @@ def test_stacking_reordering(self, func, dtype):
3676
3737
"func" ,
3677
3738
(
3678
3739
method ("diff" , dim = "x" ),
3679
- method ("differentiate" , coord = "x" ),
3680
- method ("integrate" , coord = "x" ),
3681
3740
method ("quantile" , q = [0.25 , 0.75 ]),
3682
3741
method ("reduce" , func = np .sum , dim = "x" ),
3683
3742
pytest .param (lambda x : x .dot (x ), id = "method_dot" ),
0 commit comments