diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index a801697d8f..f24f374644 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4733,6 +4733,13 @@ public override Task ReadAsync(CancellationToken cancellationToken) return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } + // Register first to catch any already expired tokens to be able to trigger cancellation event. + IDisposable registration = null; + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + } + // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { @@ -4831,12 +4838,6 @@ public override Task ReadAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } - ReadAsyncCallContext context = null; if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 09061277e0..090b0f7bfb 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -5326,6 +5326,13 @@ public override Task ReadAsync(CancellationToken cancellationToken) return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("ReadAsync"))); } + // Register first to catch any already expired tokens to be able to trigger cancellation event. + IDisposable registration = null; + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + } + // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { @@ -5425,12 +5432,6 @@ public override Task ReadAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } - var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext(); Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 704d4a28f0..1efdfeb5fc 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -86,6 +86,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderCancellationTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderCancellationTest.cs new file mode 100644 index 0000000000..38d7da418e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderCancellationTest.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class DataReaderCancellationTest + { + /// + /// Test ensures cancellation token is registered before ReadAsync starts processing results from TDS Stream, + /// such that when Cancel is triggered, the token is capable of canceling reading further results. + /// + /// Async Task + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static async Task CancellationTokenIsRespected_ReadAsync() + { + const string longRunningQuery = @" +with TenRows as (select Value from (values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) as TenRows (Value)), + ThousandRows as (select A.Value as A, B.Value as B, C.Value as C from TenRows as A, TenRows as B, TenRows as C) +select * +from ThousandRows as A, ThousandRows as B, ThousandRows as C;"; + + using (var source = new CancellationTokenSource()) + using (var connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + { + await connection.OpenAsync(source.Token); + + Stopwatch stopwatch = Stopwatch.StartNew(); + await Assert.ThrowsAsync(async () => + { + using (var command = new SqlCommand(longRunningQuery, connection)) + using (var reader = await command.ExecuteReaderAsync(source.Token)) + { + while (await reader.ReadAsync(source.Token)) + { + source.Cancel(); + } + } + }); + Assert.True(stopwatch.ElapsedMilliseconds < 10000, "Cancellation did not trigger on time."); + } + } + + /// + /// Test ensures cancellation token is registered before ReadAsync starts processing results from TDS Stream, + /// such that when Cancel is triggered, the token is capable of canceling reading further results. + /// + /// Async Task + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static async Task CancelledCancellationTokenIsRespected_ReadAsync() + { + const string longRunningQuery = @" +with TenRows as (select Value from (values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) as TenRows (Value)), + ThousandRows as (select A.Value as A, B.Value as B, C.Value as C from TenRows as A, TenRows as B, TenRows as C) +select * +from ThousandRows as A, ThousandRows as B, ThousandRows as C;"; + + using (var source = new CancellationTokenSource()) + using (var connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + { + await connection.OpenAsync(source.Token); + + Stopwatch stopwatch = Stopwatch.StartNew(); + await Assert.ThrowsAsync(async () => + { + using (var command = new SqlCommand(longRunningQuery, connection)) + using (var reader = await command.ExecuteReaderAsync(source.Token)) + { + source.Cancel(); + while (await reader.ReadAsync(source.Token)) + { } + } + }); + Assert.True(stopwatch.ElapsedMilliseconds < 10000, "Cancellation did not trigger on time."); + } + } + } +}