diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index feb6adb1e..21009e108 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -74,7 +74,7 @@ jobs: id: status run: | python -c "import xarray; xarray.show_versions()" - pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci + pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox - name: Upload code coverage to Codecov uses: codecov/codecov-action@v5.3.1 with: diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index d3591d2dc..82562cc3a 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -9,19 +9,16 @@ the `func` kwarg: - `"mean"`, `"nanmean"` - `"var"`, `"nanvar"` - `"std"`, `"nanstd"` -- `"argmin"` -- `"argmax"` +- `"argmin"`, `"nanargmax"` +- `"argmax"`, `"nanargmin"` - `"first"`, `"nanfirst"` - `"last"`, `"nanlast"` - `"median"`, `"nanmedian"` - `"mode"`, `"nanmode"` - `"quantile"`, `"nanquantile"` +- `"topk"` -```{tip} -We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! -``` - -## Custom Aggregations +## Custom Reductions `flox` also allows you to specify a custom Aggregation (again inspired by dask.dataframe), though this might not be fully functional at the moment. See `aggregations.py` for examples. @@ -46,3 +43,7 @@ mean = Aggregation( final_fill_value=np.nan, ) ``` + +## Custom Scans + +Coming soon! diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 938bd6fcc..9987704d6 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -47,14 +47,32 @@ def _lerp(a, b, *, t, dtype, out=None): return out -def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None): +def quantile_or_topk( + array, + inv_idx, + *, + q=None, + k=None, + axis, + skipna, + group_idx, + dtype=None, + out=None, + fill_value=None, +): + assert q is not None or k is not None + assert axis == -1 + inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) array_validmask = notnull(array) actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis) newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) - full_sizes = np.reshape(np.diff(inv_idx), newshape) - nanmask = full_sizes != actual_sizes + if k is not None: + nanmask = actual_sizes < abs(k) + else: + full_sizes = np.reshape(np.diff(inv_idx), newshape) + nanmask = full_sizes != actual_sizes # The approach here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary @@ -72,36 +90,48 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # So we determine which indices we need using the fact that NaNs get sorted to the end. # This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ # but not any more now that I use partition and avoid replacing NaNs - qin = q - q = np.atleast_1d(qin) - q = np.reshape(q, (len(q),) + (1,) * array.ndim) + if k is not None: + is_scalar_param = False + param = np.sort(np.arange(abs(k)) * np.sign(k)) + else: + is_scalar_param = is_scalar(q) + param = np.atleast_1d(q) + param = np.reshape(param, (param.size,) + (1,) * array.ndim) # This is numpy's method="linear" # TODO: could support all the interpolations here offset = actual_sizes.cumsum(axis=-1) - actual_sizes -= 1 - virtual_index = q * actual_sizes - # virtual_index is relative to group starts, so now offset that - virtual_index[..., 1:] += offset[..., :-1] - - is_scalar_q = is_scalar(qin) - if is_scalar_q: - virtual_index = virtual_index.squeeze(axis=0) - idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) - else: - idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) + # For topk(.., k=+1 or -1), we always return the singleton dimension. + idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) - lo_ = np.floor( - virtual_index, - casting="unsafe", - out=np.empty(virtual_index.shape, dtype=np.int64), - ) - hi_ = np.ceil( - virtual_index, - casting="unsafe", - out=np.empty(virtual_index.shape, dtype=np.int64), - ) - kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + if q is not None: + # This is numpy's method="linear" + # TODO: could support all the interpolations here + actual_sizes -= 1 + virtual_index = param * actual_sizes + # virtual_index is relative to group starts, so now offset that + virtual_index[..., 1:] += offset[..., :-1] + + if is_scalar_param: + virtual_index = virtual_index.squeeze(axis=0) + idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) + + lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)) + hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)) + kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + + else: + virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1) + # virtual_index is relative to group starts, so now offset that + virtual_index[..., 1:] += offset[..., :-1] + kth = np.unique(virtual_index) + kth = kth[kth >= 0] + kth[kth >= array.shape[axis]] = array.shape[axis] - 1 + k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) + lo_ = k_offset + virtual_index[np.newaxis, ...] + not_enough_elems = actual_sizes < np.abs(k) + lo_[..., not_enough_elems] = 0 + badmask = np.broadcast_to(not_enough_elems, idxshape) | nanmask # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) @@ -111,20 +141,33 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j` cmplx.real = labels_broadcast cmplx.partition(kth=kth, axis=-1) - if is_scalar_q: - a_ = cmplx.imag - else: - a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape) - # get bounds, Broadcast to (num quantiles, ..., num labels) - loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) - hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + a_ = cmplx.imag + if not is_scalar_param: + a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape) - # TODO: could support all the interpolations here - gamma = np.broadcast_to(virtual_index, idxshape) - lo_ - result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) - if not skipna and np.any(nanmask): - result[..., nanmask] = np.nan + if array.dtype.kind in "Mm": + a_ = a_.view(array.dtype) + + loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) + if q is not None: + # get bounds, Broadcast to (num quantiles, ..., num labels) + hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + + # TODO: could support all the interpolations here + gamma = np.broadcast_to(virtual_index, idxshape) - lo_ + result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) + if not skipna and np.any(nanmask): + result[..., nanmask] = fill_value + else: + result = loval + if badmask.any(): + result[badmask] = fill_value + + if k is not None: + result = result.astype(dtype, copy=False) + if out is not None: + np.copyto(out, result) return result @@ -158,12 +201,14 @@ def _np_grouped_op( if out is None: q = kwargs.get("q", None) - if q is None: + k = kwargs.get("k", None) + if q is None and k is None: out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) else: - nq = len(np.atleast_1d(q)) + nq = len(np.atleast_1d(q)) if q is not None else abs(k) out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) kwargs["group_idx"] = group_idx + kwargs["fill_value"] = fill_value if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all(): # The previous version of this if condition @@ -200,10 +245,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF) min = partial(_np_grouped_op, op=np.minimum.reduceat) nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF) -quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False)) -nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True)) -median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False)) -nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True)) +topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) +nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False)) +nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True)) # TODO: all, any diff --git a/flox/aggregations.py b/flox/aggregations.py index 575a5252e..d8e2ff094 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -551,6 +551,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +def topk_new_dims_func(k) -> tuple[Dim]: + return (Dim(name="k", values=np.arange(abs(k))),) + + # if the input contains integers or floats smaller than float64, # the output data-type is float64. Otherwise, the output data-type is the same as that # of the input. @@ -572,6 +576,16 @@ def quantile_new_dims_func(q) -> tuple[Dim]: ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) +topk = Aggregation( + name="topk", + fill_value=(dtypes.NINF, 0), + final_fill_value=dtypes.NA, + # FIXME: set numpy + chunk=("topk", "nanlen"), + combine=(xrutils.topk, "sum"), + new_dims_func=topk_new_dims_func, + preserves_dtype=True, +) @dataclass @@ -769,6 +783,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) "nanquantile": nanquantile, "mode": mode, "nanmode": nanmode, + "topk": topk, # "cumsum": cumsum, "nancumsum": nancumsum, "ffill": ffill, @@ -823,6 +838,12 @@ def _initialize_aggregation( ), } + if finalize_kwargs is not None: + assert isinstance(finalize_kwargs, dict) + agg.finalize_kwargs = finalize_kwargs + + if agg.name == "topk" and agg.finalize_kwargs["k"] < 0: + agg.fill_value["intermediate"] = (dtypes.INF, 0) # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( @@ -838,9 +859,8 @@ def _initialize_aggregation( else: agg.fill_value["numpy"] = (fv,) - if finalize_kwargs is not None: - assert isinstance(finalize_kwargs, dict) - agg.finalize_kwargs = finalize_kwargs + if agg.name == "topk": + min_count = max(min_count or 0, abs(agg.finalize_kwargs["k"])) # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be @@ -878,6 +898,11 @@ def _initialize_aggregation( else: simple_combine.append(getattr(np, combine)) else: + # TODO: bah, we need to pass `k` to the combine topk function + # this is ugly. + if agg.name == "topk" and not isinstance(combine, str): + assert combine is not None + combine = partial(combine, **agg.finalize_kwargs) simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) diff --git a/flox/core.py b/flox/core.py index 0daf77b23..acb81fd49 100644 --- a/flox/core.py +++ b/flox/core.py @@ -42,6 +42,7 @@ _initialize_aggregation, generic_aggregate, quantile_new_dims_func, + topk_new_dims_func, ) from .cache import memoize from .lib import ArrayLayer @@ -112,7 +113,7 @@ # This dummy axis is inserted using np.expand_dims # and then reduced over during the combine stage by # _simple_combine. -DUMMY_AXIS = -2 +DUMMY_AXIS = 0 logger = logging.getLogger("flox") @@ -176,7 +177,7 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool: if isinstance(func, Aggregation): func = func.name return ( - func in ["all", "any"] + func in ["all", "any", "topk"] # TODO: enable in npg # or _is_first_last_reduction(func) # or _is_minmax_reduction(func) @@ -982,7 +983,7 @@ def chunk_reduce( nfuncs = len(funcs) dtypes = _atleast_1d(dtype, nfuncs) fill_values = _atleast_1d(fill_value, nfuncs) - kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs + kwargss = _atleast_1d({} if kwargs is None else kwargs, nfuncs) if isinstance(axis, Sequence): axes: T_Axes = axis @@ -1070,8 +1071,16 @@ def chunk_reduce( # optimize that out. previous_reduction: T_Func = "" for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes): + # TODO: Figure out how to generalize this + if reduction in ("quantile", "nanquantile"): + new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) + elif reduction == "topk": + new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) + else: + new_dims_shape = tuple() + if empty: - result = np.full(shape=final_array_shape, fill_value=fv) + result = np.full(shape=new_dims_shape + final_array_shape, fill_value=fv) elif is_nanlen(reduction) and is_nanlen(previous_reduction): result = results["intermediates"][-1] else: @@ -1100,11 +1109,6 @@ def chunk_reduce( if hasnan: # remove NaN group label which should be last result = result[..., :-1] - # TODO: Figure out how to generalize this - if reduction in ("quantile", "nanquantile"): - new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) - else: - new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) results["intermediates"].append(result) previous_reduction = reduction @@ -1159,6 +1163,7 @@ def _finalize_results( if count_mask.any(): # For one count_mask.any() prevents promoting bool to dtype(fill_value) unless # necessary + fill_value = fill_value or agg.fill_value[agg.name] if fill_value is None: raise ValueError("Filling is required but fill_value is None.") # This allows us to match xarray's type promotion rules @@ -1198,8 +1203,15 @@ def _aggregate( return _finalize_results(results, agg, axis, expected_groups, reindex) -def _expand_dims(results: IntermediateDict) -> IntermediateDict: - results["intermediates"] = tuple(np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"]) +def _expand_dims(results: IntermediateDict, agg: Aggregation) -> IntermediateDict: + if agg.name == "topk": + results["intermediates"] = tuple(results["intermediates"][:1]) + tuple( + np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"][1:] + ) + else: + results["intermediates"] = tuple( + np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"] + ) return results @@ -1249,7 +1261,7 @@ def _simple_combine( results: IntermediateDict = {"groups": unique_groups} results["intermediates"] = [] - axis_ = axis[:-1] + (DUMMY_AXIS,) + axis_ = (DUMMY_AXIS,) + tuple(a + 1 for a in axis[:-1]) for idx, combine in enumerate(agg.simple_combine): array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_) assert array.ndim >= 2 @@ -1257,7 +1269,9 @@ def _simple_combine( warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") assert callable(combine) result = combine(array, axis=axis_, keepdims=True) - if is_aggregate: + # FIXME: The `idx > 0` clause assumes that DUMMY_AXIS = 0 + # and is inserted by the first elem of simple_combine. + if is_aggregate and (agg.new_dims_func is None or idx > 0): # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate result = result.squeeze(axis=DUMMY_AXIS) results["intermediates"].append(result) @@ -1657,6 +1671,9 @@ def dask_groupby_agg( # use the "non dask" code path, but applied blockwise blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) else: + extra = {} + if agg.name == "topk": + extra["kwargs"] = (agg.finalize_kwargs, *(({},) * (len(agg.chunk) - 1))) # choose `chunk_reduce` or `chunk_argreduce` blockwise_method = partial( _get_chunk_reduction(agg.reduction_type), @@ -1665,10 +1682,11 @@ def dask_groupby_agg( dtype=agg.dtype["intermediate"], reindex=reindex, user_dtype=agg.dtype["user"], + **extra, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over - blockwise_method = tlz.compose(_expand_dims, blockwise_method) + blockwise_method = tlz.compose(partial(_expand_dims, agg=agg), blockwise_method) # apply reduction on chunk intermediate = dask.array.blockwise( @@ -2239,12 +2257,12 @@ def _choose_method( return method -def _choose_engine(by, agg: Aggregation): +def _choose_engine(by, agg: Aggregation) -> T_Engine: dtype = agg.dtype["user"] not_arg_reduce = not _is_arg_reduction(agg) - if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]: + if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "topk"]: logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}") return "flox" @@ -2295,7 +2313,7 @@ def groupby_reduce( equality check are for dimensions of size 1 in `by`. func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence @@ -2367,6 +2385,11 @@ def groupby_reduce( finalize_kwargs : dict, optional Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + Returns ------- result @@ -2403,6 +2426,8 @@ def groupby_reduce( "Use engine='flox' instead (it is also much faster), " "or set engine=None to use the default." ) + if func == "topk" and (finalize_kwargs is None or "k" not in finalize_kwargs): + raise ValueError("Please pass `k` in ``finalize_kwargs`` for topk calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) diff --git a/flox/xarray.py b/flox/xarray.py index fbeeedba6..d51bf95d6 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -8,7 +8,13 @@ import xarray as xr from packaging.version import Version -from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func +from .aggregations import ( + Aggregation, + Dim, + _atleast_1d, + quantile_new_dims_func, + topk_new_dims_func, +) from .core import ( _convert_expected_groups_to_index, _get_expected_groups, @@ -90,7 +96,7 @@ def xarray_reduce( Variables with which to group by ``obj`` func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : str or sequence @@ -175,6 +181,11 @@ def xarray_reduce( DataArray or Dataset Reduced object + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + See Also -------- flox.core.groupby_reduce @@ -366,16 +377,20 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) - # Transpose the new quantile dimension to the end. This is ugly. + # Transpose the new quantile or topk dimension to the end. This is ugly. # but new core dimensions are expected at the end :/ # but groupby_reduce inserts them at the beginning if func in ["quantile", "nanquantile"]: (newdim,) = quantile_new_dims_func(**finalize_kwargs) - if not newdim.is_scalar: - # NOTE: _restore_dim_order will move any new dims to the end anyway. - # This transpose is simply makes it easy to specify output_core_dims - # output dim order: (*broadcast_dims, *group_dims, quantile_dim) - result = np.moveaxis(result, 0, -1) + elif func == "topk": + (newdim,) = topk_new_dims_func(**finalize_kwargs) + else: + newdim = None + if newdim is not None and not newdim.is_scalar: + # NOTE: _restore_dim_order will move any new dims to the end anyway. + # This transpose is simply makes it easy to specify output_core_dims + # output dim order: (*broadcast_dims, *group_dims, quantile_dim) + result = np.moveaxis(result, 0, -1) return result diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index e1b9bccec..05a060733 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -109,6 +109,9 @@ def get_pos_infinity(dtype, max_for_int=False): if issubclass(dtype.type, np.complexfloating): return np.inf + 1j * np.inf + if issubclass(dtype.type, np.bool_): + return True + return INF @@ -142,6 +145,9 @@ def get_neg_infinity(dtype, min_for_int=False): if issubclass(dtype.type, np.complexfloating): return -np.inf - 1j * np.inf + if issubclass(dtype.type, np.bool_): + return False + return NINF diff --git a/flox/xrutils.py b/flox/xrutils.py index 28c9667b2..d58b1da6a 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from numpy.lib.array_utils import normalize_axis_tuple from packaging.version import Version @@ -396,3 +397,25 @@ def nanlast(values, axis, keepdims=False): return np.expand_dims(result, axis=axis) else: return result + + +def topk(a: np.ndarray, k: int, axis, keepdims: bool = True) -> np.ndarray: + """Chunk and combine function of topk + + Extract the k largest elements from a on the given axis. + If k is negative, extract the -k smallest elements instead. + Note that, unlike in the parent function, the returned elements + are not sorted internally. + + NOTE: This function was copied from the dask project under the terms + of their LICENSE. + """ + assert keepdims is True + (axis,) = normalize_axis_tuple(axis, a.ndim) + if abs(k) >= a.shape[axis]: + return a + + a.partition(-k, axis=axis) + k_slice = slice(-k, None) if k > 0 else slice(-k) + result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + return result.astype(a.dtype, copy=False) diff --git a/tests/test_core.py b/tests/test_core.py index 6f8c22962..0552af142 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -83,7 +83,7 @@ def npfunc(x, **kwargs): x = np.asarray(x) return (~xrutils.isnull(x)).sum(**kwargs) - elif func in ["nanfirst", "nanlast"]: + elif func in ["nanfirst", "nanlast", "topk"]: npfunc = getattr(xrutils, func) elif func in SCIPY_STATS_FUNCS: @@ -216,15 +216,15 @@ def gen_array_by(size, func): "chunks", [ None, - pytest.param(-1, marks=requires_dask), - pytest.param(3, marks=requires_dask), - pytest.param(4, marks=requires_dask), + # pytest.param(-1, marks=requires_dask), + # pytest.param(3, marks=requires_dask), + # pytest.param(4, marks=requires_dask), ], ) -@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9))) -@pytest.mark.parametrize("nby", [1, 2, 3]) -@pytest.mark.parametrize("add_nan_by", [True, False]) -@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("size", ((12, 6),)) +@pytest.mark.parametrize("nby", [2]) +@pytest.mark.parametrize("add_nan_by", [True]) +@pytest.mark.parametrize("func", ["topk"]) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() @@ -252,6 +252,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): ] fill_value = None tolerance = None + elif func == "topk": + finalize_kwargs = [{"k": 3}, {"k": -3}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -262,6 +266,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): for kwargs in finalize_kwargs: if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox": continue + if "topk" in func and engine != "flox": + continue flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): with warnings.catch_warnings(): @@ -281,6 +287,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): expected = getattr(np, func_)(array_, axis=-1, **kwargs) else: expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) + if func == "topk": + if nanmask.all(): + expected = np.full(expected.shape[:-1] + (abs(kwargs["k"]),), np.nan) + expected = np.sort(np.swapaxes(expected, array.ndim - 1, 0), axis=0) for _ in range(nby): expected = np.expand_dims(expected, -1) @@ -288,7 +298,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert chunks == -1 actual, *groups = groupby_reduce(array, *by, **flox_kwargs) - if "quantile" in func and isinstance(kwargs["q"], list): + if ("quantile" in func and isinstance(kwargs["q"], list)) or func == "topk": assert actual.ndim == expected.ndim == (array.ndim + nby) else: assert actual.ndim == expected.ndim == (array.ndim + nby - 1) @@ -298,9 +308,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" + if func == "topk": + actual = np.sort(actual, axis=0) assert_equal(expected, actual, tolerance) - if "nan" not in func and "arg" not in func: + # FIXME: topk vs nantopk + if "nan" not in func and "arg" not in func and "topk" not in func: # test non-NaN skipping behaviour when NaNs are present nanned = array_.copy() # remove nans in by to reduce complexity @@ -310,6 +323,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): nanned.reshape(-1)[0] = np.nan actual, *_ = groupby_reduce(nanned, *by_, **flox_kwargs) expected_0 = array_func(nanned, axis=-1, **kwargs) + if func == "topk": + expected_0 = np.sort(np.swapaxes(expected_0, array.ndim - 1, 0), axis=-1) + actual = np.sort(actual, axis=-1) + for _ in range(nby): expected_0 = np.expand_dims(expected_0, -1) assert_equal(expected_0, actual, tolerance) diff --git a/tests/test_properties.py b/tests/test_properties.py index 5f1095f52..689cd3b47 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -327,3 +327,28 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): ) expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) assert actual.dtype == expected.dtype + + +@given(data=st.data(), array=chunked_arrays()) +def test_topk_max_min(data, array): + "top 1 == nanmax; top -1 == nanmin" + + if array.dtype.kind in "iu": + # we cast to float and back, so this is the effective limit + assume((np.abs(array) < 2**53).all()) + elif array.dtype.kind in "Mm": + assume((np.abs(array.view(np.int64)) < 2**53).all()) + # we cast to float and back, so this is the effective limit + elif _contains_cftime_datetimes(array): + asint = datetime_to_numeric(array, datetime_unit="us") + assume((np.abs(asint.view(np.int64)) < 2**53).all()) + + size = array.shape[-1] + by = data.draw(by_arrays(shape=(size,))) + k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) + + for a in (array, array.compute()): + actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) + # TODO: do numbagg, flox + expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") + assert_equal(actual, expected[np.newaxis, :])