Skip to content

Add LINQ Shuffle #112173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ public static partial class AsyncEnumerable
public static System.Collections.Generic.IAsyncEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<TResult>> selector) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, TResult> selector) { throw null; }
public static System.Threading.Tasks.ValueTask<bool> SequenceEqualAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> first, System.Collections.Generic.IAsyncEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<TSource> Shuffle<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source) { throw null; }
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, bool> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<bool>> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
</PropertyGroup>

<ItemGroup>
<Compile Include="System\Linq\Shuffle.cs" />
<Compile Include="System\Linq\SkipLast.cs" />
<Compile Include="System\Linq\SkipWhile.cs" />
<Compile Include="System\Linq\Append.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;

namespace System.Linq
{
public static partial class AsyncEnumerable
{
#if !NET
[ThreadStatic]
private static Random? t_random;
#endif

/// <summary>Shuffles the order of the elements of a sequence.</summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source"/>.</typeparam>
/// <param name="source">A sequence of values to shuffle.</param>
/// <returns>A sequence whose elements correspond to those of the input sequence in randomized order.</returns>
/// <remarks>Randomization is performed using a non-cryptographically-secure random number generator.</remarks>
public static IAsyncEnumerable<TSource> Shuffle<TSource>(
this IAsyncEnumerable<TSource> source)
{
ThrowHelper.ThrowIfNull(source);

return Impl(source, default);

static async IAsyncEnumerable<TSource> Impl(
IAsyncEnumerable<TSource> source,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
TSource[] array = await source.ToArrayAsync(cancellationToken).ConfigureAwait(false);

#if NET
Random.Shared.Shuffle(array);
#else
Random random = t_random ??= new Random(Environment.TickCount ^ Environment.CurrentManagedThreadId);
int n = array.Length;
for (int i = 0; i < n - 1; i++)
{
int j = random.Next(i, n);
if (j != i)
{
TSource temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}
#endif

for (int i = 0; i < array.Length; i++)
{
yield return array[i];
}
}
}
}
}
69 changes: 69 additions & 0 deletions src/libraries/System.Linq.AsyncEnumerable/tests/ShuffleTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace System.Linq.Tests
{
public class ShuffleTests : AsyncEnumerableTests
{
[Fact]
public void InvalidInputs_Throws()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => AsyncEnumerable.Shuffle<int>(null));
}

[Theory]
[InlineData(new int[0])]
[InlineData(new int[] { 1 })]
[InlineData(new int[] { 2, 4, 8 })]
[InlineData(new int[] { -1, 2, 5, 6, 7, 8 })]
public async Task VariousValues_ContainsAllInputValues(int[] values)
{
foreach (IAsyncEnumerable<int> source in CreateSources(values))
{
int[] shuffled = await source.Shuffle().ToArrayAsync();
Array.Sort(shuffled);
Assert.Equal(values, shuffled);
}
}

[Fact]
public async Task ToArrayAsync_ElementsAreRandomized()
{
// The chance that shuffling a thousand elements produces the same order twice is infinitesimal.
int length = 1000;
foreach (IAsyncEnumerable<int> source in CreateSources(Enumerable.Range(0, length).ToArray()))
{
int[] first = await source.Shuffle().ToArrayAsync();
int[] second = await source.Shuffle().ToArrayAsync();
Assert.Equal(length, first.Length);
Assert.Equal(length, second.Length);
Assert.NotEqual(first, second);
}
}

[Fact]
public async Task Cancellation_Cancels()
{
IAsyncEnumerable<int> source = CreateSource(2, 4, 8, 16);
await Assert.ThrowsAsync<OperationCanceledException>(async () =>
{
await ConsumeAsync(source.Shuffle().WithCancellation(new CancellationToken(true)));
});
}

[Fact]
public async Task InterfaceCalls_ExpectedCounts()
{
TrackingAsyncEnumerable<int> source = CreateSource(2, 4, 8, 16).Track();
await ConsumeAsync(source.Shuffle());
Assert.Equal(5, source.MoveNextAsyncCount);
Assert.Equal(4, source.CurrentCount);
Assert.Equal(1, source.DisposeAsyncCount);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
<Compile Include="ElementAtOrDefaultAsyncTests.cs" />
<Compile Include="ElementAtAsyncTests.cs" />
<Compile Include="GroupByTests.cs" />
<Compile Include="ShuffleTests.cs" />
<Compile Include="SingleOrDefaultAsyncTests.cs" />
<Compile Include="SingleAsyncTests.cs" />
<Compile Include="LastOrDefaultAsyncTests.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ public static partial class Queryable
public static System.Linq.IQueryable<TResult> Select<TSource, TResult>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TResult>> selector) { throw null; }
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2) { throw null; }
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static System.Linq.IQueryable<TSource> Shuffle<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate, TSource defaultValue) { throw null; }
Expand Down
12 changes: 12 additions & 0 deletions src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,18 @@ public static bool SequenceEqual<TSource>(this IQueryable<TSource> source1, IEnu
Expression.Constant(comparer, typeof(IEqualityComparer<TSource>))));
}

[DynamicDependency("Shuffle`1", typeof(Enumerable))]
public static IQueryable<TSource> Shuffle<TSource>(this IQueryable<TSource> source)
{
ArgumentNullException.ThrowIfNull(source);

return source.Provider.CreateQuery<TSource>(
Expression.Call(
null,
new Func<IQueryable<TSource>, IQueryable<TSource>>(Shuffle).Method,
source.Expression));
}

[DynamicDependency("Any`1", typeof(Enumerable))]
public static bool Any<TSource>(this IQueryable<TSource> source)
{
Expand Down
37 changes: 37 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/ShuffleTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Xunit;

namespace System.Linq.Tests
{
public class ShuffleTests : EnumerableBasedTests
{
[Fact]
public void InvalidArguments()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<string>)null).Shuffle());
}

[Fact]
public void ProducesAllElements()
{
int[] shuffled = Enumerable.Range(0, 1000).AsQueryable().Shuffle().ToArray();
Array.Sort(shuffled);
Assert.Equal(Enumerable.Range(0, shuffled.Length), shuffled);
}

[Fact]
public void ElementsAreRandomized()
{
// The chance that shuffling a thousand elements produces the same order twice is infinitesimal.
const int Length = 1000;
IQueryable<int> source = Enumerable.Range(0, Length).AsQueryable().Shuffle();
int[] first = source.ToArray();
int[] second = source.ToArray();
Assert.Equal(Length, first.Length);
Assert.Equal(Length, second.Length);
Assert.NotEqual(first, second);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<Compile Include="JoinTests.cs" />
<Compile Include="LastOrDefaultTests.cs" />
<Compile Include="LastTests.cs" />
<Compile Include="ShuffleTests.cs" />
<Compile Include="RightJoinTests.cs" />
<Compile Include="LongCountTests.cs" />
<Compile Include="MaxTests.cs" />
Expand Down Expand Up @@ -62,7 +63,6 @@
<Compile Include="UnionTests.cs" />
<Compile Include="WhereTests.cs" />
<Compile Include="ZipTests.cs" />
<Compile Include="$(CommonTestPath)System\Linq\SkipTakeData.cs"
Link="Common\System\Linq\SkipTakeData.cs" />
<Compile Include="$(CommonTestPath)System\Linq\SkipTakeData.cs" Link="Common\System\Linq\SkipTakeData.cs" />
</ItemGroup>
</Project>
3 changes: 3 additions & 0 deletions src/libraries/System.Linq/ref/System.Linq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// Changes to this file must follow the https://aka.ms/api-review process.
// ------------------------------------------------------------------------------

using System.Collections.Generic;

namespace System.Linq
{
public static partial class Enumerable
Expand Down Expand Up @@ -172,6 +174,7 @@ public static System.Collections.Generic.IEnumerable<
public static System.Collections.Generic.IEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TResult> selector) { throw null; }
public static bool SequenceEqual<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
public static bool SequenceEqual<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Shuffle<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
public static TSource SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, TSource defaultValue) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
Expand Down
2 changes: 2 additions & 0 deletions src/libraries/System.Linq/src/System.Linq.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
<Compile Include="System\Linq\Range.SpeedOpt.cs" />
<Compile Include="System\Linq\Repeat.cs" />
<Compile Include="System\Linq\Repeat.SpeedOpt.cs" />
<Compile Include="System\Linq\Shuffle.SpeedOpt.cs" />
<Compile Include="System\Linq\Shuffle.cs" />
<Compile Include="System\Linq\Reverse.cs" />
<Compile Include="System\Linq\Reverse.SpeedOpt.cs" />
<Compile Include="System\Linq\RightJoin.cs" />
Expand Down
Loading
Loading