Skip to content

Commit 136889e

Browse files
committed
feat: Series.rolling_mean
1 parent d5feb6f commit 136889e

File tree

4 files changed

+172
-1
lines changed

4 files changed

+172
-1
lines changed

narwhals/_arrow/series.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Sequence
99
from typing import overload
1010

11+
from narwhals._arrow.utils import _rolling
1112
from narwhals._arrow.utils import cast_for_truediv
1213
from narwhals._arrow.utils import floordiv_compat
1314
from narwhals._arrow.utils import narwhals_to_native_dtype
@@ -714,13 +715,41 @@ def clip(
714715
def to_arrow(self: Self) -> pa.Array:
715716
return self._native_series.combine_chunks()
716717

717-
def mode(self: Self) -> ArrowSeries:
718+
def mode(self: Self) -> Self:
718719
plx = self.__narwhals_namespace__()
719720
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
720721
return self.value_counts(name=col_token, normalize=False).filter(
721722
plx.col(col_token) == plx.col(col_token).max()
722723
)[self.name]
723724

725+
def rolling_mean(
726+
self: Self,
727+
window_size: int,
728+
weights: list[float] | None,
729+
*,
730+
min_periods: int | None,
731+
center: bool,
732+
) -> Self:
733+
import pyarrow as pa
734+
import pyarrow.compute as pc
735+
736+
native_series = self._native_series
737+
result = pa.chunked_array(
738+
[
739+
[
740+
pc.mean(v) if v is not None else None
741+
for v in _rolling(
742+
native_series,
743+
window_size=window_size,
744+
weights=weights,
745+
min_periods=min_periods,
746+
center=center,
747+
)
748+
]
749+
]
750+
)
751+
return self._from_native_series(result)
752+
724753
def __iter__(self: Self) -> Iterator[Any]:
725754
yield from self._native_series.__iter__()
726755

narwhals/_arrow/utils.py

+43
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Generator
56
from typing import Sequence
67

78
from narwhals.utils import isinstance_or_issubclass
@@ -420,3 +421,45 @@ def _parse_time_format(arr: pa.Array) -> str:
420421

421422
matches = pc.extract_regex(arr, pattern=TIME_RE)
422423
return "%H:%M:%S" if pc.all(matches.is_valid()).as_py() else ""
424+
425+
426+
def _rolling(
427+
array: pa.chunked_array,
428+
window_size: int,
429+
weights: list[float] | None,
430+
*,
431+
min_periods: int | None,
432+
center: bool,
433+
) -> Generator[pa.array | None, None, None]:
434+
import pyarrow as pa
435+
import pyarrow.compute as pc
436+
437+
# Default min_periods to window_size if not provided
438+
if min_periods is None:
439+
min_periods = window_size
440+
441+
# Convert weights to a pyarrow array for elementwise operations if given
442+
weights = pa.array(weights) if weights else pa.scalar(1)
443+
444+
# Flatten the chunked array to work with it as a contiguous array
445+
flat_array = array.combine_chunks()
446+
size = len(flat_array)
447+
# Calculate rolling mean by slicing the flat array for each position
448+
split_points = (
449+
(max(0, i - window_size // 2), min(size, i + window_size // 2 + 1))
450+
if center
451+
else (max(0, i - window_size + 1), i + 1)
452+
for i in range(size)
453+
)
454+
455+
for start, end in split_points:
456+
weighted_window = pc.drop_null(
457+
pc.multiply(flat_array.slice(start, end - start), weights)
458+
)
459+
460+
num_valid = len(weighted_window)
461+
462+
if num_valid >= min_periods:
463+
yield weighted_window
464+
else:
465+
yield None

narwhals/_pandas_like/series.py

+19
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,25 @@ def mode(self: Self) -> Self:
683683
result.name = native_series.name
684684
return self._from_native_series(result)
685685

686+
def rolling_mean(
687+
self: Self,
688+
window_size: int,
689+
weights: list[float] | None,
690+
*,
691+
min_periods: int | None,
692+
center: bool,
693+
) -> Self:
694+
if weights is not None:
695+
msg = (
696+
f"`weights` argument is not supported for {self._implementation} backend"
697+
)
698+
raise NotImplementedError(msg)
699+
700+
result = self._native_series.rolling(
701+
window=window_size, min_periods=min_periods, center=center
702+
).mean()
703+
return self._from_native_series(result)
704+
686705
def __iter__(self: Self) -> Iterator[Any]:
687706
yield from self._native_series.__iter__()
688707

narwhals/series.py

+80
Original file line numberDiff line numberDiff line change
@@ -2525,6 +2525,86 @@ def mode(self: Self) -> Self:
25252525
"""
25262526
return self._from_compliant_series(self._compliant_series.mode())
25272527

2528+
def rolling_mean(
2529+
self: Self,
2530+
window_size: int,
2531+
weights: list[float] | None = None,
2532+
*,
2533+
min_periods: int | None = None,
2534+
center: bool = False,
2535+
) -> Self:
2536+
"""
2537+
Apply a rolling mean (moving mean) over the values of the series.
2538+
2539+
A window of length `window_size` will traverse the series. The values that fill
2540+
this window will (optionally) be multiplied with the weights given by the
2541+
`weight` vector. The resulting values will be aggregated to their mean.
2542+
2543+
The window at a given row will include the row itself and the `window_size - 1`
2544+
elements before it.
2545+
2546+
Arguments:
2547+
window_size: The length of the window in number of elements.
2548+
weights: An optional slice with the same length as the window that will be
2549+
multiplied elementwise with the values in the window.
2550+
min_periods: The number of values in the window that should be non-null before
2551+
computing a result. If set to `None` (default), it will be set equal to
2552+
`window_size`.
2553+
center: Set the labels at the center of the window.
2554+
2555+
Examples:
2556+
>>> import narwhals as nw
2557+
>>> import pandas as pd
2558+
>>> import polars as pl
2559+
>>> import pyarrow as pa
2560+
2561+
>>> data = [100, 200, 300]
2562+
>>> s_pd = pd.Series(name="a", data=data)
2563+
>>> s_pl = pl.Series(name="a", values=data)
2564+
>>> s_pa = pa.chunked_array([data])
2565+
2566+
We define a library agnostic function:
2567+
2568+
>>> @nw.narwhalify
2569+
... def func(s):
2570+
... return s.rolling_mean(window_size=2)
2571+
2572+
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
2573+
2574+
>>> func(s_pd)
2575+
0 NaN
2576+
1 150.0
2577+
2 250.0
2578+
Name: a, dtype: float64
2579+
2580+
>>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE
2581+
shape: (3,)
2582+
Series: 'a' [f64]
2583+
[
2584+
null
2585+
150.0
2586+
250.0
2587+
]
2588+
2589+
>>> func(s_pa) # doctest:+ELLIPSIS
2590+
<pyarrow.lib.ChunkedArray object at ...>
2591+
[
2592+
[
2593+
null,
2594+
150,
2595+
250
2596+
]
2597+
]
2598+
"""
2599+
return self._from_compliant_series(
2600+
self._compliant_series.rolling_mean(
2601+
window_size=window_size,
2602+
weights=weights,
2603+
min_periods=min_periods,
2604+
center=center,
2605+
)
2606+
)
2607+
25282608
def __iter__(self: Self) -> Iterator[Any]:
25292609
yield from self._compliant_series.__iter__()
25302610

0 commit comments

Comments
 (0)