Skip to content

Commit e1dafe6

Browse files
Fix map_blocks example (#4305)
1 parent 8cea798 commit e1dafe6

File tree

3 files changed

+45
-43
lines changed

3 files changed

+45
-43
lines changed

xarray/core/dataarray.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -3358,9 +3358,12 @@ def map_blocks(
33583358
... clim = gb.mean(dim="time")
33593359
... return gb - clim
33603360
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
3361+
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
33613362
>>> np.random.seed(123)
33623363
>>> array = xr.DataArray(
3363-
... np.random.rand(len(time)), dims="time", coords=[time]
3364+
... np.random.rand(len(time)),
3365+
... dims=["time"],
3366+
... coords={"time": time, "month": month},
33643367
... ).chunk()
33653368
>>> array.map_blocks(calculate_anomaly, template=array).compute()
33663369
<xarray.DataArray (time: 24)>
@@ -3371,21 +3374,19 @@ def map_blocks(
33713374
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
33723375
Coordinates:
33733376
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
3377+
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
33743378
33753379
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
33763380
to the function being applied in ``xr.map_blocks()``:
33773381
33783382
>>> array.map_blocks(
33793383
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
3380-
... )
3384+
... ) # doctest: +ELLIPSIS
33813385
<xarray.DataArray (time: 24)>
3382-
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
3383-
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
3384-
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
3385-
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
3386-
0.14482397, 0.35985481, 0.23487834, 0.12144652])
3386+
dask.array<calculate_anomaly-...-<this, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray>
33873387
Coordinates:
3388-
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
3388+
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
3389+
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
33893390
"""
33903391
from .parallel import map_blocks
33913392

@@ -3875,9 +3876,10 @@ def argmin(
38753876
>>> array.isel(array.argmin(...))
38763877
array(-1)
38773878
3878-
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
3879-
... [[1, 3, 2], [2, -5, 1], [2, 3, 1]]],
3880-
... dims=("x", "y", "z"))
3879+
>>> array = xr.DataArray(
3880+
... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, -5, 1], [2, 3, 1]]],
3881+
... dims=("x", "y", "z"),
3882+
... )
38813883
>>> array.min(dim="x")
38823884
<xarray.DataArray (y: 3, z: 3)>
38833885
array([[ 1, 2, 1],
@@ -3977,9 +3979,10 @@ def argmax(
39773979
<xarray.DataArray ()>
39783980
array(3)
39793981
3980-
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
3981-
... [[1, 3, 2], [2, 5, 1], [2, 3, 1]]],
3982-
... dims=("x", "y", "z"))
3982+
>>> array = xr.DataArray(
3983+
... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, 5, 1], [2, 3, 1]]],
3984+
... dims=("x", "y", "z"),
3985+
... )
39833986
>>> array.max(dim="x")
39843987
<xarray.DataArray (y: 3, z: 3)>
39853988
array([[3, 3, 2],

xarray/core/dataset.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -5817,35 +5817,36 @@ def map_blocks(
58175817
... clim = gb.mean(dim="time")
58185818
... return gb - clim
58195819
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
5820+
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
58205821
>>> np.random.seed(123)
58215822
>>> array = xr.DataArray(
5822-
... np.random.rand(len(time)), dims="time", coords=[time]
5823+
... np.random.rand(len(time)),
5824+
... dims=["time"],
5825+
... coords={"time": time, "month": month},
58235826
... ).chunk()
58245827
>>> ds = xr.Dataset({"a": array})
58255828
>>> ds.map_blocks(calculate_anomaly, template=ds).compute()
5826-
<xarray.DataArray (time: 24)>
5827-
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
5828-
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
5829-
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
5830-
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
5831-
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
5829+
<xarray.Dataset>
5830+
Dimensions: (time: 24)
58325831
Coordinates:
58335832
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
5833+
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
5834+
Data variables:
5835+
a (time) float64 0.1289 0.1132 -0.0856 ... 0.2287 0.1906 -0.05901
58345836
58355837
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
58365838
to the function being applied in ``xr.map_blocks()``:
58375839
58385840
>>> ds.map_blocks(
58395841
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds,
58405842
... )
5841-
<xarray.DataArray (time: 24)>
5842-
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
5843-
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
5844-
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
5845-
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
5846-
0.14482397, 0.35985481, 0.23487834, 0.12144652])
5843+
<xarray.Dataset>
5844+
Dimensions: (time: 24)
58475845
Coordinates:
5848-
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
5846+
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
5847+
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
5848+
Data variables:
5849+
a (time) float64 dask.array<chunksize=(24,), meta=np.ndarray>
58495850
"""
58505851
from .parallel import map_blocks
58515852

xarray/core/parallel.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,14 @@ def map_blocks(
235235
... clim = gb.mean(dim="time")
236236
... return gb - clim
237237
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
238+
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
238239
>>> np.random.seed(123)
239240
>>> array = xr.DataArray(
240-
... np.random.rand(len(time)), dims="time", coords=[time]
241+
... np.random.rand(len(time)),
242+
... dims=["time"],
243+
... coords={"time": time, "month": month},
241244
... ).chunk()
242-
>>> xr.map_blocks(calculate_anomaly, array, template=array).compute()
245+
>>> array.map_blocks(calculate_anomaly, template=array).compute()
243246
<xarray.DataArray (time: 24)>
244247
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
245248
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
@@ -248,25 +251,20 @@ def map_blocks(
248251
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
249252
Coordinates:
250253
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
254+
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
251255
252256
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
253257
to the function being applied in ``xr.map_blocks()``:
254258
255-
>>> xr.map_blocks(
256-
... calculate_anomaly,
257-
... array,
258-
... kwargs={"groupby_type": "time.year"},
259-
... template=array,
260-
... )
259+
>>> array.map_blocks(
260+
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
261+
... ) # doctest: +ELLIPSIS
261262
<xarray.DataArray (time: 24)>
262-
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
263-
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
264-
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
265-
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
266-
0.14482397, 0.35985481, 0.23487834, 0.12144652])
263+
dask.array<calculate_anomaly-...-<this, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray>
267264
Coordinates:
268-
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
269-
"""
265+
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
266+
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
267+
"""
270268

271269
def _wrapper(
272270
func: Callable,

0 commit comments

Comments
 (0)