Skip to content

Commit 776d233

Browse files
committed
test
1 parent 996ff2a commit 776d233

File tree

5 files changed

+38
-11
lines changed

5 files changed

+38
-11
lines changed

flox/aggregate_flox.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,8 @@ def quantile_or_topk(
9898
param = np.atleast_1d(param)
9999
param = np.reshape(param, (param.size,) + (1,) * array.ndim)
100100

101-
if is_scalar_param:
102-
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
103-
else:
104-
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
101+
# For topk(.., k=+1 or -1), we always return the singleton dimension.
102+
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
105103

106104
if q is not None:
107105
# This is numpy's method="linear"
@@ -110,6 +108,7 @@ def quantile_or_topk(
110108

111109
if is_scalar_param:
112110
virtual_index = virtual_index.squeeze(axis=0)
111+
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
113112

114113
lo_ = np.floor(
115114
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
@@ -122,7 +121,7 @@ def quantile_or_topk(
122121
else:
123122
virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1)
124123
kth = np.unique(virtual_index)
125-
kth = kth[kth > 0]
124+
kth = kth[kth >= 0]
126125
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
127126
lo_ = k_offset + virtual_index[np.newaxis, ...]
128127

@@ -147,12 +146,18 @@ def quantile_or_topk(
147146
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
148147
else:
149148
result = loval
150-
result[lo_ < 0] = fill_value
149+
# This happens if numel in group < abs(k)
150+
badmask = lo_ < 0
151+
if badmask.any():
152+
result[badmask] = fill_value
153+
151154
if not skipna and np.any(nanmask):
152155
result[..., nanmask] = fill_value
156+
153157
if k is not None:
154158
result = result.astype(dtype, copy=False)
155-
np.copyto(out, result)
159+
if out is not None:
160+
np.copyto(out, result)
156161
return result
157162

158163

flox/aggregations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def _initialize_aggregation(
830830
)
831831

832832
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
833-
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
833+
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax", "topk"]:
834834
final_dtype = _maybe_promote_int(final_dtype)
835835
agg.dtype = {
836836
"user": dtype, # Save to automatically choose an engine
@@ -892,6 +892,8 @@ def _initialize_aggregation(
892892
if isinstance(combine, str):
893893
simple_combine.append(getattr(np, combine))
894894
else:
895+
if agg.name == "topk":
896+
combine = partial(combine, **finalize_kwargs)
895897
simple_combine.append(combine)
896898

897899
agg.simple_combine = tuple(simple_combine)

flox/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ def chunk_reduce(
958958
nfuncs = len(funcs)
959959
dtypes = _atleast_1d(dtype, nfuncs)
960960
fill_values = _atleast_1d(fill_value, nfuncs)
961-
kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs
961+
kwargss = _atleast_1d({} if kwargs is None else kwargs, nfuncs)
962962

963963
if isinstance(axis, Sequence):
964964
axes: T_Axes = axis
@@ -1645,6 +1645,7 @@ def dask_groupby_agg(
16451645
dtype=agg.dtype["intermediate"],
16461646
reindex=reindex,
16471647
user_dtype=agg.dtype["user"],
1648+
kwargs=agg.finalize_kwargs if agg.name == "topk" else None,
16481649
)
16491650
if do_simple_combine:
16501651
# Add a dummy dimension that then gets reduced over
@@ -2372,6 +2373,9 @@ def groupby_reduce(
23722373
"Use engine='flox' instead (it is also much faster), "
23732374
"or set engine=None to use the default."
23742375
)
2376+
if func == "topk":
2377+
if finalize_kwargs is None or "k" not in finalize_kwargs:
2378+
raise ValueError("Please pass `k` for topk calculations.")
23752379

23762380
bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
23772381
nby = len(bys)

flox/xrutils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,13 @@ def topk(a, k, axis, keepdims):
389389
are not sorted internally.
390390
"""
391391
assert keepdims is True
392-
axis = axis[0]
392+
(axis,) = axis
393+
axis = normalize_axis_index(axis, a.ndim)
393394
if abs(k) >= a.shape[axis]:
394395
return a
395396

397+
# TODO: handle NaNs
396398
a = np.partition(a, -k, axis=axis)
397399
k_slice = slice(-k, None) if k > 0 else slice(-k)
398-
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
400+
result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
401+
return result

tests/test_properties.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
210210
first, *_ = groupby_reduce(array, by, func=func, engine="flox")
211211
second, *_ = groupby_reduce(array, by, func=mate, engine="flox")
212212
assert_equal(first, second)
213+
214+
215+
@given(data=st.data(), array=chunked_arrays())
216+
def test_topk_max_min(data, array):
217+
"top 1 == max; top -1 == min"
218+
size = array.shape[-1]
219+
by = data.draw(by_arrays(shape=(size,)))
220+
k, npfunc = data.draw(st.sampled_from([(1, "max"), (-1, "min")]))
221+
222+
for a in (array, array.compute()):
223+
actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k})
224+
expected, _ = groupby_reduce(a, by, func=npfunc)
225+
assert_equal(actual, expected[np.newaxis, :])

0 commit comments

Comments
 (0)