Skip to content

Commit

Permalink
Honor retry delay on Azure/OpenAI status codes 429 and 503 (#795)
Browse files Browse the repository at this point in the history
## Motivation and Context (Why the change? What's the scenario?)

When sending too many requests to Azure / OpenAI and receiving status
code 429, the response includes how long to wait before retrying. The
same might be true for status code 503.

KM default retry strategy ignores this information, often retrying too
soon, and causing unnecessary errors and potential pipeline failures.

## High level description (Approach, Design)

Fix the retry policy to honor these headers in case of 429 and 503:

* retry-after-ms
* x-ms-retry-after-ms
* Retry-After
  • Loading branch information
dluc authored Sep 23, 2024
1 parent ccfb815 commit c367cd0
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 23 deletions.
67 changes: 67 additions & 0 deletions examples/002-dotnet-Serverless/Utils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http.Headers;

namespace Microsoft.KernelMemory.Utils;

#pragma warning disable CA1303
#pragma warning disable CA1812

// TMP workaround for Azure SDK bug
// See https://github.com/Azure/azure-sdk-for-net/issues/46109
internal sealed class AuthFixHandler : DelegatingHandler
{
protected override Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request, CancellationToken cancellationToken)
{
if (request.Headers.TryGetValues("Authorization", out var headers) && headers.Count() > 1)
{
request.Headers.Authorization = new AuthenticationHeaderValue(
request.Headers.Authorization!.Scheme,
request.Headers.Authorization.Parameter);
}

return base.SendAsync(request, cancellationToken);
}
}

internal sealed class HttpLogger : DelegatingHandler
{
protected async override Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request, CancellationToken cancellationToken)
{
// Log the request
Console.WriteLine("## Request:");
Console.WriteLine(request.ToString());
if (request.Content != null)
{
Console.WriteLine(await request.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false));
}

Console.WriteLine("Headers");
foreach (var h in request.Headers)
{
foreach (string x in h.Value)
{
Console.WriteLine($"{h.Key}: {x}");
}
}

Console.WriteLine();

// Proceed with the request
HttpResponseMessage response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);

// Log the response
Console.WriteLine("\n\n## Response:");
Console.WriteLine(response.ToString());
if (response.Content != null)
{
Console.WriteLine(await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false));
}

Console.WriteLine();

return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public AzureOpenAITextEmbeddingGenerator(
HttpClient? httpClient = null)
: this(
config,
AzureOpenAIClientBuilder.Build(config, httpClient),
AzureOpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
loggerFactory)
{
Expand Down
2 changes: 1 addition & 1 deletion extensions/AzureOpenAI/AzureOpenAITextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public AzureOpenAITextGenerator(
HttpClient? httpClient = null)
: this(
config,
AzureOpenAIClientBuilder.Build(config, httpClient),
AzureOpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
loggerFactory)
{
Expand Down
8 changes: 6 additions & 2 deletions extensions/AzureOpenAI/Internals/AzureOpenAIClientBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
using Azure;
using Azure.AI.OpenAI;
using Azure.Identity;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;

namespace Microsoft.KernelMemory.AI.AzureOpenAI.Internals;

internal static class AzureOpenAIClientBuilder
{
internal static AzureOpenAIClient Build(AzureOpenAIConfig config, HttpClient? httpClient = null)
internal static AzureOpenAIClient Build(
AzureOpenAIConfig config,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
if (string.IsNullOrEmpty(config.Endpoint))
{
Expand All @@ -21,7 +25,7 @@ internal static AzureOpenAIClient Build(AzureOpenAIConfig config, HttpClient? ht

AzureOpenAIClientOptions options = new()
{
RetryPolicy = new ClientSequentialRetryPolicy(maxRetries: Math.Max(0, config.MaxRetries)),
RetryPolicy = new ClientSequentialRetryPolicy(maxRetries: Math.Max(0, config.MaxRetries), loggerFactory),
ApplicationId = Telemetry.HttpUserAgent,
};

Expand Down
60 changes: 56 additions & 4 deletions extensions/AzureOpenAI/Internals/ClientSequentialRetryPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

using System;
using System.ClientModel.Primitives;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;

namespace Microsoft.KernelMemory.AI.AzureOpenAI.Internals;

internal sealed class ClientSequentialRetryPolicy : ClientRetryPolicy
{
private static readonly TimeSpan[] s_pollingSequence =
private static readonly TimeSpan[] s_retryDelaySequence =
{
TimeSpan.FromSeconds(1),
TimeSpan.FromSeconds(1),
Expand All @@ -19,15 +21,65 @@ internal sealed class ClientSequentialRetryPolicy : ClientRetryPolicy
TimeSpan.FromSeconds(8)
};

private static readonly TimeSpan s_maxDelay = s_pollingSequence[^1];
private static readonly TimeSpan s_maxDelay = s_retryDelaySequence[^1];

public ClientSequentialRetryPolicy(int maxRetries = 3) : base(maxRetries)
private readonly ILogger<ClientSequentialRetryPolicy> _log;

public ClientSequentialRetryPolicy(
int maxRetries = 3,
ILoggerFactory? loggerFactory = null) : base(maxRetries)
{
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<ClientSequentialRetryPolicy>();
}

protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount)
{
// Check if the remote service specified how long to wait before retrying
if (this.TryGetDelayFromResponse(message.Response, out TimeSpan delay))
{
this._log.LogWarning("Delay extracted from HTTP response: {0} msecs", delay.TotalMilliseconds);
return delay;
}

// Use predefined delay, increasing on each attempt up to a max value
int index = Math.Max(0, tryCount - 1);
return index >= s_pollingSequence.Length ? s_maxDelay : s_pollingSequence[index];
return index >= s_retryDelaySequence.Length ? s_maxDelay : s_retryDelaySequence[index];
}

private bool TryGetDelayFromResponse(PipelineResponse? response, out TimeSpan delay)
{
delay = TimeSpan.Zero;

if (response == null || (response.Status != 429 && response.Status != 503)) { return false; }

delay = this.TryGetTimeSpanFromHeader(response, "retry-after-ms")
?? this.TryGetTimeSpanFromHeader(response, "x-ms-retry-after-ms")
?? this.TryGetTimeSpanFromHeader(response, "Retry-After", msecsMultiplier: 1000, allowDateTimeOffset: true)
?? TimeSpan.Zero;

return delay > TimeSpan.Zero;
}

private TimeSpan? TryGetTimeSpanFromHeader(
PipelineResponse response,
string headerName,
int msecsMultiplier = 1,
bool allowDateTimeOffset = false)
{
if (double.TryParse(
response.Headers.TryGetValue(headerName, out string? strValue) ? strValue : null,
out double doubleValue))
{
this._log.LogWarning("Header {0} found, value {1}", headerName, doubleValue);
return TimeSpan.FromMilliseconds(msecsMultiplier * doubleValue);
}

if (allowDateTimeOffset && DateTimeOffset.TryParse(headerName, out DateTimeOffset delayUntil))
{
this._log.LogWarning("Header {0} found, value {1}", headerName, delayUntil);
return delayUntil - DateTimeOffset.UtcNow;
}

return null;
}
}
3 changes: 0 additions & 3 deletions extensions/OpenAI/OpenAI/Internals/.editorconfig

This file was deleted.

2 changes: 1 addition & 1 deletion extensions/OpenAI/OpenAI/Internals/ChangeEndpointPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Collections.Generic;
using System.Threading.Tasks;

namespace Microsoft.KernelMemory.AI.OpenAI;
namespace Microsoft.KernelMemory.AI.OpenAI.Internals;

internal sealed class ChangeEndpointPolicy : PipelinePolicy
{
Expand Down
62 changes: 57 additions & 5 deletions extensions/OpenAI/OpenAI/Internals/ClientSequentialRetryPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

using System;
using System.ClientModel.Primitives;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;

namespace Microsoft.KernelMemory.AI.OpenAI;
namespace Microsoft.KernelMemory.AI.OpenAI.Internals;

internal sealed class ClientSequentialRetryPolicy : ClientRetryPolicy
{
private static readonly TimeSpan[] s_pollingSequence =
private static readonly TimeSpan[] s_retryDelaySequence =
{
TimeSpan.FromSeconds(1),
TimeSpan.FromSeconds(1),
Expand All @@ -19,15 +21,65 @@ internal sealed class ClientSequentialRetryPolicy : ClientRetryPolicy
TimeSpan.FromSeconds(8)
};

private static readonly TimeSpan s_maxDelay = s_pollingSequence[^1];
private static readonly TimeSpan s_maxDelay = s_retryDelaySequence[^1];

public ClientSequentialRetryPolicy(int maxRetries = 3) : base(maxRetries)
private readonly ILogger<ClientSequentialRetryPolicy> _log;

public ClientSequentialRetryPolicy(
int maxRetries = 3,
ILoggerFactory? loggerFactory = null) : base(maxRetries)
{
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<ClientSequentialRetryPolicy>();
}

protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount)
{
// Check if the remote service specified how long to wait before retrying
if (this.TryGetDelayFromResponse(message.Response, out TimeSpan delay))
{
this._log.LogWarning("Delay extracted from HTTP response: {0} msecs", delay.TotalMilliseconds);
return delay;
}

// Use predefined delay, increasing on each attempt up to a max value
int index = Math.Max(0, tryCount - 1);
return index >= s_pollingSequence.Length ? s_maxDelay : s_pollingSequence[index];
return index >= s_retryDelaySequence.Length ? s_maxDelay : s_retryDelaySequence[index];
}

private bool TryGetDelayFromResponse(PipelineResponse? response, out TimeSpan delay)
{
delay = TimeSpan.Zero;

if (response == null || (response.Status != 429 && response.Status != 503)) { return false; }

delay = this.TryGetTimeSpanFromHeader(response, "retry-after-ms")
?? this.TryGetTimeSpanFromHeader(response, "x-ms-retry-after-ms")
?? this.TryGetTimeSpanFromHeader(response, "Retry-After", msecsMultiplier: 1000, allowDateTimeOffset: true)
?? TimeSpan.Zero;

return delay > TimeSpan.Zero;
}

private TimeSpan? TryGetTimeSpanFromHeader(
PipelineResponse response,
string headerName,
int msecsMultiplier = 1,
bool allowDateTimeOffset = false)
{
if (double.TryParse(
response.Headers.TryGetValue(headerName, out string? strValue) ? strValue : null,
out double doubleValue))
{
this._log.LogWarning("Header {0} found, value {1}", headerName, doubleValue);
return TimeSpan.FromMilliseconds(msecsMultiplier * doubleValue);
}

if (allowDateTimeOffset && DateTimeOffset.TryParse(headerName, out DateTimeOffset delayUntil))
{
this._log.LogWarning("Header {0} found, value {1}", headerName, delayUntil);
return delayUntil - DateTimeOffset.UtcNow;
}

return null;
}
}
8 changes: 5 additions & 3 deletions extensions/OpenAI/OpenAI/Internals/OpenAIClientBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
using System;
using System.ClientModel.Primitives;
using System.Net.Http;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using OpenAI;

namespace Microsoft.KernelMemory.AI.OpenAI;
namespace Microsoft.KernelMemory.AI.OpenAI.Internals;

internal static class OpenAIClientBuilder
{
internal static OpenAIClient Build(
OpenAIConfig config,
HttpClient? httpClient = null)
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
OpenAIClientOptions options = new()
{
RetryPolicy = new ClientSequentialRetryPolicy(maxRetries: Math.Max(0, config.MaxRetries)),
RetryPolicy = new ClientSequentialRetryPolicy(maxRetries: Math.Max(0, config.MaxRetries), loggerFactory),
ApplicationId = Telemetry.HttpUserAgent,
};

Expand Down
2 changes: 1 addition & 1 deletion extensions/OpenAI/OpenAI/Internals/SkClientBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using Microsoft.SemanticKernel.Connectors.OpenAI;
using OpenAI;

namespace Microsoft.KernelMemory.AI.OpenAI;
namespace Microsoft.KernelMemory.AI.OpenAI.Internals;

internal static class SkClientBuilder
{
Expand Down
3 changes: 2 additions & 1 deletion extensions/OpenAI/OpenAI/OpenAITextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI.Internals;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Embeddings;
Expand Down Expand Up @@ -45,7 +46,7 @@ public OpenAITextEmbeddingGenerator(
HttpClient? httpClient = null)
: this(
config,
OpenAIClientBuilder.Build(config, httpClient),
OpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
loggerFactory)
{
Expand Down
3 changes: 2 additions & 1 deletion extensions/OpenAI/OpenAI/OpenAITextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI.Internals;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand Down Expand Up @@ -41,7 +42,7 @@ public OpenAITextGenerator(
HttpClient? httpClient = null)
: this(
config,
OpenAIClientBuilder.Build(config, httpClient),
OpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
loggerFactory)
{
Expand Down

0 comments on commit c367cd0

Please sign in to comment.