Skip to content

Commit 3d5fb19

Browse files
committed
fix groupby
1 parent f9de1ba commit 3d5fb19

File tree

5 files changed

+93
-17
lines changed

5 files changed

+93
-17
lines changed

pandas/compat/numpy/function.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,16 @@ def validate_expanding_func(name, args, kwargs):
306306
raise UnsupportedFunctionCall(msg)
307307

308308

309-
def validate_groupby_func(name, args, kwargs):
309+
def validate_groupby_func(name, args, kwargs, allowed_kwargs=None):
310310
"""
311-
'args' and 'kwargs' should be empty because all of
311+
'args' should be empty because all of
312312
their necessary parameters are explicitly listed in
313313
the function signature
314314
"""
315-
if len(args) + len(kwargs) > 0:
315+
if allowed_kwargs:
316+
kwargs = set(kwargs) - set(allowed_kwargs)
317+
318+
if len(args) or len(kwargs):
316319
raise UnsupportedFunctionCall((
317320
"numpy operations are not valid "
318321
"with groupby. Use .groupby(...)."

pandas/core/groupby.py

+45-10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from pandas.formats.printing import pprint_thing
5050
from pandas.util.validators import validate_kwargs
5151

52+
from pandas.tools import weightby
5253
import pandas.core.algorithms as algos
5354
import pandas.core.common as com
5455
from pandas.core.config import option_context
@@ -791,15 +792,21 @@ def _cython_transform(self, how, numeric_only=True):
791792

792793
return self._wrap_transformed_output(output, names)
793794

794-
def _cython_agg_general(self, how, numeric_only=True):
795+
def _cython_agg_general(self, how, weights=None, numeric_only=True):
796+
if weights is not None:
797+
798+
# TODO, need to integrate this with the exclusions
799+
_, weights = weightby.weightby(self.obj, weights=weights, axis=axis)
800+
795801
output = {}
796802
for name, obj in self._iterate_slices():
797803
is_numeric = is_numeric_dtype(obj.dtype)
798804
if numeric_only and not is_numeric:
799805
continue
800806

807+
values = weightby.weight(obj.values, weights)
801808
try:
802-
result, names = self.grouper.aggregate(obj.values, how)
809+
result, names = self.grouper.aggregate(values, how)
803810
except AssertionError as e:
804811
raise GroupByError(str(e))
805812
output[name] = self._try_cast(result, obj)
@@ -1006,6 +1013,26 @@ def count(self):
10061013
# defined here for API doc
10071014
raise NotImplementedError
10081015

1016+
@Substitution(name='groupby')
1017+
@Appender(_doc_template)
1018+
def sum(self, *args, **kwargs):
1019+
"""
1020+
Compute sum of groups, excluding missing values
1021+
1022+
For multiple groupings, the result index will be a MultiIndex
1023+
"""
1024+
1025+
# TODO: this is slightly different from other cythonized functions (e.g. mean)
1026+
# to accomodate np.sum functionaility
1027+
nv.validate_groupby_func('sum', args, kwargs, ('weights', 'numeric_only'))
1028+
self._set_group_selection()
1029+
try:
1030+
return self._cython_agg_general('add', **kwargs)
1031+
except AssertionError as e:
1032+
raise SpecificationError(str(e))
1033+
except Exception: # pragma: no cover
1034+
return self.aggregate(lambda x: np.sum(x, axis=self.axis))
1035+
10091036
@Substitution(name='groupby')
10101037
@Appender(_doc_template)
10111038
def mean(self, *args, **kwargs):
@@ -1014,14 +1041,15 @@ def mean(self, *args, **kwargs):
10141041
10151042
For multiple groupings, the result index will be a MultiIndex
10161043
"""
1017-
nv.validate_groupby_func('mean', args, kwargs)
1044+
nv.validate_groupby_func('mean', args, kwargs, ('weights', 'numeric_only'))
10181045
try:
1019-
return self._cython_agg_general('mean')
1046+
return self._cython_agg_general('mean', **kwargs)
10201047
except GroupByError:
10211048
raise
10221049
except Exception: # pragma: no cover
10231050
self._set_group_selection()
1024-
f = lambda x: x.mean(axis=self.axis)
1051+
kwargs['axis'] = self.axis
1052+
f = lambda x: x.mean(**kwargs)
10251053
return self._python_agg_general(f)
10261054

10271055
@Substitution(name='groupby')
@@ -1107,7 +1135,6 @@ def size(self):
11071135
"""Compute group sizes"""
11081136
return self.grouper.size()
11091137

1110-
sum = _groupby_function('sum', 'add', np.sum)
11111138
prod = _groupby_function('prod', 'prod', np.prod)
11121139
min = _groupby_function('min', 'min', np.min, numeric_only=False)
11131140
max = _groupby_function('max', 'max', np.max, numeric_only=False)
@@ -3134,9 +3161,9 @@ def _iterate_slices(self):
31343161
continue
31353162
yield val, slicer(val)
31363163

3137-
def _cython_agg_general(self, how, numeric_only=True):
3164+
def _cython_agg_general(self, how, **kwargs):
31383165
new_items, new_blocks = self._cython_agg_blocks(
3139-
how, numeric_only=numeric_only)
3166+
how, **kwargs)
31403167
return self._wrap_agged_blocks(new_items, new_blocks)
31413168

31423169
def _wrap_agged_blocks(self, items, blocks):
@@ -3162,18 +3189,26 @@ def _wrap_agged_blocks(self, items, blocks):
31623189

31633190
_block_agg_axis = 0
31643191

3165-
def _cython_agg_blocks(self, how, numeric_only=True):
3192+
def _cython_agg_blocks(self, how, weights=None, numeric_only=True):
31663193
data, agg_axis = self._get_data_to_aggregate()
31673194

3195+
if weights is not None:
3196+
3197+
# TODO, need to integrate this with the exclusions
3198+
_, weights = weightby.weightby(self.obj,
3199+
weights=weights,
3200+
axis=self.axis)
3201+
31683202
new_blocks = []
31693203

31703204
if numeric_only:
31713205
data = data.get_numeric_data(copy=False)
31723206

31733207
for block in data.blocks:
31743208

3209+
values = weightby.weight(block.values, weights)
31753210
result, _ = self.grouper.aggregate(
3176-
block.values, how, axis=agg_axis)
3211+
values, how, axis=agg_axis)
31773212

31783213
# see if we can cast the block back to the original dtype
31793214
result = block._try_coerce_and_cast_result(result)

pandas/core/nanops.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pandas.types.missing import isnull, notnull
2424

2525
from pandas.core.common import _values_from_object
26-
26+
from pandas.tools import weightby
2727

2828
class disallow(object):
2929
def __init__(self, *dtypes):
@@ -204,9 +204,7 @@ def _get_values(values, skipna,
204204
mask = isnull(values)
205205

206206
# weights
207-
if weights is not None:
208-
values = values * weights.reshape(values.shape)
209-
207+
values = weightby.weight(values, weights)
210208
dtype = values.dtype
211209
dtype_ok = _na_ok_dtype(dtype)
212210

pandas/tools/tests/test_weightby.py

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def setUp(self):
1313
'B': [1, 2, 3, 4]})
1414
self.df2 = DataFrame({'A': [1, 2, 3, 4],
1515
'B': [1, 2, 3, 4]})
16+
self.df3 = DataFrame({'A': [1, 2, 3, 4],
17+
'B': [1, 2, 3, 4],
18+
'C': [1, 1, 2, 2]})
1619

1720
@property
1821
def rs(self):
@@ -55,6 +58,16 @@ def test_basic(self):
5558
expected = getattr(self.df2[['B']] * weights2, f)(ddof=2)
5659
# tm.assert_series_equal(result, expected)
5760

61+
def test_groupby(self):
62+
63+
for f in ['mean', 'sum']:
64+
weights = (self.df3['A'] / self.df3.A.sum()).values
65+
result = getattr(self.df3.groupby('C'), f)(weights='A')
66+
adj = self.df3.assign(A=self.df3.A * weights,
67+
B=self.df3.B * weights)
68+
expected = getattr(adj.groupby('C'), f)()
69+
tm.assert_frame_equal(result, expected)
70+
5871
def test_unsupported(self):
5972
for f in ['first', 'median', 'min', 'max', 'prod']:
6073

pandas/tools/weightby.py

+27
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,30 @@ def weightby(obj, weights=None, axis=0):
9494
raise ValueError("Invalid weights: weights sum to zero")
9595

9696
return obj, weights.values
97+
98+
99+
def weight(values, weights):
100+
"""
101+
Return the values * weights, broadcasting if needed
102+
103+
Parameters
104+
----------
105+
values : ndarray
106+
weights : 1d-ndarray
107+
108+
Returns
109+
-------
110+
values shaped ndarray
111+
"""
112+
113+
if weights is None:
114+
return values
115+
116+
if values.ndim == 1:
117+
return values * weights
118+
119+
elif values.ndim == 2:
120+
121+
return values * weights
122+
123+
raise NotImplementedError

0 commit comments

Comments
 (0)