Skip to content

Commit 217bd7d

Browse files
authored
refactor: Simplify ArrowGroupBy.__iter__ (#2133)
1 parent e34433a commit 217bd7d

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

narwhals/_arrow/group_by.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pyarrow as pa
1111
import pyarrow.compute as pc
1212

13+
from narwhals._arrow.utils import extract_py_scalar
1314
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1415
from narwhals._expression_parsing import is_simple_aggregation
1516
from narwhals.utils import generate_temporary_column_name
@@ -156,20 +157,9 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
156157
*it, "", null_handling="replace", null_replacement=null_token
157158
)
158159
table = table.add_column(i=0, field_=col_token, column=key_values)
159-
160-
yield from (
161-
(
162-
next(
163-
(
164-
t := self._df._from_native_frame(
165-
table.filter(pc.equal(table[col_token], v)).drop([col_token])
166-
)
167-
)
168-
.simple_select(*self._keys)
169-
.head(1)
170-
.iter_rows(named=False, buffer_size=512)
171-
),
172-
t,
160+
for v in pc.unique(key_values):
161+
t = self._df._from_native_frame(
162+
table.filter(pc.equal(table[col_token], v)).drop([col_token])
173163
)
174-
for v in pc.unique(key_values)
175-
)
164+
row = t.simple_select(*self._keys).row(0)
165+
yield tuple(extract_py_scalar(el) for el in row), t

narwhals/_arrow/series.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from narwhals.utils import Version
5757

5858

59+
# TODO @dangotbanned: move into `_arrow.utils`
60+
# Lots of modules are importing inline
5961
@overload
6062
def maybe_extract_py_scalar(
6163
value: pa.Scalar[_BasicDataType[_AsPyType]],

narwhals/_arrow/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def extract_regex(
6565
"""Alias for `pyarrow.scalar`."""
6666

6767

68+
def extract_py_scalar(value: Any, /) -> Any:
69+
from narwhals._arrow.series import maybe_extract_py_scalar
70+
71+
return maybe_extract_py_scalar(value, return_py_scalar=True)
72+
73+
6874
def chunked_array(
6975
arr: ArrowArray | list[Iterable[pa.Scalar[Any]]] | ArrowChunkedArray,
7076
) -> ArrowChunkedArray:

0 commit comments

Comments
 (0)