4949from pandas .formats .printing import pprint_thing
5050from pandas .util .validators import validate_kwargs
5151
52+ from pandas .tools import weightby
5253import pandas .core .algorithms as algos
5354import pandas .core .common as com
5455from pandas .core .config import option_context
@@ -791,15 +792,21 @@ def _cython_transform(self, how, numeric_only=True):
791792
792793 return self ._wrap_transformed_output (output , names )
793794
794- def _cython_agg_general (self , how , numeric_only = True ):
795+ def _cython_agg_general (self , how , weights = None , numeric_only = True ):
796+ if weights is not None :
797+
798+ # TODO, need to integrate this with the exclusions
799+ _ , weights = weightby .weightby (self .obj , weights = weights , axis = axis )
800+
795801 output = {}
796802 for name , obj in self ._iterate_slices ():
797803 is_numeric = is_numeric_dtype (obj .dtype )
798804 if numeric_only and not is_numeric :
799805 continue
800806
807+ values = weightby .weight (obj .values , weights )
801808 try :
802- result , names = self .grouper .aggregate (obj . values , how )
809+ result , names = self .grouper .aggregate (values , how )
803810 except AssertionError as e :
804811 raise GroupByError (str (e ))
805812 output [name ] = self ._try_cast (result , obj )
@@ -1006,6 +1013,26 @@ def count(self):
10061013 # defined here for API doc
10071014 raise NotImplementedError
10081015
1016+ @Substitution (name = 'groupby' )
1017+ @Appender (_doc_template )
1018+ def sum (self , * args , ** kwargs ):
1019+ """
1020+ Compute sum of groups, excluding missing values
1021+
1022+ For multiple groupings, the result index will be a MultiIndex
1023+ """
1024+
1025+ # TODO: this is slightly different from other cythonized functions (e.g. mean)
1026+ # to accomodate np.sum functionaility
1027+ nv .validate_groupby_func ('sum' , args , kwargs , ('weights' , 'numeric_only' ))
1028+ self ._set_group_selection ()
1029+ try :
1030+ return self ._cython_agg_general ('add' , ** kwargs )
1031+ except AssertionError as e :
1032+ raise SpecificationError (str (e ))
1033+ except Exception : # pragma: no cover
1034+ return self .aggregate (lambda x : np .sum (x , axis = self .axis ))
1035+
10091036 @Substitution (name = 'groupby' )
10101037 @Appender (_doc_template )
10111038 def mean (self , * args , ** kwargs ):
@@ -1014,14 +1041,15 @@ def mean(self, *args, **kwargs):
10141041
10151042 For multiple groupings, the result index will be a MultiIndex
10161043 """
1017- nv .validate_groupby_func ('mean' , args , kwargs )
1044+ nv .validate_groupby_func ('mean' , args , kwargs , ( 'weights' , 'numeric_only' ) )
10181045 try :
1019- return self ._cython_agg_general ('mean' )
1046+ return self ._cython_agg_general ('mean' , ** kwargs )
10201047 except GroupByError :
10211048 raise
10221049 except Exception : # pragma: no cover
10231050 self ._set_group_selection ()
1024- f = lambda x : x .mean (axis = self .axis )
1051+ kwargs ['axis' ] = self .axis
1052+ f = lambda x : x .mean (** kwargs )
10251053 return self ._python_agg_general (f )
10261054
10271055 @Substitution (name = 'groupby' )
@@ -1107,7 +1135,6 @@ def size(self):
11071135 """Compute group sizes"""
11081136 return self .grouper .size ()
11091137
1110- sum = _groupby_function ('sum' , 'add' , np .sum )
11111138 prod = _groupby_function ('prod' , 'prod' , np .prod )
11121139 min = _groupby_function ('min' , 'min' , np .min , numeric_only = False )
11131140 max = _groupby_function ('max' , 'max' , np .max , numeric_only = False )
@@ -3134,9 +3161,9 @@ def _iterate_slices(self):
31343161 continue
31353162 yield val , slicer (val )
31363163
3137- def _cython_agg_general (self , how , numeric_only = True ):
3164+ def _cython_agg_general (self , how , ** kwargs ):
31383165 new_items , new_blocks = self ._cython_agg_blocks (
3139- how , numeric_only = numeric_only )
3166+ how , ** kwargs )
31403167 return self ._wrap_agged_blocks (new_items , new_blocks )
31413168
31423169 def _wrap_agged_blocks (self , items , blocks ):
@@ -3162,18 +3189,26 @@ def _wrap_agged_blocks(self, items, blocks):
31623189
31633190 _block_agg_axis = 0
31643191
3165- def _cython_agg_blocks (self , how , numeric_only = True ):
3192+ def _cython_agg_blocks (self , how , weights = None , numeric_only = True ):
31663193 data , agg_axis = self ._get_data_to_aggregate ()
31673194
3195+ if weights is not None :
3196+
3197+ # TODO, need to integrate this with the exclusions
3198+ _ , weights = weightby .weightby (self .obj ,
3199+ weights = weights ,
3200+ axis = self .axis )
3201+
31683202 new_blocks = []
31693203
31703204 if numeric_only :
31713205 data = data .get_numeric_data (copy = False )
31723206
31733207 for block in data .blocks :
31743208
3209+ values = weightby .weight (block .values , weights )
31753210 result , _ = self .grouper .aggregate (
3176- block . values , how , axis = agg_axis )
3211+ values , how , axis = agg_axis )
31773212
31783213 # see if we can cast the block back to the original dtype
31793214 result = block ._try_coerce_and_cast_result (result )
0 commit comments