@@ -341,10 +341,14 @@ class _GroupBy(PandasObject, SelectionMixin):
341
341
342
342
def __init__ (self , obj , keys = None , axis = 0 , level = None ,
343
343
grouper = None , exclusions = None , selection = None , as_index = True ,
344
- sort = True , group_keys = True , squeeze = False , ** kwargs ):
344
+ sort = True , group_keys = True , squeeze = False , ref_obj = None , ** kwargs ):
345
345
346
346
self ._selection = selection
347
347
348
+ if ref_obj is None :
349
+ ref_obj = obj
350
+ self .ref_obj = ref_obj
351
+
348
352
if isinstance (obj , NDFrame ):
349
353
obj ._consolidate_inplace ()
350
354
@@ -796,7 +800,9 @@ def _cython_agg_general(self, how, weights=None, numeric_only=True):
796
800
if weights is not None :
797
801
798
802
# TODO, need to integrate this with the exclusions
799
- _ , weights = weightby .weightby (self .obj , weights = weights , axis = axis )
803
+ _ , weights = weightby .weightby (self .ref_obj ,
804
+ weights = weights ,
805
+ axis = self .axis )
800
806
801
807
output = {}
802
808
for name , obj in self ._iterate_slices ():
@@ -3189,13 +3195,14 @@ def _wrap_agged_blocks(self, items, blocks):
3189
3195
3190
3196
_block_agg_axis = 0
3191
3197
3192
- def _cython_agg_blocks (self , how , weights = None , numeric_only = True ):
3198
+ def _cython_agg_blocks (self , how , weights = None , numeric_only = True ,
3199
+ ** kwargs ):
3193
3200
data , agg_axis = self ._get_data_to_aggregate ()
3194
3201
3195
3202
if weights is not None :
3196
3203
3197
3204
# TODO, need to integrate this with the exclusions
3198
- _ , weights = weightby .weightby (self .obj ,
3205
+ _ , weights = weightby .weightby (self .ref_obj ,
3199
3206
weights = weights ,
3200
3207
axis = self .axis )
3201
3208
@@ -3765,19 +3772,20 @@ def _gotitem(self, key, ndim, subset=None):
3765
3772
subset : object, default None
3766
3773
subset to act on
3767
3774
"""
3768
-
3769
3775
if ndim == 2 :
3770
3776
if subset is None :
3771
3777
subset = self .obj
3772
3778
return DataFrameGroupBy (subset , self .grouper , selection = key ,
3773
3779
grouper = self .grouper ,
3774
3780
exclusions = self .exclusions ,
3775
- as_index = self .as_index )
3781
+ as_index = self .as_index ,
3782
+ ref_obj = self .obj )
3776
3783
elif ndim == 1 :
3777
3784
if subset is None :
3778
3785
subset = self .obj [key ]
3779
3786
return SeriesGroupBy (subset , selection = key ,
3780
- grouper = self .grouper )
3787
+ grouper = self .grouper ,
3788
+ ref_obj = self .obj )
3781
3789
3782
3790
raise AssertionError ("invalid ndim for _gotitem" )
3783
3791
0 commit comments