Skip to content

Commit 32f0241

Browse files
Improve PartialSort() performance; fixes #18
1 parent e2c65fd commit 32f0241

File tree

3 files changed

+82
-67
lines changed

3 files changed

+82
-67
lines changed

Source/SuperLinq/PartialSort.cs

+80-65
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ public static partial class SuperEnumerable
1212
/// <param name="count">Number of (maximum) elements to return.</param>
1313
/// <returns>A sequence containing at most top <paramref name="count"/>
1414
/// elements from source, in their ascending order.</returns>
15+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
16+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
1517
/// <remarks>
1618
/// This operator uses deferred execution and streams it results.
1719
/// </remarks>
18-
1920
public static IEnumerable<T> PartialSort<T>(this IEnumerable<T> source, int count)
2021
{
2122
return source.PartialSort(count, comparer: null);
@@ -33,10 +34,11 @@ public static IEnumerable<T> PartialSort<T>(this IEnumerable<T> source, int coun
3334
/// <param name="direction">The direction in which to sort the elements</param>
3435
/// <returns>A sequence containing at most top <paramref name="count"/>
3536
/// elements from source, in the specified order.</returns>
37+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
38+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
3639
/// <remarks>
3740
/// This operator uses deferred execution and streams it results.
3841
/// </remarks>
39-
4042
public static IEnumerable<T> PartialSort<T>(
4143
this IEnumerable<T> source, int count, OrderByDirection direction)
4244
{
@@ -55,10 +57,11 @@ public static IEnumerable<T> PartialSort<T>(
5557
/// <param name="comparer">A <see cref="IComparer{T}"/> to compare elements.</param>
5658
/// <returns>A sequence containing at most top <paramref name="count"/>
5759
/// elements from source, in their ascending order.</returns>
60+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
61+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
5862
/// <remarks>
5963
/// This operator uses deferred execution and streams it results.
6064
/// </remarks>
61-
6265
public static IEnumerable<T> PartialSort<T>(
6366
this IEnumerable<T> source,
6467
int count, IComparer<T>? comparer)
@@ -80,19 +83,52 @@ public static IEnumerable<T> PartialSort<T>(
8083
/// <param name="direction">The direction in which to sort the elements</param>
8184
/// <returns>A sequence containing at most top <paramref name="count"/>
8285
/// elements from source, in the specified order.</returns>
86+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
87+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
8388
/// <remarks>
8489
/// This operator uses deferred execution and streams it results.
8590
/// </remarks>
86-
8791
public static IEnumerable<T> PartialSort<T>(
8892
this IEnumerable<T> source, int count,
8993
IComparer<T>? comparer, OrderByDirection direction)
9094
{
9195
source.ThrowIfNull();
96+
count.ThrowIfLessThan(1);
97+
9298
comparer ??= Comparer<T>.Default;
9399
if (direction == OrderByDirection.Descending)
94100
comparer = new ReverseComparer<T>(comparer);
95-
return PartialSortByImpl<T, T>(source, count, keySelector: null, keyComparer: null, comparer);
101+
102+
return _(source, count, comparer);
103+
104+
static IEnumerable<T> _(IEnumerable<T> source, int count, IComparer<T> comparer)
105+
{
106+
var top = new SortedSet<(T item, int index)>(
107+
Comparer<(T item, int index)>.Create((x, y) =>
108+
{
109+
var result = comparer.Compare(x.item, y.item);
110+
return result != 0 ? result :
111+
Comparer<long>.Default.Compare(x.index, y.index);
112+
}));
113+
114+
foreach (var (index, item) in source.Index())
115+
{
116+
if (top.Count < count)
117+
{
118+
top.Add((item, index));
119+
continue;
120+
}
121+
122+
if (comparer.Compare(item, top.Max.item) >= 0)
123+
continue;
124+
125+
top.Remove(top.Max);
126+
top.Add((item, index));
127+
}
128+
129+
foreach (var (item, _) in top)
130+
yield return item;
131+
}
96132
}
97133

98134
/// <summary>
@@ -106,10 +142,12 @@ public static IEnumerable<T> PartialSort<T>(
106142
/// <param name="count">Number of (maximum) elements to return.</param>
107143
/// <returns>A sequence containing at most top <paramref name="count"/>
108144
/// elements from source, in ascending order of their keys.</returns>
145+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
146+
/// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
147+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
109148
/// <remarks>
110149
/// This operator uses deferred execution and streams it results.
111150
/// </remarks>
112-
113151
public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
114152
this IEnumerable<TSource> source, int count,
115153
Func<TSource, TKey> keySelector)
@@ -130,10 +168,12 @@ public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
130168
/// <param name="direction">The direction in which to sort the elements</param>
131169
/// <returns>A sequence containing at most top <paramref name="count"/>
132170
/// elements from source, in the specified order of their keys.</returns>
171+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
172+
/// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
173+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
133174
/// <remarks>
134175
/// This operator uses deferred execution and streams it results.
135176
/// </remarks>
136-
137177
public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
138178
this IEnumerable<TSource> source, int count,
139179
Func<TSource, TKey> keySelector, OrderByDirection direction)
@@ -154,10 +194,12 @@ public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
154194
/// <param name="comparer">A <see cref="IComparer{T}"/> to compare elements.</param>
155195
/// <returns>A sequence containing at most top <paramref name="count"/>
156196
/// elements from source, in ascending order of their keys.</returns>
197+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
198+
/// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
199+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
157200
/// <remarks>
158201
/// This operator uses deferred execution and streams it results.
159202
/// </remarks>
160-
161203
public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
162204
this IEnumerable<TSource> source, int count,
163205
Func<TSource, TKey> keySelector,
@@ -181,87 +223,60 @@ public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
181223
/// <param name="direction">The direction in which to sort the elements</param>
182224
/// <returns>A sequence containing at most top <paramref name="count"/>
183225
/// elements from source, in the specified order of their keys.</returns>
226+
/// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
227+
/// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
228+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="count"/> is less than 1.</exception>
184229
/// <remarks>
185230
/// This operator uses deferred execution and streams it results.
186231
/// </remarks>
187-
188232
public static IEnumerable<TSource> PartialSortBy<TSource, TKey>(
189233
this IEnumerable<TSource> source, int count,
190234
Func<TSource, TKey> keySelector,
191235
IComparer<TKey>? comparer,
192236
OrderByDirection direction)
193237
{
194238
source.ThrowIfNull();
239+
count.ThrowIfLessThan(1);
195240
keySelector.ThrowIfNull();
196241

197242
comparer ??= Comparer<TKey>.Default;
198243
if (direction == OrderByDirection.Descending)
199244
comparer = new ReverseComparer<TKey>(comparer);
200-
return PartialSortByImpl(source, count, keySelector, keyComparer: comparer, comparer: null);
201-
}
202245

203-
static IEnumerable<TSource> PartialSortByImpl<TSource, TKey>(
204-
IEnumerable<TSource> source, int count,
205-
Func<TSource, TKey>? keySelector,
206-
IComparer<TKey>? keyComparer,
207-
IComparer<TSource>? comparer)
208-
{
209-
var top = new List<TSource>(count);
246+
return _(source, count, keySelector, comparer);
210247

211-
static int? Insert<T>(List<T> list, T item, IComparer<T> comparer, int count)
248+
static IEnumerable<TSource> _(IEnumerable<TSource> source, int count, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
212249
{
213-
var i = list.BinarySearch(item, comparer);
214-
// find the place to insert
215-
if (i < 0 && (i = ~i) >= count)
216-
return null;
217-
// move forward until we get to next larger
218-
while (i < list.Count && comparer.Compare(item, list[i]) == 0)
219-
i++;
220-
// is the list full?
221-
if (list.Count == count)
222-
{
223-
// if our insert location is at the end of the list
224-
if (i == list.Count
225-
// and we're _not larger_ than the last item
226-
&& comparer.Compare(item, list[^1]) <= 0)
250+
var top = new SortedSet<(TKey Item, int Index)>(
251+
Comparer<(TKey item, int index)>.Create((x, y) =>
227252
{
228-
// then don't affect the list
229-
return null;
230-
}
231-
// remove last item
232-
list.RemoveAt(count - 1);
233-
}
253+
var result = comparer.Compare(x.item, y.item);
254+
return result != 0 ? result :
255+
Comparer<long>.Default.Compare(x.index, y.index);
256+
}));
257+
var dic = new Dictionary<(TKey Item, int Index), TSource>(count);
234258

235-
list.Insert(i, item);
236-
return i;
237-
}
238-
239-
if (keyComparer != null)
240-
{
241-
var keys = new List<TKey>(count);
242-
243-
foreach (var item in source)
259+
foreach (var (index, item) in source.Index())
244260
{
245-
var key = keySelector!(item);
246-
if (Insert(keys, key, keyComparer, count) is { } i)
261+
var key = (key: keySelector(item), index);
262+
if (top.Count < count)
247263
{
248-
if (top.Count == count)
249-
top.RemoveAt(count - 1);
250-
top.Insert(i, item);
264+
top.Add(key);
265+
dic[key] = item;
266+
continue;
251267
}
268+
269+
if (comparer.Compare(key.key, top.Max.Item) >= 0)
270+
continue;
271+
272+
dic.Remove(top.Max);
273+
top.Remove(top.Max);
274+
top.Add(key);
275+
dic[key] = item;
252276
}
253-
}
254-
else if (comparer != null)
255-
{
256-
foreach (var item in source)
257-
_ = Insert(top, item, comparer, count);
258-
}
259-
else
260-
{
261-
throw new NotSupportedException("Should not be able to reach here.");
262-
}
263277

264-
foreach (var item in top)
265-
yield return item;
278+
foreach (var entry in top)
279+
yield return dic[entry];
280+
}
266281
}
267282
}

Tests/SuperLinq.Test/PartialSortByTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void PartialSortByIsStable()
8787
for (var i = 1; i <= 10; i++)
8888
{
8989
var sorted = list.PartialSortBy(i, x => x.key);
90-
Assert.True(sorted.SequenceEqual(stableSort.Take(i)));
90+
Assert.Equal(stableSort.Take(i), sorted);
9191
}
9292
}
9393
}

Tests/SuperLinq.Test/PartialSortTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public void PartialSortIsStable()
9494
for (var i = 1; i <= 10; i++)
9595
{
9696
var sorted = list.PartialSort(i, comparer);
97-
Assert.True(sorted.SequenceEqual(stableSort.Take(i)));
97+
Assert.Equal(stableSort.Take(i), sorted);
9898
}
9999
}
100100
}

0 commit comments

Comments
 (0)