Skip to content

Commit 9a4313b

Browse files
authored
Better rolling reductions (pydata#4915)
1 parent 070d815 commit 9a4313b

File tree

5 files changed

+166
-25
lines changed

5 files changed

+166
-25
lines changed

asv_bench/benchmarks/rolling.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,44 @@ def setup(self, *args, **kwargs):
6767
super().setup(**kwargs)
6868
self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50})
6969
self.da_long = self.da_long.chunk({"x": 10000})
70+
71+
72+
class RollingMemory:
73+
def setup(self, *args, **kwargs):
74+
self.ds = xr.Dataset(
75+
{
76+
"var1": (("x", "y"), randn_xy),
77+
"var2": (("x", "t"), randn_xt),
78+
"var3": (("t",), randn_t),
79+
},
80+
coords={
81+
"x": np.arange(nx),
82+
"y": np.linspace(0, 1, ny),
83+
"t": pd.date_range("1970-01-01", periods=nt, freq="D"),
84+
"x_coords": ("x", np.linspace(1.1, 2.1, nx)),
85+
},
86+
)
87+
88+
89+
class DataArrayRollingMemory(RollingMemory):
90+
@parameterized("func", ["sum", "max", "mean"])
91+
def peakmem_ndrolling_reduce(self, func):
92+
roll = self.ds.var1.rolling(x=10, y=4)
93+
getattr(roll, func)()
94+
95+
@parameterized("func", ["sum", "max", "mean"])
96+
def peakmem_1drolling_reduce(self, func):
97+
roll = self.ds.var3.rolling(t=100)
98+
getattr(roll, func)()
99+
100+
101+
class DatasetRollingMemory(RollingMemory):
102+
@parameterized("func", ["sum", "max", "mean"])
103+
def peakmem_ndrolling_reduce(self, func):
104+
roll = self.ds.rolling(x=10, y=4)
105+
getattr(roll, func)()
106+
107+
@parameterized("func", ["sum", "max", "mean"])
108+
def peakmem_1drolling_reduce(self, func):
109+
roll = self.ds.rolling(t=100)
110+
getattr(roll, func)()

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ New Features
6868
- Xarray now leverages updates as of cftime version 1.4.1, which enable exact I/O
6969
roundtripping of ``cftime.datetime`` objects (:pull:`4758`).
7070
By `Spencer Clark <https://github.com/spencerkclark>`_.
71+
- Most rolling operations use significantly less memory. (:issue:`4325`).
72+
By `Deepak Cherian <https://github.com/dcherian>`_.
7173
- :py:meth:`~xarray.cftime_range` and :py:meth:`DataArray.resample` now support
7274
millisecond (``"L"`` or ``"ms"``) and microsecond (``"U"`` or ``"us"``) frequencies
7375
for ``cftime.datetime`` coordinates (:issue:`4097`, :pull:`4758`).

xarray/core/dtypes.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,40 +96,56 @@ def get_fill_value(dtype):
9696
return fill_value
9797

9898

99-
def get_pos_infinity(dtype):
99+
def get_pos_infinity(dtype, max_for_int=False):
100100
"""Return an appropriate positive infinity for this dtype.
101101
102102
Parameters
103103
----------
104104
dtype : np.dtype
105+
max_for_int : bool
106+
Return np.iinfo(dtype).max instead of np.inf
105107
106108
Returns
107109
-------
108110
fill_value : positive infinity value corresponding to this dtype.
109111
"""
110-
if issubclass(dtype.type, (np.floating, np.integer)):
112+
if issubclass(dtype.type, np.floating):
111113
return np.inf
112114

115+
if issubclass(dtype.type, np.integer):
116+
if max_for_int:
117+
return np.iinfo(dtype).max
118+
else:
119+
return np.inf
120+
113121
if issubclass(dtype.type, np.complexfloating):
114122
return np.inf + 1j * np.inf
115123

116124
return INF
117125

118126

119-
def get_neg_infinity(dtype):
127+
def get_neg_infinity(dtype, min_for_int=False):
120128
"""Return an appropriate positive infinity for this dtype.
121129
122130
Parameters
123131
----------
124132
dtype : np.dtype
133+
min_for_int : bool
134+
Return np.iinfo(dtype).min instead of -np.inf
125135
126136
Returns
127137
-------
128138
fill_value : positive infinity value corresponding to this dtype.
129139
"""
130-
if issubclass(dtype.type, (np.floating, np.integer)):
140+
if issubclass(dtype.type, np.floating):
131141
return -np.inf
132142

143+
if issubclass(dtype.type, np.integer):
144+
if min_for_int:
145+
return np.iinfo(dtype).min
146+
else:
147+
return -np.inf
148+
133149
if issubclass(dtype.type, np.complexfloating):
134150
return -np.inf - 1j * np.inf
135151

xarray/core/rolling.py

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,51 @@ def __repr__(self):
111111
def __len__(self):
112112
return self.obj.sizes[self.dim]
113113

114-
def _reduce_method(name: str) -> Callable: # type: ignore
115-
array_agg_func = getattr(duck_array_ops, name)
114+
def _reduce_method(name: str, fillna, rolling_agg_func: Callable = None) -> Callable: # type: ignore
115+
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
116+
a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean)."""
117+
if rolling_agg_func:
118+
array_agg_func = None
119+
else:
120+
array_agg_func = getattr(duck_array_ops, name)
121+
116122
bottleneck_move_func = getattr(bottleneck, "move_" + name, None)
117123

118124
def method(self, keep_attrs=None, **kwargs):
119125

120126
keep_attrs = self._get_keep_attrs(keep_attrs)
121127

122128
return self._numpy_or_bottleneck_reduce(
123-
array_agg_func, bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
129+
array_agg_func,
130+
bottleneck_move_func,
131+
rolling_agg_func,
132+
keep_attrs=keep_attrs,
133+
fillna=fillna,
134+
**kwargs,
124135
)
125136

126137
method.__name__ = name
127138
method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name)
128139
return method
129140

130-
argmax = _reduce_method("argmax")
131-
argmin = _reduce_method("argmin")
132-
max = _reduce_method("max")
133-
min = _reduce_method("min")
134-
mean = _reduce_method("mean")
135-
prod = _reduce_method("prod")
136-
sum = _reduce_method("sum")
137-
std = _reduce_method("std")
138-
var = _reduce_method("var")
139-
median = _reduce_method("median")
141+
def _mean(self, keep_attrs, **kwargs):
142+
result = self.sum(keep_attrs=False, **kwargs) / self.count(keep_attrs=False)
143+
if keep_attrs:
144+
result.attrs = self.obj.attrs
145+
return result
146+
147+
_mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean")
148+
149+
argmax = _reduce_method("argmax", dtypes.NINF)
150+
argmin = _reduce_method("argmin", dtypes.INF)
151+
max = _reduce_method("max", dtypes.NINF)
152+
min = _reduce_method("min", dtypes.INF)
153+
prod = _reduce_method("prod", 1)
154+
sum = _reduce_method("sum", 0)
155+
mean = _reduce_method("mean", None, _mean)
156+
std = _reduce_method("std", None)
157+
var = _reduce_method("var", None)
158+
median = _reduce_method("median", None)
140159

141160
def count(self, keep_attrs=None):
142161
keep_attrs = self._get_keep_attrs(keep_attrs)
@@ -301,6 +320,24 @@ def construct(
301320
302321
"""
303322

323+
return self._construct(
324+
self.obj,
325+
window_dim=window_dim,
326+
stride=stride,
327+
fill_value=fill_value,
328+
keep_attrs=keep_attrs,
329+
**window_dim_kwargs,
330+
)
331+
332+
def _construct(
333+
self,
334+
obj,
335+
window_dim=None,
336+
stride=1,
337+
fill_value=dtypes.NA,
338+
keep_attrs=None,
339+
**window_dim_kwargs,
340+
):
304341
from .dataarray import DataArray
305342

306343
keep_attrs = self._get_keep_attrs(keep_attrs)
@@ -317,18 +354,18 @@ def construct(
317354
)
318355
stride = self._mapping_to_list(stride, default=1)
319356

320-
window = self.obj.variable.rolling_window(
357+
window = obj.variable.rolling_window(
321358
self.dim, self.window, window_dim, self.center, fill_value=fill_value
322359
)
323360

324-
attrs = self.obj.attrs if keep_attrs else {}
361+
attrs = obj.attrs if keep_attrs else {}
325362

326363
result = DataArray(
327364
window,
328-
dims=self.obj.dims + tuple(window_dim),
329-
coords=self.obj.coords,
365+
dims=obj.dims + tuple(window_dim),
366+
coords=obj.coords,
330367
attrs=attrs,
331-
name=self.obj.name,
368+
name=obj.name,
332369
)
333370
return result.isel(
334371
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
@@ -393,7 +430,18 @@ def reduce(self, func, keep_attrs=None, **kwargs):
393430
d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}")
394431
for d in self.dim
395432
}
396-
windows = self.construct(rolling_dim, keep_attrs=keep_attrs)
433+
434+
# save memory with reductions GH4325
435+
fillna = kwargs.pop("fillna", dtypes.NA)
436+
if fillna is not dtypes.NA:
437+
obj = self.obj.fillna(fillna)
438+
else:
439+
obj = self.obj
440+
441+
windows = self._construct(
442+
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
443+
)
444+
397445
result = windows.reduce(
398446
func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
399447
)
@@ -470,7 +518,13 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
470518
return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name)
471519

472520
def _numpy_or_bottleneck_reduce(
473-
self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs
521+
self,
522+
array_agg_func,
523+
bottleneck_move_func,
524+
rolling_agg_func,
525+
keep_attrs,
526+
fillna,
527+
**kwargs,
474528
):
475529
if "dim" in kwargs:
476530
warnings.warn(
@@ -494,6 +548,18 @@ def _numpy_or_bottleneck_reduce(
494548
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
495549
)
496550
else:
551+
if rolling_agg_func:
552+
return rolling_agg_func(
553+
self, keep_attrs=self._get_keep_attrs(keep_attrs)
554+
)
555+
if fillna is not None:
556+
if fillna is dtypes.INF:
557+
fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True)
558+
elif fillna is dtypes.NINF:
559+
fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True)
560+
kwargs.setdefault("skipna", False)
561+
kwargs.setdefault("fillna", fillna)
562+
497563
return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)
498564

499565

@@ -600,13 +666,19 @@ def _counts(self, keep_attrs):
600666
)
601667

602668
def _numpy_or_bottleneck_reduce(
603-
self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs
669+
self,
670+
array_agg_func,
671+
bottleneck_move_func,
672+
rolling_agg_func,
673+
keep_attrs,
674+
**kwargs,
604675
):
605676
return self._dataset_implementation(
606677
functools.partial(
607678
DataArrayRolling._numpy_or_bottleneck_reduce,
608679
array_agg_func=array_agg_func,
609680
bottleneck_move_func=bottleneck_move_func,
681+
rolling_agg_func=rolling_agg_func,
610682
),
611683
keep_attrs=keep_attrs,
612684
**kwargs,

xarray/tests/test_dataarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6623,6 +6623,16 @@ def test_ndrolling_reduce(da, center, min_periods, name):
66236623
assert_allclose(actual, expected)
66246624
assert actual.dims == expected.dims
66256625

6626+
if name in ["mean"]:
6627+
# test our reimplementation of nanmean using np.nanmean
6628+
expected = getattr(rolling_obj.construct({"time": "tw", "x": "xw"}), name)(
6629+
["tw", "xw"]
6630+
)
6631+
count = rolling_obj.count()
6632+
if min_periods is None:
6633+
min_periods = 1
6634+
assert_allclose(actual, expected.where(count >= min_periods))
6635+
66266636

66276637
@pytest.mark.parametrize("center", (True, False, (True, False)))
66286638
@pytest.mark.parametrize("fill_value", (np.nan, 0.0))

0 commit comments

Comments
 (0)