Skip to content

Commit 3b36a2f

Browse files
committed
fix(groupby): preserve metadata for subclassed DataFrames and Series
- Update metadata preservation logic for DataFrames and Series in groupby operations - Fix DataFrame.__setitem__ with MultiIndex columns and scalar indexer - Adjust formatting and naming conventions in the code
1 parent 72514b0 commit 3b36a2f

File tree

4 files changed

+85
-32
lines changed

4 files changed

+85
-32
lines changed

pandas/core/groupby/generic.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,13 +2070,13 @@ def _wrap_applied_output(
20702070

20712071
result = self.obj._constructor(index=res_index, columns=data.columns)
20722072
result = result.astype(data.dtypes)
2073-
2073+
20742074
# Preserve metadata for subclassed DataFrames
2075-
if hasattr(self.obj, '_metadata'):
2075+
if hasattr(self.obj, "_metadata"):
20762076
for attr in self.obj._metadata:
20772077
if hasattr(self.obj, attr):
20782078
setattr(result, attr, getattr(self.obj, attr))
2079-
2079+
20802080
return result
20812081

20822082
# GH12824
@@ -2088,27 +2088,27 @@ def _wrap_applied_output(
20882088
# GH57775 - Ensure that columns and dtypes from original frame are kept.
20892089
result = self.obj._constructor(columns=data.columns)
20902090
result = result.astype(data.dtypes)
2091-
2091+
20922092
# Preserve metadata for subclassed DataFrames
2093-
if hasattr(self.obj, '_metadata'):
2093+
if hasattr(self.obj, "_metadata"):
20942094
for attr in self.obj._metadata:
20952095
if hasattr(self.obj, attr):
20962096
setattr(result, attr, getattr(self.obj, attr))
2097-
2097+
20982098
return result
20992099
elif isinstance(first_not_none, DataFrame):
21002100
result = self._concat_objects(
21012101
values,
21022102
not_indexed_same=not_indexed_same,
21032103
is_transform=is_transform,
21042104
)
2105-
2105+
21062106
# Preserve metadata for subclassed DataFrames
2107-
if hasattr(self.obj, '_metadata'):
2107+
if hasattr(self.obj, "_metadata"):
21082108
for attr in self.obj._metadata:
21092109
if hasattr(self.obj, attr):
21102110
setattr(result, attr, getattr(self.obj, attr))
2111-
2111+
21122112
return result
21132113

21142114
key_index = self._grouper.result_index if self.as_index else None
@@ -2128,13 +2128,13 @@ def _wrap_applied_output(
21282128
# has type "Tuple[Any, ...]")
21292129
name = self._selection # type: ignore[assignment]
21302130
result = self.obj._constructor_sliced(values, index=key_index, name=name)
2131-
2131+
21322132
# Preserve metadata for subclassed Series
2133-
if hasattr(self.obj, '_metadata'):
2133+
if hasattr(self.obj, "_metadata"):
21342134
for attr in self.obj._metadata:
21352135
if hasattr(self.obj, attr):
21362136
setattr(result, attr, getattr(self.obj, attr))
2137-
2137+
21382138
return result
21392139
elif not isinstance(first_not_none, Series):
21402140
# values are not series or array-like but scalars
@@ -2143,24 +2143,24 @@ def _wrap_applied_output(
21432143
# of columns
21442144
if self.as_index:
21452145
result = self.obj._constructor_sliced(values, index=key_index)
2146-
2146+
21472147
# Preserve metadata for subclassed Series
2148-
if hasattr(self.obj, '_metadata'):
2148+
if hasattr(self.obj, "_metadata"):
21492149
for attr in self.obj._metadata:
21502150
if hasattr(self.obj, attr):
21512151
setattr(result, attr, getattr(self.obj, attr))
2152-
2152+
21532153
return result
21542154
else:
21552155
result = self.obj._constructor(values, columns=[self._selection])
21562156
result = self._insert_inaxis_grouper(result)
2157-
2157+
21582158
# Preserve metadata for subclassed DataFrames
2159-
if hasattr(self.obj, '_metadata'):
2159+
if hasattr(self.obj, "_metadata"):
21602160
for attr in self.obj._metadata:
21612161
if hasattr(self.obj, attr):
21622162
setattr(result, attr, getattr(self.obj, attr))
2163-
2163+
21642164
return result
21652165
else:
21662166
# values are Series
@@ -2171,13 +2171,13 @@ def _wrap_applied_output(
21712171
key_index,
21722172
is_transform,
21732173
)
2174-
2174+
21752175
# Preserve metadata for subclassed DataFrames/Series
2176-
if hasattr(self.obj, '_metadata'):
2176+
if hasattr(self.obj, "_metadata"):
21772177
for attr in self.obj._metadata:
21782178
if hasattr(self.obj, attr):
21792179
setattr(result, attr, getattr(self.obj, attr))
2180-
2180+
21812181
return result
21822182

21832183
def _wrap_applied_output_series(

pandas/tests/frame/indexing/test_setitem.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,59 @@ def test_setitem_multi_index(self):
607607
df[("joe", "last")] = df[("jolie", "first")].loc[i, j]
608608
tm.assert_frame_equal(df[("joe", "last")], df[("jolie", "first")])
609609

610+
def test_setitem_multiindex_scalar_indexer(self):
611+
# GH#62135: Fix DataFrame.__setitem__ with MultiIndex columns and scalar indexer
612+
# Test scalar key assignment with MultiIndex columns
613+
columns = MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("B", "a")])
614+
df = DataFrame(np.arange(15).reshape(5, 3), columns=columns)
615+
616+
# Test setting new column with scalar tuple key
617+
df[("C", "c")] = 100
618+
expected_new = DataFrame(
619+
np.array(
620+
[
621+
[0, 1, 2, 100],
622+
[3, 4, 5, 100],
623+
[6, 7, 8, 100],
624+
[9, 10, 11, 100],
625+
[12, 13, 14, 100],
626+
]
627+
),
628+
columns=MultiIndex.from_tuples(
629+
[("A", "a"), ("A", "b"), ("B", "a"), ("C", "c")]
630+
),
631+
)
632+
tm.assert_frame_equal(df, expected_new)
633+
634+
# Test setting existing column with scalar tuple key
635+
df[("A", "a")] = 999
636+
expected_existing = expected_new.copy()
637+
expected_existing[("A", "a")] = 999
638+
tm.assert_frame_equal(df, expected_existing)
639+
640+
# Test setting with Series using scalar tuple key
641+
series_data = Series([10, 20, 30, 40, 50])
642+
df[("D", "d")] = series_data
643+
expected_series = expected_existing.copy()
644+
expected_series[("D", "d")] = series_data
645+
tm.assert_frame_equal(df, expected_series)
646+
647+
# Test with 3-level MultiIndex
648+
columns_3level = MultiIndex.from_tuples(
649+
[("X", "A", "1"), ("X", "A", "2"), ("Y", "B", "1")]
650+
)
651+
df_3level = DataFrame(np.arange(12).reshape(4, 3), columns=columns_3level)
652+
653+
# Test scalar assignment with 3-level MultiIndex
654+
df_3level[("Z", "C", "3")] = 42
655+
assert ("Z", "C", "3") in df_3level.columns
656+
tm.assert_series_equal(df_3level[("Z", "C", "3")], Series([42, 42, 42, 42]))
657+
658+
# Test Series assignment with 3-level MultiIndex
659+
new_series = Series([1, 2, 3, 4])
660+
df_3level[("W", "D", "4")] = new_series
661+
tm.assert_series_equal(df_3level[("W", "D", "4")], new_series)
662+
610663
@pytest.mark.parametrize(
611664
"columns,box,expected",
612665
[

pandas/tests/frame/test_query_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_query_duplicate_column_name(self, engine, parser):
168168
}
169169
).rename(columns={"B": "A"})
170170

171-
res = df.query('C == 1', engine=engine, parser=parser)
171+
res = df.query("C == 1", engine=engine, parser=parser)
172172

173173
expect = DataFrame(
174174
[[1, 1, 1]],
@@ -1411,7 +1411,7 @@ def test_expr_with_column_name_with_backtick_and_hash(self):
14111411
def test_expr_with_column_name_with_backtick(self):
14121412
# GH 59285
14131413
df = DataFrame({"a`b": (1, 2, 3), "ab": (4, 5, 6)})
1414-
result = df.query("`a``b` < 2") # noqa
1414+
result = df.query("`a``b` < 2")
14151415
# Note: Formatting checks may wrongly consider the above ``inline code``.
14161416
expected = df[df["a`b"] < 2]
14171417
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_groupby_metadata.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,8 @@
33
"""
44

55
import numpy as np
6-
import pytest
76

8-
import pandas as pd
97
import pandas._testing as tm
10-
from pandas import DataFrame
11-
from pandas.tests.groupby import test_groupby_subclass
128

139

1410
class TestGroupByMetadataPreservation:
@@ -19,14 +15,18 @@ def test_groupby_apply_preserves_metadata(self):
1915
{"X": [1, 1, 2, 2, 3], "Y": np.arange(0, 5), "Z": np.arange(10, 15)}
2016
)
2117
subdf.testattr = "test"
22-
18+
2319
# Apply groupby operation
2420
result = subdf.groupby("X").apply(np.sum, axis=0, include_groups=False)
25-
21+
2622
# Check that metadata is preserved
27-
assert hasattr(result, 'testattr'), "Metadata attribute 'testattr' should be preserved"
23+
assert hasattr(result, "testattr"), (
24+
"Metadata attribute 'testattr' should be preserved"
25+
)
2826
assert result.testattr == "test", "Metadata value should be preserved"
29-
27+
3028
# Compare with equivalent operation that preserves metadata
3129
expected = subdf.groupby("X").sum()
32-
assert expected.testattr == "test", "Equivalent operation should preserve metadata"
30+
assert expected.testattr == "test", (
31+
"Equivalent operation should preserve metadata"
32+
)

0 commit comments

Comments
 (0)