Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add topk #374

Draft
wants to merge 54 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
650088b
Add topk
dcherian Jul 27, 2024
889be0c
Negative k
dcherian Jul 28, 2024
996ff2a
dask support
dcherian Jul 28, 2024
776d233
test
dcherian Jul 28, 2024
a5eb7b9
wip
dcherian Jul 28, 2024
4fa9a4c
fix
dcherian Jul 28, 2024
4b04fde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2024
93800aa
Handle dtypes.NA properly for datetime/timedelta
dcherian Jul 31, 2024
80c67f4
Fix
dcherian Jul 31, 2024
7056d18
Merge branch 'main' into topk
dcherian Aug 7, 2024
44f5f3f
Merge branch 'main' into topk
dcherian Jan 7, 2025
c924017
Fixes
dcherian Jan 7, 2025
7a794ba
one more fix
dcherian Jan 7, 2025
eec4dd4
fix
dcherian Jan 7, 2025
6ac9a1f
one more fix
dcherian Jan 7, 2025
83594e8
Fixes.
dcherian Jan 7, 2025
740f85f
WIP
dcherian Jan 7, 2025
5d64fd9
Merge branch 'main' into topk
dcherian Jan 7, 2025
e177efd
fixes
dcherian Jan 7, 2025
9393470
fix
dcherian Jan 7, 2025
17eb915
cleanup
dcherian Jan 7, 2025
dc0df3e
works?
dcherian Jan 7, 2025
83ae5d8
fix quantile
dcherian Jan 7, 2025
95d20b8
optimize xrutils.topk
dcherian Jan 7, 2025
0b9fafc
Merge branch 'main' into topk
dcherian Jan 8, 2025
caa98b8
Update tests/test_properties.py
dcherian Jan 8, 2025
820d46c
generalize new_dims_func
dcherian Jan 13, 2025
17a4d5d
Merge branch 'main' into topk
dcherian Jan 13, 2025
6aa923a
Revert "generalize new_dims_func"
dcherian Jan 13, 2025
16b0bac
Merge branch 'main' into topk
dcherian Jan 13, 2025
2c6d486
Support bool
dcherian Jan 13, 2025
0dcd87c
more skipping
dcherian Jan 13, 2025
9b874ea
fix
dcherian Jan 14, 2025
adebbec
more xfail
dcherian Jan 15, 2025
ace2af5
Merge branch 'main' into topk
dcherian Jan 19, 2025
4f35230
cleanup
dcherian Jan 19, 2025
cd2f150
one more xfail
dcherian Jan 19, 2025
70e6f22
typing
dcherian Jan 19, 2025
5d45603
minor docs
dcherian Jan 19, 2025
096f6b9
disable log in CI
dcherian Jan 19, 2025
0277cb9
Fix boolean
dcherian Jan 19, 2025
6c7e84a
bool -> bool_
dcherian Jan 20, 2025
43c3408
update int limits
dcherian Jan 20, 2025
01eabfb
fix rtd
dcherian Jan 20, 2025
6e4ce69
Add note
dcherian Jan 20, 2025
4500c7e
Merge branch 'main' into topk
dcherian Jan 24, 2025
8f60477
Add unit test
dcherian Jan 24, 2025
15fcfa1
WIP
dcherian Jan 24, 2025
a5bcc5b
fix
dcherian Jan 24, 2025
489c843
Merge branch 'main' into topk
dcherian Mar 18, 2025
91e1d07
Switch DUMMY_AXIS to 0
dcherian Mar 18, 2025
2d868fe
More support for edge cases
dcherian Mar 18, 2025
d244d60
minor
dcherian Mar 18, 2025
8319f7f
[WIP] failing test
dcherian Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
id: status
run: |
python -c "import xarray; xarray.show_versions()"
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox
- name: Upload code coverage to Codecov
uses: codecov/[email protected]
with:
Expand Down
15 changes: 8 additions & 7 deletions docs/source/aggregations.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@ the `func` kwarg:
- `"mean"`, `"nanmean"`
- `"var"`, `"nanvar"`
- `"std"`, `"nanstd"`
- `"argmin"`
- `"argmax"`
- `"argmin"`, `"nanargmax"`
- `"argmax"`, `"nanargmin"`
- `"first"`, `"nanfirst"`
- `"last"`, `"nanlast"`
- `"median"`, `"nanmedian"`
- `"mode"`, `"nanmode"`
- `"quantile"`, `"nanquantile"`
- `"topk"`

```{tip}
We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome!
```

## Custom Aggregations
## Custom Reductions

`flox` also allows you to specify a custom Aggregation (again inspired by dask.dataframe),
though this might not be fully functional at the moment. See `aggregations.py` for examples.
Expand All @@ -46,3 +43,7 @@ mean = Aggregation(
final_fill_value=np.nan,
)
```

## Custom Scans

Coming soon!
138 changes: 92 additions & 46 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,32 @@ def _lerp(a, b, *, t, dtype, out=None):
return out


def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
def quantile_or_topk(
array,
inv_idx,
*,
q=None,
k=None,
axis,
skipna,
group_idx,
dtype=None,
out=None,
fill_value=None,
):
assert q is not None or k is not None
assert axis == -1

inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))

array_validmask = notnull(array)
actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes
if k is not None:
nanmask = actual_sizes < abs(k)
else:
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes

# The approach here is to use (complex_array.partition) because
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
Expand All @@ -72,36 +90,48 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# So we determine which indices we need using the fact that NaNs get sorted to the end.
# This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# but not any more now that I use partition and avoid replacing NaNs
qin = q
q = np.atleast_1d(qin)
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
if k is not None:
is_scalar_param = False
param = np.sort(np.arange(abs(k)) * np.sign(k))
else:
is_scalar_param = is_scalar(q)
param = np.atleast_1d(q)
param = np.reshape(param, (param.size,) + (1,) * array.ndim)

# This is numpy's method="linear"
# TODO: could support all the interpolations here
offset = actual_sizes.cumsum(axis=-1)
actual_sizes -= 1
virtual_index = q * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

is_scalar_q = is_scalar(qin)
if is_scalar_q:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
else:
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
# For topk(.., k=+1 or -1), we always return the singleton dimension.
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
hi_ = np.ceil(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
if q is not None:
# This is numpy's method="linear"
# TODO: could support all the interpolations here
actual_sizes -= 1
virtual_index = param * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

if is_scalar_param:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))

else:
virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1)
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]
kth = np.unique(virtual_index)
kth = kth[kth >= 0]
kth[kth >= array.shape[axis]] = array.shape[axis] - 1
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
lo_ = k_offset + virtual_index[np.newaxis, ...]
not_enough_elems = actual_sizes < np.abs(k)
lo_[..., not_enough_elems] = 0
badmask = np.broadcast_to(not_enough_elems, idxshape) | nanmask

# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
Expand All @@ -111,20 +141,33 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
cmplx.real = labels_broadcast
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
else:
a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape)

# get bounds, Broadcast to (num quantiles, ..., num labels)
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
a_ = cmplx.imag
if not is_scalar_param:
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = np.nan
if array.dtype.kind in "Mm":
a_ = a_.view(array.dtype)

loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
if q is not None:
# get bounds, Broadcast to (num quantiles, ..., num labels)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = fill_value
else:
result = loval
if badmask.any():
result[badmask] = fill_value

if k is not None:
result = result.astype(dtype, copy=False)
if out is not None:
np.copyto(out, result)
return result


Expand Down Expand Up @@ -158,12 +201,14 @@ def _np_grouped_op(

if out is None:
q = kwargs.get("q", None)
if q is None:
k = kwargs.get("k", None)
if q is None and k is None:
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
else:
nq = len(np.atleast_1d(q))
nq = len(np.atleast_1d(q)) if q is not None else abs(k)
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
kwargs["group_idx"] = group_idx
kwargs["fill_value"] = fill_value

if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
# The previous version of this if condition
Expand Down Expand Up @@ -200,10 +245,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
# TODO: all, any


Expand Down
31 changes: 28 additions & 3 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


def topk_new_dims_func(k) -> tuple[Dim]:
return (Dim(name="k", values=np.arange(abs(k))),)


# if the input contains integers or floats smaller than float64,
# the output data-type is float64. Otherwise, the output data-type is the same as that
# of the input.
Expand All @@ -572,6 +576,16 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
topk = Aggregation(
name="topk",
fill_value=(dtypes.NINF, 0),
final_fill_value=dtypes.NA,
# FIXME: set numpy
chunk=("topk", "nanlen"),
combine=(xrutils.topk, "sum"),
new_dims_func=topk_new_dims_func,
preserves_dtype=True,
)


@dataclass
Expand Down Expand Up @@ -769,6 +783,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
"topk": topk,
# "cumsum": cumsum,
"nancumsum": nancumsum,
"ffill": ffill,
Expand Down Expand Up @@ -823,6 +838,12 @@ def _initialize_aggregation(
),
}

if finalize_kwargs is not None:
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs

if agg.name == "topk" and agg.finalize_kwargs["k"] < 0:
agg.fill_value["intermediate"] = (dtypes.INF, 0)
# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
Expand All @@ -838,9 +859,8 @@ def _initialize_aggregation(
else:
agg.fill_value["numpy"] = (fv,)

if finalize_kwargs is not None:
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs
if agg.name == "topk":
min_count = max(min_count or 0, abs(agg.finalize_kwargs["k"]))

# This is needed for the dask pathway.
# Because we use intermediate fill_value since a group could be
Expand Down Expand Up @@ -878,6 +898,11 @@ def _initialize_aggregation(
else:
simple_combine.append(getattr(np, combine))
else:
# TODO: bah, we need to pass `k` to the combine topk function
# this is ugly.
if agg.name == "topk" and not isinstance(combine, str):
assert combine is not None
combine = partial(combine, **agg.finalize_kwargs)
simple_combine.append(combine)

agg.simple_combine = tuple(simple_combine)
Expand Down
Loading
Loading