diff --git a/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs b/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs index 253640e4f..5f1878297 100644 --- a/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs +++ b/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs @@ -20,6 +20,8 @@ namespace Microsoft.Azure.Storage.Core using Microsoft.Azure.Storage.Core.Util; using System; using System.IO; + using System.Threading; + using System.Threading.Tasks; /// /// This class provides a wrapper that will update the Ingress / Egress bytes of a given request result as the stream is used. @@ -107,6 +109,13 @@ public override int Read(byte[] buffer, int offset, int count) return read; } + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await this.wrappedStream.ReadAsync(buffer, offset, count, cancellationToken); + this.requestObject.IngressBytes += read; + return read; + } + public override int ReadByte() { int val = this.wrappedStream.ReadByte(); @@ -181,6 +190,12 @@ public override void Write(byte[] buffer, int offset, int count) this.requestObject.EgressBytes += count; } + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await this.wrappedStream.WriteAsync(buffer, offset, count, cancellationToken); + this.requestObject.EgressBytes += count; + } + public override void WriteByte(byte value) { this.wrappedStream.WriteByte(value); diff --git a/Lib/ClassLibraryCommon/Core/Executor/Executor.cs b/Lib/ClassLibraryCommon/Core/Executor/Executor.cs index 5c2584f1e..1ac1ffb2f 100644 --- a/Lib/ClassLibraryCommon/Core/Executor/Executor.cs +++ b/Lib/ClassLibraryCommon/Core/Executor/Executor.cs @@ -141,7 +141,12 @@ public static async Task ExecuteAsync(RESTCommand cmd, IRetryPolicy pol // 8. (Potentially reads stream from server) executionState.CurrentOperation = ExecutorOperation.GetResponseStream; - cmd.ResponseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false); + var responseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false); + if (cmd.NetworkTimeout.HasValue) + { + responseStream = new TimeoutStream(responseStream, cmd.NetworkTimeout.Value); + } + cmd.ResponseStream = responseStream; // The stream is now available in ResponseStream. Use the stream to parse out the response or error if (executionState.ExceptionRef != null) diff --git a/Lib/ClassLibraryCommon/Core/TimeoutStream.cs b/Lib/ClassLibraryCommon/Core/TimeoutStream.cs new file mode 100644 index 000000000..b818802ef --- /dev/null +++ b/Lib/ClassLibraryCommon/Core/TimeoutStream.cs @@ -0,0 +1,281 @@ +//----------------------------------------------------------------------- +// +// Copyright 2013 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//----------------------------------------------------------------------- + + +namespace Microsoft.Azure.Storage.Core +{ + using Microsoft.Azure.Storage.Core.Util; + using System; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + + /// + /// Stream that will throw a if it has to wait longer than a configurable timeout to read or write more data + /// + internal class TimeoutStream : Stream + { + private readonly Stream wrappedStream; + private TimeSpan readTimeout; + private TimeSpan writeTimeout; + private CancellationTokenSource cancellationTokenSource; + + public TimeoutStream(Stream wrappedStream, TimeSpan timeout) + : this(wrappedStream, timeout, timeout) { } + + public TimeoutStream(Stream wrappedStream, TimeSpan readTimeout, TimeSpan writeTimeout) + { + CommonUtility.AssertNotNull("WrappedStream", wrappedStream); + CommonUtility.AssertNotNull("ReadTimeout", readTimeout); + CommonUtility.AssertNotNull("WriteTimeout", writeTimeout); + this.wrappedStream = wrappedStream; + this.readTimeout = readTimeout; + this.writeTimeout = writeTimeout; + this.UpdateReadTimeout(); + this.UpdateWriteTimeout(); + this.cancellationTokenSource = new CancellationTokenSource(); + } + + public override long Position + { + get { return this.wrappedStream.Position; } + set { this.wrappedStream.Position = value; } + } + + public override long Length + { + get { return this.wrappedStream.Length; } + } + + public override bool CanWrite + { + get { return this.wrappedStream.CanWrite; } + } + + public override bool CanTimeout + { + get { return this.wrappedStream.CanTimeout; } + } + + public override bool CanSeek + { + get { return this.wrappedStream.CanSeek; } + } + + public override bool CanRead + { + get { return this.wrappedStream.CanRead; } + } + + public override int ReadTimeout + { + get { return (int) this.readTimeout.TotalMilliseconds; } + set { + this.readTimeout = TimeSpan.FromMilliseconds(value); + this.UpdateReadTimeout(); + } + } + + public override int WriteTimeout + { + get { return (int) this.writeTimeout.TotalMilliseconds; } + set + { + this.writeTimeout = TimeSpan.FromMilliseconds(value); + this.UpdateWriteTimeout(); + } + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return this.wrappedStream.BeginRead(buffer, offset, count, callback, state); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return this.wrappedStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void Close() + { + this.wrappedStream.Close(); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return this.wrappedStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return this.wrappedStream.EndRead(asyncResult); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + this.wrappedStream.EndWrite(asyncResult); + } + + public override void Flush() + { + this.wrappedStream.Flush(); + } + + public override async Task FlushAsync(CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + await this.wrappedStream.FlushAsync(source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + return wrappedStream.Read(buffer, offset, count); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + return await this.wrappedStream.ReadAsync(buffer, offset, count, source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override int ReadByte() + { + return this.wrappedStream.ReadByte(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + return this.wrappedStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + this.wrappedStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + this.wrappedStream.Write(buffer, offset, count); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + await this.wrappedStream.WriteAsync(buffer, offset, count, source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override void WriteByte(byte value) + { + this.wrappedStream.WriteByte(value); + } + + private CancellationTokenSource StartTimeout(CancellationToken additionalToken, out bool dispose) + { + if (this.cancellationTokenSource.IsCancellationRequested) + { + this.cancellationTokenSource = new CancellationTokenSource(); + } + + CancellationTokenSource source; + if (additionalToken.CanBeCanceled) + { + source = CancellationTokenSource.CreateLinkedTokenSource(additionalToken, this.cancellationTokenSource.Token); + dispose = true; + } + else + { + source = this.cancellationTokenSource; + dispose = false; + } + + this.cancellationTokenSource.CancelAfter(this.readTimeout); + + return source; + } + + private void StopTimeout(CancellationTokenSource source, bool dispose) + { + this.cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); + if (dispose) + { + source.Dispose(); + } + } + + private void UpdateReadTimeout() + { + if (this.wrappedStream.CanTimeout) + { + try + { + this.wrappedStream.ReadTimeout = (int)this.readTimeout.TotalMilliseconds; + } + catch + { + // ignore + } + } + } + + private void UpdateWriteTimeout() + { + if (this.wrappedStream.CanTimeout) + { + try + { + this.wrappedStream.WriteTimeout = (int)this.writeTimeout.TotalMilliseconds; + } + catch + { + // ignore + } + } + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (disposing) + { + this.cancellationTokenSource.Dispose(); + this.wrappedStream.Dispose(); + } + } + } +} \ No newline at end of file diff --git a/Lib/Common/Blob/BlobRequestOptions.cs b/Lib/Common/Blob/BlobRequestOptions.cs index 3c9d27cbc..f3366affb 100644 --- a/Lib/Common/Blob/BlobRequestOptions.cs +++ b/Lib/Common/Blob/BlobRequestOptions.cs @@ -64,6 +64,7 @@ public sealed class BlobRequestOptions : IRequestOptions LocationMode = RetryPolicies.LocationMode.PrimaryOnly, ServerTimeout = null, MaximumExecutionTime = null, + NetworkTimeout = Constants.DefaultNetworkTimeout, ParallelOperationThreadCount = 1, SingleBlobUploadThresholdInBytes = Constants.MaxSingleUploadBlobSize / 2, @@ -114,6 +115,7 @@ internal BlobRequestOptions(BlobRequestOptions other) this.LocationMode = other.LocationMode; this.ServerTimeout = other.ServerTimeout; this.MaximumExecutionTime = other.MaximumExecutionTime; + this.NetworkTimeout = other.NetworkTimeout; this.OperationExpiryTime = other.OperationExpiryTime; this.ChecksumOptions.CopyFrom(other.ChecksumOptions); this.ParallelOperationThreadCount = other.ParallelOperationThreadCount; @@ -162,6 +164,11 @@ internal static BlobRequestOptions ApplyDefaults(BlobRequestOptions options, Blo ?? serviceClient.DefaultRequestOptions.MaximumExecutionTime ?? BaseDefaultRequestOptions.MaximumExecutionTime; + modifiedOptions.NetworkTimeout = + modifiedOptions.NetworkTimeout + ?? serviceClient.DefaultRequestOptions.NetworkTimeout + ?? BaseDefaultRequestOptions.NetworkTimeout; + modifiedOptions.ParallelOperationThreadCount = modifiedOptions.ParallelOperationThreadCount ?? serviceClient.DefaultRequestOptions.ParallelOperationThreadCount @@ -242,6 +249,8 @@ internal void ApplyToStorageCommand(RESTCommand cmd) { cmd.OperationExpiryTime = DateTime.Now + this.MaximumExecutionTime.Value; } + + cmd.NetworkTimeout = this.NetworkTimeout; } #if !(WINDOWS_RT || NETCORE) @@ -413,6 +422,11 @@ public TimeSpan? MaximumExecutionTime } } + /// + /// Gets or sets the timeout applied to an individual network operations. + /// + public TimeSpan? NetworkTimeout { get; set; } + /// /// Gets or sets the number of blocks that may be simultaneously uploaded. /// diff --git a/Lib/Common/Core/Executor/StorageCommandBase.cs b/Lib/Common/Core/Executor/StorageCommandBase.cs index a90af64af..38927108c 100644 --- a/Lib/Common/Core/Executor/StorageCommandBase.cs +++ b/Lib/Common/Core/Executor/StorageCommandBase.cs @@ -37,6 +37,9 @@ internal abstract class StorageCommandBase // Max client timeout, enforced over entire operation on client side internal DateTime? OperationExpiryTime = null; + // Timeout applied to an individual network operations. + internal TimeSpan? NetworkTimeout = null; + // State- different than async state, this is used for ops to communicate state between invocations, i.e. bytes downloaded etc internal object OperationState = null; diff --git a/Lib/Common/Core/Util/TaskExtensions.cs b/Lib/Common/Core/Util/TaskExtensions.cs index 42be39c3a..6b178b343 100644 --- a/Lib/Common/Core/Util/TaskExtensions.cs +++ b/Lib/Common/Core/Util/TaskExtensions.cs @@ -34,8 +34,16 @@ internal static async Task WithCancellation(this Task task, Cancellatio TaskCompletionSource tcs = new TaskCompletionSource(); using (cancellationToken.Register( taskCompletionSource => ((TaskCompletionSource)taskCompletionSource).TrySetResult(true), tcs)) - if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)) - throw new OperationCanceledException(cancellationToken); + if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)) + { + _ = task.ContinueWith(val => + { + // Mark exceptions thrown from abandonned task as handled. + // https://tpodolak.com/blog/2015/08/10/tpl-exception-handling-and-unobservedtaskexception-issue/ + val.Exception.Handle(ex => true); + }, TaskContinuationOptions.OnlyOnFaulted); + throw new OperationCanceledException(cancellationToken); + } return await task.ConfigureAwait(false); } @@ -51,7 +59,15 @@ internal static async Task WithCancellation(this Task task, CancellationToken ca using (cancellationToken.Register( taskCompletionSource => ((TaskCompletionSource)taskCompletionSource).TrySetResult(true), tcs)) if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)) + { + _ = task.ContinueWith(val => + { + // Mark exceptions thrown from abandonned task as handled. + // https://tpodolak.com/blog/2015/08/10/tpl-exception-handling-and-unobservedtaskexception-issue/ + val.Exception.Handle(ex => true); + }, TaskContinuationOptions.OnlyOnFaulted); throw new OperationCanceledException(cancellationToken); + } await task.ConfigureAwait(false); } } diff --git a/Lib/Common/File/FileRequestOptions.cs b/Lib/Common/File/FileRequestOptions.cs index 7425cfb3c..572b8fb35 100644 --- a/Lib/Common/File/FileRequestOptions.cs +++ b/Lib/Common/File/FileRequestOptions.cs @@ -54,6 +54,7 @@ public sealed class FileRequestOptions : IRequestOptions ServerTimeout = null, MaximumExecutionTime = null, + NetworkTimeout = Constants.DefaultNetworkTimeout, ParallelOperationThreadCount = 1, ChecksumOptions = new ChecksumOptions @@ -101,6 +102,7 @@ internal FileRequestOptions(FileRequestOptions other) this.LocationMode = other.LocationMode; this.ServerTimeout = other.ServerTimeout; this.MaximumExecutionTime = other.MaximumExecutionTime; + this.NetworkTimeout = other.NetworkTimeout; this.OperationExpiryTime = other.OperationExpiryTime; this.ChecksumOptions.CopyFrom(other.ChecksumOptions); this.ParallelOperationThreadCount = other.ParallelOperationThreadCount; @@ -138,6 +140,12 @@ internal static FileRequestOptions ApplyDefaults(FileRequestOptions options, Clo ?? serviceClient.DefaultRequestOptions.MaximumExecutionTime ?? BaseDefaultRequestOptions.MaximumExecutionTime; + + modifiedOptions.NetworkTimeout = + modifiedOptions.NetworkTimeout + ?? serviceClient.DefaultRequestOptions.NetworkTimeout + ?? BaseDefaultRequestOptions.NetworkTimeout; + modifiedOptions.ParallelOperationThreadCount = modifiedOptions.ParallelOperationThreadCount ?? serviceClient.DefaultRequestOptions.ParallelOperationThreadCount @@ -211,6 +219,8 @@ internal void ApplyToStorageCommand(RESTCommand cmd) { cmd.OperationExpiryTime = DateTime.Now + this.MaximumExecutionTime.Value; } + + cmd.NetworkTimeout = this.NetworkTimeout; } /// @@ -292,7 +302,12 @@ public TimeSpan? MaximumExecutionTime this.maximumExecutionTime = value; } - } + } + + /// + /// Gets or sets the timeout applied to an individual network operations. + /// + public TimeSpan? NetworkTimeout { get; set; } /// /// Gets or sets the number of ranges that may be simultaneously uploaded when uploading a file. diff --git a/Lib/Common/IRequestOptions.cs b/Lib/Common/IRequestOptions.cs index 27277b058..ff3ca07c3 100644 --- a/Lib/Common/IRequestOptions.cs +++ b/Lib/Common/IRequestOptions.cs @@ -50,6 +50,12 @@ public interface IRequestOptions /// A containing the maximum execution time across all potential retries. TimeSpan? MaximumExecutionTime { get; set; } + /// + /// Gets or sets the timeout applied to an individual network operations. + /// + /// A containing the timeout applied to an individual network operations. + TimeSpan? NetworkTimeout { get; set; } + #if !(WINDOWS_RT || NETCORE) /// /// Gets or sets a value to indicate whether data written and read by the client library should be encrypted. diff --git a/Lib/Common/Queue/QueueRequestOptions.cs b/Lib/Common/Queue/QueueRequestOptions.cs index 2cca78786..e43a24a7c 100644 --- a/Lib/Common/Queue/QueueRequestOptions.cs +++ b/Lib/Common/Queue/QueueRequestOptions.cs @@ -46,7 +46,8 @@ public sealed class QueueRequestOptions : IRequestOptions #endif LocationMode = RetryPolicies.LocationMode.PrimaryOnly, ServerTimeout = null, - MaximumExecutionTime = null + MaximumExecutionTime = null, + NetworkTimeout = Constants.DefaultNetworkTimeout, }; /// @@ -73,6 +74,7 @@ internal QueueRequestOptions(QueueRequestOptions other) this.ServerTimeout = other.ServerTimeout; this.LocationMode = other.LocationMode; this.MaximumExecutionTime = other.MaximumExecutionTime; + this.NetworkTimeout = other.NetworkTimeout; this.OperationExpiryTime = other.OperationExpiryTime; } } @@ -141,6 +143,8 @@ internal void ApplyToStorageCommand(RESTCommand cmd) { cmd.OperationExpiryTime = DateTime.Now + this.MaximumExecutionTime.Value; } + + cmd.NetworkTimeout = this.NetworkTimeout; } #if !(WINDOWS_RT || NETCORE) @@ -210,6 +214,11 @@ public TimeSpan? MaximumExecutionTime this.maximumExecutionTime = value; } - } + } + + /// + /// Gets or sets the timeout applied to an individual network operations. + /// + public TimeSpan? NetworkTimeout { get; set; } } } diff --git a/Lib/Common/Shared/Protocol/Constants.cs b/Lib/Common/Shared/Protocol/Constants.cs index af2d3b5ed..e6509c23f 100644 --- a/Lib/Common/Shared/Protocol/Constants.cs +++ b/Lib/Common/Shared/Protocol/Constants.cs @@ -132,6 +132,11 @@ static class Constants /// public static readonly TimeSpan MaximumAllowedTimeout = TimeSpan.FromSeconds(int.MaxValue); + /// + /// Default timeout applied to an individual network operations. + /// + public static readonly TimeSpan DefaultNetworkTimeout = TimeSpan.FromSeconds(100); + /// /// Maximum allowed value for Delete Retention Days. /// diff --git a/Lib/WindowsRuntime/Core/Executor/Executor.cs b/Lib/WindowsRuntime/Core/Executor/Executor.cs index 63ef5ae09..d3cd72e62 100644 --- a/Lib/WindowsRuntime/Core/Executor/Executor.cs +++ b/Lib/WindowsRuntime/Core/Executor/Executor.cs @@ -180,7 +180,12 @@ private async static Task ExecuteAsyncInternal(RESTCommand cmd, IRetryP // 8. (Potentially reads stream from server) executionState.CurrentOperation = ExecutorOperation.GetResponseStream; - cmd.ResponseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false); + var responseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false); + if (cmd.NetworkTimeout.HasValue) + { + responseStream = new TimeoutStream(responseStream, cmd.NetworkTimeout.Value); + } + cmd.ResponseStream = responseStream; // The stream is now available in ResponseStream. Use the stream to parse out the response or error if (executionState.ExceptionRef != null) diff --git a/Lib/WindowsRuntime/Core/TimeoutStream.cs b/Lib/WindowsRuntime/Core/TimeoutStream.cs new file mode 100644 index 000000000..ccff9e213 --- /dev/null +++ b/Lib/WindowsRuntime/Core/TimeoutStream.cs @@ -0,0 +1,256 @@ +//----------------------------------------------------------------------- +// +// Copyright 2013 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//----------------------------------------------------------------------- + + +namespace Microsoft.Azure.Storage.Core +{ + using Microsoft.Azure.Storage.Core.Util; + using System; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + + /// + /// Stream that will throw a if it has to wait longer than a configurable timeout to read or write more data + /// + internal class TimeoutStream : Stream + { + private readonly Stream wrappedStream; + private TimeSpan readTimeout; + private TimeSpan writeTimeout; + private CancellationTokenSource cancellationTokenSource; + + public TimeoutStream(Stream wrappedStream, TimeSpan timeout) + : this(wrappedStream, timeout, timeout) { } + + public TimeoutStream(Stream wrappedStream, TimeSpan readTimeout, TimeSpan writeTimeout) + { + CommonUtility.AssertNotNull("WrappedStream", wrappedStream); + CommonUtility.AssertNotNull("ReadTimeout", readTimeout); + CommonUtility.AssertNotNull("WriteTimeout", writeTimeout); + this.wrappedStream = wrappedStream; + this.readTimeout = readTimeout; + this.writeTimeout = writeTimeout; + this.UpdateReadTimeout(); + this.UpdateWriteTimeout(); + this.cancellationTokenSource = new CancellationTokenSource(); + } + + public override long Position + { + get { return this.wrappedStream.Position; } + set { this.wrappedStream.Position = value; } + } + + public override long Length + { + get { return this.wrappedStream.Length; } + } + + public override bool CanWrite + { + get { return this.wrappedStream.CanWrite; } + } + + public override bool CanTimeout + { + get { return this.wrappedStream.CanTimeout; } + } + + public override bool CanSeek + { + get { return this.wrappedStream.CanSeek; } + } + + public override bool CanRead + { + get { return this.wrappedStream.CanRead; } + } + + public override int ReadTimeout + { + get { return (int) this.readTimeout.TotalMilliseconds; } + set { + this.readTimeout = TimeSpan.FromMilliseconds(value); + this.UpdateReadTimeout(); + } + } + + public override int WriteTimeout + { + get { return (int) this.writeTimeout.TotalMilliseconds; } + set + { + this.writeTimeout = TimeSpan.FromMilliseconds(value); + this.UpdateWriteTimeout(); + } + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return this.wrappedStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override void Flush() + { + this.wrappedStream.Flush(); + } + + public override async Task FlushAsync(CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + await this.wrappedStream.FlushAsync(source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + return wrappedStream.Read(buffer, offset, count); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + return await this.wrappedStream.ReadAsync(buffer, offset, count, source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override int ReadByte() + { + return this.wrappedStream.ReadByte(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + return this.wrappedStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + this.wrappedStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + this.wrappedStream.Write(buffer, offset, count); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var source = StartTimeout(cancellationToken, out bool dispose); + try + { + await this.wrappedStream.WriteAsync(buffer, offset, count, source.Token); + } + finally + { + StopTimeout(source, dispose); + } + } + + public override void WriteByte(byte value) + { + this.wrappedStream.WriteByte(value); + } + + private CancellationTokenSource StartTimeout(CancellationToken additionalToken, out bool dispose) + { + if (this.cancellationTokenSource.IsCancellationRequested) + { + this.cancellationTokenSource = new CancellationTokenSource(); + } + + CancellationTokenSource source; + if (additionalToken.CanBeCanceled) + { + source = CancellationTokenSource.CreateLinkedTokenSource(additionalToken, this.cancellationTokenSource.Token); + dispose = true; + } + else + { + source = this.cancellationTokenSource; + dispose = false; + } + + this.cancellationTokenSource.CancelAfter(this.readTimeout); + + return source; + } + + private void StopTimeout(CancellationTokenSource source, bool dispose) + { + this.cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan); + if (dispose) + { + source.Dispose(); + } + } + + private void UpdateReadTimeout() + { + if (this.wrappedStream.CanTimeout) + { + try + { + this.wrappedStream.ReadTimeout = (int)this.readTimeout.TotalMilliseconds; + } + catch + { + // ignore + } + } + } + + private void UpdateWriteTimeout() + { + if (this.wrappedStream.CanTimeout) + { + try + { + this.wrappedStream.WriteTimeout = (int)this.writeTimeout.TotalMilliseconds; + } + catch + { + // ignore + } + } + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (disposing) + { + this.cancellationTokenSource.Dispose(); + this.wrappedStream.Dispose(); + } + } + } +} \ No newline at end of file