Skip to content

Commit 3fe7f7f

Browse files
committed
wip
1 parent 3d5fb19 commit 3fe7f7f

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

pandas/core/groupby.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,14 @@ class _GroupBy(PandasObject, SelectionMixin):
341341

342342
def __init__(self, obj, keys=None, axis=0, level=None,
343343
grouper=None, exclusions=None, selection=None, as_index=True,
344-
sort=True, group_keys=True, squeeze=False, **kwargs):
344+
sort=True, group_keys=True, squeeze=False, ref_obj=None, **kwargs):
345345

346346
self._selection = selection
347347

348+
if ref_obj is None:
349+
ref_obj = obj
350+
self.ref_obj = ref_obj
351+
348352
if isinstance(obj, NDFrame):
349353
obj._consolidate_inplace()
350354

@@ -796,7 +800,9 @@ def _cython_agg_general(self, how, weights=None, numeric_only=True):
796800
if weights is not None:
797801

798802
# TODO, need to integrate this with the exclusions
799-
_, weights = weightby.weightby(self.obj, weights=weights, axis=axis)
803+
_, weights = weightby.weightby(self.ref_obj,
804+
weights=weights,
805+
axis=self.axis)
800806

801807
output = {}
802808
for name, obj in self._iterate_slices():
@@ -3189,13 +3195,14 @@ def _wrap_agged_blocks(self, items, blocks):
31893195

31903196
_block_agg_axis = 0
31913197

3192-
def _cython_agg_blocks(self, how, weights=None, numeric_only=True):
3198+
def _cython_agg_blocks(self, how, weights=None, numeric_only=True,
3199+
**kwargs):
31933200
data, agg_axis = self._get_data_to_aggregate()
31943201

31953202
if weights is not None:
31963203

31973204
# TODO, need to integrate this with the exclusions
3198-
_, weights = weightby.weightby(self.obj,
3205+
_, weights = weightby.weightby(self.ref_obj,
31993206
weights=weights,
32003207
axis=self.axis)
32013208

@@ -3765,19 +3772,20 @@ def _gotitem(self, key, ndim, subset=None):
37653772
subset : object, default None
37663773
subset to act on
37673774
"""
3768-
37693775
if ndim == 2:
37703776
if subset is None:
37713777
subset = self.obj
37723778
return DataFrameGroupBy(subset, self.grouper, selection=key,
37733779
grouper=self.grouper,
37743780
exclusions=self.exclusions,
3775-
as_index=self.as_index)
3781+
as_index=self.as_index,
3782+
ref_obj=self.obj)
37763783
elif ndim == 1:
37773784
if subset is None:
37783785
subset = self.obj[key]
37793786
return SeriesGroupBy(subset, selection=key,
3780-
grouper=self.grouper)
3787+
grouper=self.grouper,
3788+
ref_obj=self.obj)
37813789

37823790
raise AssertionError("invalid ndim for _gotitem")
37833791

pandas/tools/tests/test_weightby.py

+7
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ def test_basic(self):
6161
def test_groupby(self):
6262

6363
for f in ['mean', 'sum']:
64+
6465
weights = (self.df3['A'] / self.df3.A.sum()).values
6566
result = getattr(self.df3.groupby('C'), f)(weights='A')
6667
adj = self.df3.assign(A=self.df3.A * weights,
6768
B=self.df3.B * weights)
6869
expected = getattr(adj.groupby('C'), f)()
6970
tm.assert_frame_equal(result, expected)
7071

72+
weights = (self.df3['A'] / self.df3.A.sum()).values
73+
result = getattr(self.df3.groupby('C').B, f)(weights='A')
74+
adj = self.df3.assign(B=self.df3.B * weights)
75+
expected = getattr(adj.groupby('C').B, f)()
76+
tm.assert_series_equal(result, expected)
77+
7178
def test_unsupported(self):
7279
for f in ['first', 'median', 'min', 'max', 'prod']:
7380

0 commit comments

Comments
 (0)