-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG: Groupby min/max with nullable dtypes #42567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
6f7c961
8f98501
936df47
b325cc0
a021e58
2807b25
1988294
25307f6
921ad33
98f8782
13bd9f3
f44e77b
cc95817
362eed5
5f4ea99
3eb06f5
359c171
35def86
edb1beb
e1a447b
f556ab8
92b4617
11d7f1d
6043105
961073d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1176,7 +1176,9 @@ cdef group_min_max(groupby_t[:, ::1] out, | |
const intp_t[:] labels, | ||
Py_ssize_t min_count=-1, | ||
bint is_datetimelike=False, | ||
bint compute_max=True): | ||
bint compute_max=True, | ||
const uint8_t[:, ::1] mask_in=None, | ||
uint8_t[:, ::1] result_mask=None): | ||
""" | ||
Compute minimum/maximum of columns of `values`, in row groups `labels`. | ||
|
||
|
@@ -1197,6 +1199,12 @@ cdef group_min_max(groupby_t[:, ::1] out, | |
True if `values` contains datetime-like entries. | ||
compute_max : bint, default True | ||
True to compute group-wise max, False to compute min | ||
mask_in : ndarray[bool, ndim=2], optional | ||
If not None, indices represent missing values, | ||
otherwise the mask will not be used | ||
result_mask : ndarray[bool, ndim=2], optional | ||
If not None, these specify locations in the output that are NA. | ||
Modified in-place. | ||
|
||
Notes | ||
----- | ||
|
@@ -1209,6 +1217,8 @@ cdef group_min_max(groupby_t[:, ::1] out, | |
ndarray[groupby_t, ndim=2] group_min_or_max | ||
bint runtime_error = False | ||
int64_t[:, ::1] nobs | ||
bint uses_mask = mask_in is not None | ||
bint isna_entry | ||
|
||
# TODO(cython 3.0): | ||
# Instead of `labels.shape[0]` use `len(labels)` | ||
|
@@ -1243,7 +1253,12 @@ cdef group_min_max(groupby_t[:, ::1] out, | |
for j in range(K): | ||
val = values[i, j] | ||
|
||
if not _treat_as_na(val, is_datetimelike): | ||
if uses_mask: | ||
isna_entry = mask_in[i, j] | ||
else: | ||
isna_entry = _treat_as_na(val, is_datetimelike) | ||
|
||
if not isna_entry: | ||
nobs[lab, j] += 1 | ||
if compute_max: | ||
if val > group_min_or_max[lab, j]: | ||
|
@@ -1259,7 +1274,10 @@ cdef group_min_max(groupby_t[:, ::1] out, | |
runtime_error = True | ||
break | ||
else: | ||
out[i, j] = nan_val | ||
if uses_mask: | ||
result_mask[i, j] = True | ||
else: | ||
out[i, j] = nan_val | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm i could go either way on this |
||
else: | ||
out[i, j] = group_min_or_max[i, j] | ||
|
||
|
@@ -1276,7 +1294,9 @@ def group_max(groupby_t[:, ::1] out, | |
ndarray[groupby_t, ndim=2] values, | ||
const intp_t[:] labels, | ||
Py_ssize_t min_count=-1, | ||
bint is_datetimelike=False) -> None: | ||
bint is_datetimelike=False, | ||
const uint8_t[:, ::1] mask=None, | ||
uint8_t[:, ::1] result_mask=None) -> None: | ||
"""See group_min_max.__doc__""" | ||
group_min_max( | ||
out, | ||
|
@@ -1286,6 +1306,8 @@ def group_max(groupby_t[:, ::1] out, | |
min_count=min_count, | ||
is_datetimelike=is_datetimelike, | ||
compute_max=True, | ||
mask_in=mask, | ||
result_mask=result_mask, | ||
) | ||
|
||
|
||
|
@@ -1296,7 +1318,9 @@ def group_min(groupby_t[:, ::1] out, | |
ndarray[groupby_t, ndim=2] values, | ||
const intp_t[:] labels, | ||
Py_ssize_t min_count=-1, | ||
bint is_datetimelike=False) -> None: | ||
bint is_datetimelike=False, | ||
const uint8_t[:, ::1] mask=None, | ||
uint8_t[:, ::1] result_mask=None) -> None: | ||
"""See group_min_max.__doc__""" | ||
group_min_max( | ||
out, | ||
|
@@ -1306,6 +1330,8 @@ def group_min(groupby_t[:, ::1] out, | |
min_count=min_count, | ||
is_datetimelike=is_datetimelike, | ||
compute_max=False, | ||
mask_in=mask, | ||
result_mask=result_mask, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,6 +123,8 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False): | |
raise ValueError("values must be a 1D array") | ||
if mask.ndim != 1: | ||
raise ValueError("mask must be a 1D array") | ||
if values.shape != mask.shape: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this hit in tests? |
||
raise ValueError("values and mask must have same shape") | ||
|
||
if copy: | ||
values = values.copy() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,7 +138,7 @@ def __init__(self, kind: str, how: str): | |
}, | ||
} | ||
|
||
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax"} | ||
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max"} | ||
|
||
_cython_arity = {"ohlc": 4} # OHLC | ||
|
||
|
@@ -404,6 +404,7 @@ def _masked_ea_wrap_cython_operation( | |
|
||
# Copy to ensure input and result masks don't end up shared | ||
mask = values._mask.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For an aggregation op like this, I guess we can avoid the mask copy since the mask is not being modified inplace. (but not something needs to be done in this pr, just a note) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though I suppose there's the tradeoff that then we'd lose mask contiguity guarantee in the algo? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haven't looked at the contiguity but that makes sense
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result_mask = np.zeros(ngroups, dtype=bool) | ||
arr = values._data | ||
|
||
res_values = self._cython_op_ndim_compat( | ||
|
@@ -412,13 +413,18 @@ def _masked_ea_wrap_cython_operation( | |
ngroups=ngroups, | ||
comp_ids=comp_ids, | ||
mask=mask, | ||
result_mask=result_mask, | ||
**kwargs, | ||
) | ||
|
||
dtype = self._get_result_dtype(orig_values.dtype) | ||
assert isinstance(dtype, BaseMaskedDtype) | ||
cls = dtype.construct_array_type() | ||
|
||
return cls(res_values.astype(dtype.type, copy=False), mask) | ||
if self.kind != "aggregate": | ||
return cls(res_values.astype(dtype.type, copy=False), mask) | ||
else: | ||
return cls(res_values.astype(dtype.type, copy=False), result_mask) | ||
|
||
@final | ||
def _cython_op_ndim_compat( | ||
|
@@ -428,20 +434,24 @@ def _cython_op_ndim_compat( | |
min_count: int, | ||
ngroups: int, | ||
comp_ids: np.ndarray, | ||
mask: np.ndarray | None, | ||
mask: np.ndarray | None = None, | ||
result_mask: np.ndarray | None = None, | ||
**kwargs, | ||
) -> np.ndarray: | ||
if values.ndim == 1: | ||
# expand to 2d, dispatch, then squeeze if appropriate | ||
values2d = values[None, :] | ||
if mask is not None: | ||
mask = mask[None, :] | ||
if result_mask is not None: | ||
result_mask = result_mask[None, :] | ||
res = self._call_cython_op( | ||
values2d, | ||
min_count=min_count, | ||
ngroups=ngroups, | ||
comp_ids=comp_ids, | ||
mask=mask, | ||
result_mask=result_mask, | ||
**kwargs, | ||
) | ||
if res.shape[0] == 1: | ||
|
@@ -456,6 +466,7 @@ def _cython_op_ndim_compat( | |
ngroups=ngroups, | ||
comp_ids=comp_ids, | ||
mask=mask, | ||
result_mask=result_mask, | ||
**kwargs, | ||
) | ||
|
||
|
@@ -468,6 +479,7 @@ def _call_cython_op( | |
ngroups: int, | ||
comp_ids: np.ndarray, | ||
mask: np.ndarray | None, | ||
result_mask: np.ndarray | None, | ||
**kwargs, | ||
) -> np.ndarray: # np.ndarray[ndim=2] | ||
orig_values = values | ||
|
@@ -493,6 +505,8 @@ def _call_cython_op( | |
values = values.T | ||
if mask is not None: | ||
mask = mask.T | ||
if result_mask is not None: | ||
result_mask = result_mask.T | ||
|
||
out_shape = self._get_output_shape(ngroups, values) | ||
func, values = self.get_cython_func_and_vals(values, is_numeric) | ||
|
@@ -508,6 +522,8 @@ def _call_cython_op( | |
values, | ||
comp_ids, | ||
min_count, | ||
mask=mask, | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result_mask=result_mask, | ||
is_datetimelike=is_datetimelike, | ||
) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the
group_max
that wraps this usesmask
, usemask
here as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, updated