16
16
assert_identical ,
17
17
has_dask ,
18
18
has_scipy ,
19
+ has_scipy_ge_1_13 ,
19
20
requires_cftime ,
20
21
requires_dask ,
21
22
requires_scipy ,
@@ -132,29 +133,66 @@ def func(obj, new_x):
132
133
assert_allclose (actual , expected )
133
134
134
135
135
- @pytest .mark .parametrize ("use_dask" , [False , True ])
136
- def test_interpolate_vectorize (use_dask : bool ) -> None :
137
- if not has_scipy :
138
- pytest .skip ("scipy is not installed." )
139
-
140
- if not has_dask and use_dask :
141
- pytest .skip ("dask is not installed in the environment." )
142
-
136
+ @requires_scipy
137
+ @pytest .mark .parametrize (
138
+ "use_dask, method" ,
139
+ (
140
+ (False , "linear" ),
141
+ (False , "akima" ),
142
+ pytest .param (
143
+ False ,
144
+ "makima" ,
145
+ marks = pytest .mark .skipif (not has_scipy_ge_1_13 , reason = "scipy too old" ),
146
+ ),
147
+ pytest .param (
148
+ True ,
149
+ "linear" ,
150
+ marks = pytest .mark .skipif (not has_dask , reason = "dask not available" ),
151
+ ),
152
+ pytest .param (
153
+ True ,
154
+ "akima" ,
155
+ marks = pytest .mark .skipif (not has_dask , reason = "dask not available" ),
156
+ ),
157
+ ),
158
+ )
159
+ def test_interpolate_vectorize (use_dask : bool , method : InterpOptions ) -> None :
143
160
# scipy interpolation for the reference
144
- def func (obj , dim , new_x ):
161
+ def func (obj , dim , new_x , method ):
162
+ scipy_kwargs = {}
163
+ interpolant_options = {
164
+ "barycentric" : scipy .interpolate .BarycentricInterpolator ,
165
+ "krogh" : scipy .interpolate .KroghInterpolator ,
166
+ "pchip" : scipy .interpolate .PchipInterpolator ,
167
+ "akima" : scipy .interpolate .Akima1DInterpolator ,
168
+ "makima" : scipy .interpolate .Akima1DInterpolator ,
169
+ }
170
+
145
171
shape = [s for i , s in enumerate (obj .shape ) if i != obj .get_axis_num (dim )]
146
172
for s in new_x .shape [::- 1 ]:
147
173
shape .insert (obj .get_axis_num (dim ), s )
148
174
149
- return scipy .interpolate .interp1d (
150
- da [dim ],
151
- obj .data ,
152
- axis = obj .get_axis_num (dim ),
153
- bounds_error = False ,
154
- fill_value = np .nan ,
155
- )(new_x ).reshape (shape )
175
+ if method in interpolant_options :
176
+ interpolant = interpolant_options [method ]
177
+ if method == "makima" :
178
+ scipy_kwargs ["method" ] = method
179
+ return interpolant (
180
+ da [dim ], obj .data , axis = obj .get_axis_num (dim ), ** scipy_kwargs
181
+ )(new_x ).reshape (shape )
182
+ else :
183
+
184
+ return scipy .interpolate .interp1d (
185
+ da [dim ],
186
+ obj .data ,
187
+ axis = obj .get_axis_num (dim ),
188
+ kind = method ,
189
+ bounds_error = False ,
190
+ fill_value = np .nan ,
191
+ ** scipy_kwargs ,
192
+ )(new_x ).reshape (shape )
156
193
157
194
da = get_example_data (0 )
195
+
158
196
if use_dask :
159
197
da = da .chunk ({"y" : 5 })
160
198
@@ -165,17 +203,17 @@ def func(obj, dim, new_x):
165
203
coords = {"z" : np .random .randn (30 ), "z2" : ("z" , np .random .randn (30 ))},
166
204
)
167
205
168
- actual = da .interp (x = xdest , method = "linear" )
206
+ actual = da .interp (x = xdest , method = method )
169
207
170
208
expected = xr .DataArray (
171
- func (da , "x" , xdest ),
209
+ func (da , "x" , xdest , method ),
172
210
dims = ["z" , "y" ],
173
211
coords = {
174
212
"z" : xdest ["z" ],
175
213
"z2" : xdest ["z2" ],
176
214
"y" : da ["y" ],
177
215
"x" : ("z" , xdest .values ),
178
- "x2" : ("z" , func (da ["x2" ], "x" , xdest )),
216
+ "x2" : ("z" , func (da ["x2" ], "x" , xdest , method )),
179
217
},
180
218
)
181
219
assert_allclose (actual , expected .transpose ("z" , "y" , transpose_coords = True ))
@@ -191,18 +229,18 @@ def func(obj, dim, new_x):
191
229
},
192
230
)
193
231
194
- actual = da .interp (x = xdest , method = "linear" )
232
+ actual = da .interp (x = xdest , method = method )
195
233
196
234
expected = xr .DataArray (
197
- func (da , "x" , xdest ),
235
+ func (da , "x" , xdest , method ),
198
236
dims = ["z" , "w" , "y" ],
199
237
coords = {
200
238
"z" : xdest ["z" ],
201
239
"w" : xdest ["w" ],
202
240
"z2" : xdest ["z2" ],
203
241
"y" : da ["y" ],
204
242
"x" : (("z" , "w" ), xdest .data ),
205
- "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest )),
243
+ "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest , method )),
206
244
},
207
245
)
208
246
assert_allclose (actual , expected .transpose ("z" , "w" , "y" , transpose_coords = True ))
@@ -393,19 +431,17 @@ def test_nans(use_dask: bool) -> None:
393
431
assert actual .count () > 0
394
432
395
433
434
+ @requires_scipy
396
435
@pytest .mark .parametrize ("use_dask" , [True , False ])
397
436
def test_errors (use_dask : bool ) -> None :
398
- if not has_scipy :
399
- pytest .skip ("scipy is not installed." )
400
-
401
- # akima and spline are unavailable
437
+ # spline is unavailable
402
438
da = xr .DataArray ([0 , 1 , np .nan , 2 ], dims = "x" , coords = {"x" : range (4 )})
403
439
if not has_dask and use_dask :
404
440
pytest .skip ("dask is not installed in the environment." )
405
441
da = da .chunk ()
406
442
407
- for method in ["akima" , " spline" ]:
408
- with pytest .raises (ValueError ):
443
+ for method in ["spline" ]:
444
+ with pytest .raises (ValueError ), pytest . warns ( PendingDeprecationWarning ) :
409
445
da .interp (x = [0.5 , 1.5 ], method = method ) # type: ignore[arg-type]
410
446
411
447
# not sorted
@@ -922,7 +958,10 @@ def test_interp1d_bounds_error() -> None:
922
958
(("x" , np .array ([0 , 0.5 , 1 , 2 ]), dict (unit = "s" )), False ),
923
959
],
924
960
)
925
- def test_coord_attrs (x , expect_same_attrs : bool ) -> None :
961
+ def test_coord_attrs (
962
+ x ,
963
+ expect_same_attrs : bool ,
964
+ ) -> None :
926
965
base_attrs = dict (foo = "bar" )
927
966
ds = xr .Dataset (
928
967
data_vars = dict (a = 2 * np .arange (5 )),
0 commit comments