diff --git a/.changelog/1762986506.md b/.changelog/1762986506.md new file mode 100644 index 00000000000..832b7821fb0 --- /dev/null +++ b/.changelog/1762986506.md @@ -0,0 +1,11 @@ +--- +applies_to: +- client +authors: +- annahay +references: [] +breaking: false +new_feature: false +bug_fix: true +--- +Add support for static retry strategy diff --git a/.changelog/1763060740.md b/.changelog/1763060740.md new file mode 100644 index 00000000000..832bf0e5c41 --- /dev/null +++ b/.changelog/1763060740.md @@ -0,0 +1,11 @@ +--- +applies_to: +- client +authors: +- annahay +references: [] +breaking: false +new_feature: true +bug_fix: false +--- +Add support for configurable token bucket success reward and fractional token management diff --git a/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt b/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt index 3a3c022eded..1520b3e765a 100644 --- a/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt +++ b/aws/codegen-aws-sdk/src/test/kotlin/software/amazon/smithy/rustsdk/RetryPartitionTest.kt @@ -114,6 +114,7 @@ class RetryPartitionTest { "RetryPartition" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::RetryPartition"), "RuntimeComponents" to RuntimeType.runtimeComponents(ctx.runtimeConfig), "TokenBucket" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::TokenBucket"), + "MAXIMUM_CAPACITY" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::MAXIMUM_CAPACITY"), ) crate.integrationTest("custom_retry_partition") { tokioTest("test_custom_token_bucket") { @@ -139,7 +140,8 @@ class RetryPartitionTest { ) -> Result<(), #{BoxError}> { self.called.fetch_add(1, Ordering::Relaxed); let token_bucket = cfg.load::<#{TokenBucket}>().unwrap(); - let expected = format!("permits: {}", tokio::sync::Semaphore::MAX_PERMITS); + let max_capacity = #{MAXIMUM_CAPACITY}; + let expected = format!("permits: {}", max_capacity); assert!( format!("{token_bucket:?}").contains(&expected), "Expected debug output to contain `{expected}`, but got: {token_bucket:?}" diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries.rs b/rust-runtime/aws-smithy-runtime/src/client/retries.rs index 5c8de0c1228..30b48639e15 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries.rs @@ -18,7 +18,7 @@ use std::fmt; pub use client_rate_limiter::{ ClientRateLimiter, ClientRateLimiterBuilder, ClientRateLimiterPartition, }; -pub use token_bucket::{TokenBucket, TokenBucketBuilder}; +pub use token_bucket::{TokenBucket, TokenBucketBuilder, MAXIMUM_CAPACITY}; use std::borrow::Cow; diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs index de15bb71f79..474f3178d68 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs @@ -3,8 +3,59 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! A rate limiter for controlling the rate at which AWS requests are made. The rate changes based -//! on the number of throttling errors encountered. +//! A rate limiter for controlling the rate at which AWS requests are made. +//! +//! This module implements an adaptive token bucket rate limiter that can operate in two modes: +//! **dynamic** (default) and **static**. +//! +//! # Dynamic Mode (Default) +//! +//! In dynamic mode, the rate limiter automatically adjusts token refill rates based on measured +//! request throughput and throttling responses using cubic scaling algorithms. This provides +//! adaptive rate limiting that responds to service capacity. +//! +//! **Key behaviors:** +//! - Initially disabled - allows all requests through until first throttling error +//! - After first throttle: enables token bucket and dynamically adjusts refill rate +//! - Uses cubic throttle/success algorithms to scale rate based on measured throughput +//! - Enforces MIN_FILL_RATE (0.5 tokens/sec) as floor for dynamic adjustments +//! +//! # Static Mode +//! +//! In static mode, the rate limiter uses a fixed token refill rate without dynamic adjustment. +//! This provides predictable, configurable rate limiting. +//! +//! **Key behaviors:** +//! - Initially disabled - allows all requests through until first throttling error +//! - After first throttle: enables token bucket with fixed refill rate +//! - No automatic rate adjustment based on throughput +//! - Allows any refill rate (no MIN_FILL_RATE enforcement) +//! - Supports optional success token rewards +//! +//! # Token Bucket Algorithm +//! +//! Both modes use a token bucket algorithm: +//! - Tokens replenish over time at the configured rate +//! - Each request consumes tokens based on request type (all configurable): +//! - Initial request: default 1.0 tokens +//! - Retry: default 5.0 tokens +//! - Retry with timeout: default 10.0 tokens +//! - Requests are delayed if insufficient tokens available +//! +//! # Usage Examples +//! +//! ```rust,ignore +//! use aws_smithy_runtime::client::retries::client_rate_limiter::ClientRateLimiter; +//! +//! // Dynamic mode (default) - automatically adjusts based on throttling +//! let rate_limiter = ClientRateLimiter::default(); +//! +//! // Static mode with fixed rate of 10 requests/second +//! let rate_limiter = ClientRateLimiter::builder() +//! .use_dynamic_rate_adjustment(false) +//! .token_refill_rate(10.0) +//! .build(); +//! ``` #![allow(dead_code)] @@ -27,11 +78,11 @@ impl ClientRateLimiterPartition { } } -const RETRY_COST: f64 = 5.0; -const RETRY_TIMEOUT_COST: f64 = RETRY_COST * 2.0; -const INITIAL_REQUEST_COST: f64 = 1.0; +const DEFAULT_RETRY_COST: f64 = 5.0; +const DEFAULT_RETRY_TIMEOUT_COST: f64 = 10.0; +const DEFAULT_INITIAL_REQUEST_COST: f64 = 1.0; -const MIN_FILL_RATE: f64 = 0.5; +const MIN_FILL_RATE: f64 = 0.5; // Enforced as floor in dynamic mode only; static mode allows any rate const MIN_CAPACITY: f64 = 1.0; const SMOOTH: f64 = 0.8; /// How much to scale back after receiving a throttling response @@ -42,11 +93,17 @@ const SCALE_CONSTANT: f64 = 0.4; /// Rate limiter for adaptive retry. #[derive(Clone, Debug)] pub struct ClientRateLimiter { - pub(crate) inner: Arc>, + pub(crate) inner: Arc>, } #[derive(Debug)] -pub(crate) struct Inner { +pub(crate) enum RateLimiterImpl { + Dynamic(DynamicInner), + Static(StaticInner), +} + +#[derive(Debug)] +pub(crate) struct DynamicInner { /// The rate at which token are replenished. fill_rate: f64, /// The maximum capacity allowed in the token bucket. @@ -56,8 +113,6 @@ pub(crate) struct Inner { /// The last time the token bucket was refilled. last_timestamp: Option, /// Boolean indicating if the token bucket is enabled. - /// The token bucket is initially disabled. - /// When a throttling error is encountered it is enabled. enabled: bool, /// The smoothed rate which tokens are being retrieved. measured_tx_rate: f64, @@ -69,6 +124,34 @@ pub(crate) struct Inner { last_max_rate: f64, /// The last time when the client was throttled. time_of_last_throttle: f64, + /// The cost in tokens for an initial request. + initial_request_cost: f64, + /// The cost in tokens for a retry request. + retry_cost: f64, + /// The cost in tokens for a retry request with timeout. + retry_timeout_cost: f64, +} + +#[derive(Debug)] +pub(crate) struct StaticInner { + /// The rate at which token are replenished. + fill_rate: f64, + /// The maximum capacity allowed in the token bucket. + max_capacity: f64, + /// The current capacity of the token bucket. The minimum this can be is 1.0 + current_capacity: f64, + /// The last time the token bucket was refilled. + last_timestamp: Option, + /// Boolean indicating if the token bucket is enabled. + enabled: bool, + /// The cost in tokens for an initial request. + initial_request_cost: f64, + /// The cost in tokens for a retry request. + retry_cost: f64, + /// The cost in tokens for a retry request with timeout. + retry_timeout_cost: f64, + /// The number of tokens to add to the bucket for each successful request. + success_token_reward: f64, } pub(crate) enum RequestReason { @@ -77,6 +160,70 @@ pub(crate) enum RequestReason { InitialRequest, } +/// Common token bucket operations shared by both client rate limiter implementations +trait ClientRateLimiterOps { + fn enabled(&self) -> bool; + fn fill_rate(&self) -> f64; + fn max_capacity(&self) -> f64; + fn current_capacity(&self) -> f64; + fn current_capacity_mut(&mut self) -> &mut f64; + fn last_timestamp(&self) -> Option; + fn last_timestamp_mut(&mut self) -> &mut Option; + fn initial_request_cost(&self) -> f64; + fn retry_cost(&self) -> f64; + fn retry_timeout_cost(&self) -> f64; + + fn refill(&mut self, seconds_since_unix_epoch: f64) { + if let Some(last_timestamp) = self.last_timestamp() { + let fill_amount = (seconds_since_unix_epoch - last_timestamp) * self.fill_rate(); + *self.current_capacity_mut() = + f64::min(self.max_capacity(), self.current_capacity() + fill_amount); + debug!( + fill_amount, + current_capacity = self.current_capacity(), + max_capacity = self.max_capacity(), + "refilling client rate limiter tokens" + ); + } + *self.last_timestamp_mut() = Some(seconds_since_unix_epoch); + } + + fn acquire_permission( + &mut self, + seconds_since_unix_epoch: f64, + kind: RequestReason, + ) -> Result<(), Duration> { + if !self.enabled() { + return Ok(()); + } + + let amount = match kind { + RequestReason::Retry => self.retry_cost(), + RequestReason::RetryTimeout => self.retry_timeout_cost(), + RequestReason::InitialRequest => self.initial_request_cost(), + }; + + self.refill(seconds_since_unix_epoch); + + let res = if amount > self.current_capacity() { + let sleep_time = (amount - self.current_capacity()) / self.fill_rate(); + debug!( + amount, + current_capacity = self.current_capacity(), + fill_rate = self.fill_rate(), + sleep_time, + "client rate limiter delayed a request" + ); + Err(Duration::from_secs_f64(sleep_time)) + } else { + Ok(()) + }; + + *self.current_capacity_mut() -= amount; + res + } +} + impl Default for ClientRateLimiter { fn default() -> Self { Self::builder().build() @@ -103,37 +250,16 @@ impl ClientRateLimiter { seconds_since_unix_epoch: f64, kind: RequestReason, ) -> Result<(), Duration> { - let mut it = self.inner.lock().unwrap(); + let mut impl_lock = self.inner.lock().unwrap(); - if !it.enabled { - // return early if we haven't encountered a throttling error yet - return Ok(()); + match &mut *impl_lock { + RateLimiterImpl::Dynamic(inner) => { + inner.acquire_permission(seconds_since_unix_epoch, kind) + } + RateLimiterImpl::Static(inner) => { + inner.acquire_permission(seconds_since_unix_epoch, kind) + } } - let amount = match kind { - RequestReason::Retry => RETRY_COST, - RequestReason::RetryTimeout => RETRY_TIMEOUT_COST, - RequestReason::InitialRequest => INITIAL_REQUEST_COST, - }; - - it.refill(seconds_since_unix_epoch); - - let res = if amount > it.current_capacity { - let sleep_time = (amount - it.current_capacity) / it.fill_rate; - debug!( - amount, - it.current_capacity, - it.fill_rate, - sleep_time, - "client rate limiter delayed a request" - ); - - Err(Duration::from_secs_f64(sleep_time)) - } else { - Ok(()) - }; - - it.current_capacity -= amount; - res } pub(crate) fn update_rate_limiter( @@ -141,47 +267,51 @@ impl ClientRateLimiter { seconds_since_unix_epoch: f64, is_throttling_error: bool, ) { - let mut it = self.inner.lock().unwrap(); - it.update_tokens_retrieved_per_second(seconds_since_unix_epoch); + let mut impl_lock = self.inner.lock().unwrap(); - let calculated_rate; - if is_throttling_error { - let rate_to_use = if it.enabled { - f64::min(it.measured_tx_rate, it.fill_rate) - } else { - it.measured_tx_rate - }; - - // The fill_rate is from the token bucket - it.last_max_rate = rate_to_use; - it.calculate_time_window(); - it.time_of_last_throttle = seconds_since_unix_epoch; - calculated_rate = cubic_throttle(rate_to_use); - it.enable_token_bucket(); - } else { - it.calculate_time_window(); - calculated_rate = it.cubic_success(seconds_since_unix_epoch); + match &mut *impl_lock { + RateLimiterImpl::Dynamic(inner) => { + inner.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error) + } + RateLimiterImpl::Static(inner) => inner.update_rate_limiter(is_throttling_error), } - - let new_rate = f64::min(calculated_rate, 2.0 * it.measured_tx_rate); - it.update_bucket_refill_rate(seconds_since_unix_epoch, new_rate); } } -impl Inner { - fn refill(&mut self, seconds_since_unix_epoch: f64) { - if let Some(last_timestamp) = self.last_timestamp { - let fill_amount = (seconds_since_unix_epoch - last_timestamp) * self.fill_rate; - self.current_capacity = - f64::min(self.max_capacity, self.current_capacity + fill_amount); - debug!( - fill_amount, - self.current_capacity, self.max_capacity, "refilling client rate limiter tokens" - ); - } - self.last_timestamp = Some(seconds_since_unix_epoch); +impl ClientRateLimiterOps for DynamicInner { + fn enabled(&self) -> bool { + self.enabled + } + fn fill_rate(&self) -> f64 { + self.fill_rate + } + fn max_capacity(&self) -> f64 { + self.max_capacity + } + fn current_capacity(&self) -> f64 { + self.current_capacity + } + fn current_capacity_mut(&mut self) -> &mut f64 { + &mut self.current_capacity + } + fn last_timestamp(&self) -> Option { + self.last_timestamp + } + fn last_timestamp_mut(&mut self) -> &mut Option { + &mut self.last_timestamp + } + fn initial_request_cost(&self) -> f64 { + self.initial_request_cost + } + fn retry_cost(&self) -> f64 { + self.retry_cost + } + fn retry_timeout_cost(&self) -> f64 { + self.retry_timeout_cost } +} +impl DynamicInner { fn update_bucket_refill_rate(&mut self, seconds_since_unix_epoch: f64, new_fill_rate: f64) { // Refill based on our current rate before we update to the new fill rate. self.refill(seconds_since_unix_epoch); @@ -232,6 +362,82 @@ impl Inner { seconds_since_unix_epoch - self.time_of_last_throttle - self.calculate_time_window(); (SCALE_CONSTANT * dt.powi(3)) + self.last_max_rate } + + fn update_rate_limiter(&mut self, seconds_since_unix_epoch: f64, is_throttling_error: bool) { + self.update_tokens_retrieved_per_second(seconds_since_unix_epoch); + + let calculated_rate; + if is_throttling_error { + let rate_to_use = if self.enabled { + f64::min(self.measured_tx_rate, self.fill_rate) + } else { + self.measured_tx_rate + }; + + self.last_max_rate = rate_to_use; + self.calculate_time_window(); + self.time_of_last_throttle = seconds_since_unix_epoch; + calculated_rate = cubic_throttle(rate_to_use); + self.enable_token_bucket(); + } else { + self.calculate_time_window(); + calculated_rate = self.cubic_success(seconds_since_unix_epoch); + } + + let new_rate = f64::min(calculated_rate, 2.0 * self.measured_tx_rate); + self.update_bucket_refill_rate(seconds_since_unix_epoch, new_rate); + } +} + +impl ClientRateLimiterOps for StaticInner { + fn enabled(&self) -> bool { + self.enabled + } + fn fill_rate(&self) -> f64 { + self.fill_rate + } + fn max_capacity(&self) -> f64 { + self.max_capacity + } + fn current_capacity(&self) -> f64 { + self.current_capacity + } + fn current_capacity_mut(&mut self) -> &mut f64 { + &mut self.current_capacity + } + fn last_timestamp(&self) -> Option { + self.last_timestamp + } + fn last_timestamp_mut(&mut self) -> &mut Option { + &mut self.last_timestamp + } + fn initial_request_cost(&self) -> f64 { + self.initial_request_cost + } + fn retry_cost(&self) -> f64 { + self.retry_cost + } + fn retry_timeout_cost(&self) -> f64 { + self.retry_timeout_cost + } +} + +impl StaticInner { + fn update_rate_limiter(&mut self, is_throttling_error: bool) { + // Enable bucket on first throttle + if is_throttling_error && !self.enabled { + self.enabled = true; + debug!("client rate limiting has been enabled"); + } + + // Add tokens for successful requests (if configured) + if !is_throttling_error && self.success_token_reward > 0.0 && self.enabled { + self.current_capacity = f64::min( + self.max_capacity, + self.current_capacity + self.success_token_reward, + ); + } + } } fn cubic_throttle(rate_to_use: f64) -> f64 { @@ -261,6 +467,16 @@ pub struct ClientRateLimiterBuilder { tokens_retrieved_per_second_at_time_of_last_throttle: Option, ///The last time when the client was throttled. time_of_last_throttle: Option, + ///Whether to use dynamic rate adjustment based on measured throughput. + use_dynamic_rate_adjustment: Option, + ///The cost in tokens for an initial request. + initial_request_cost: Option, + ///The cost in tokens for a retry request. + retry_cost: Option, + ///The cost in tokens for a retry request with timeout. + retry_timeout_cost: Option, + ///The number of tokens to add to the bucket for each successful request. + success_token_reward: Option, } impl ClientRateLimiterBuilder { @@ -394,11 +610,83 @@ impl ClientRateLimiterBuilder { self.time_of_last_throttle = time_of_last_throttle; self } + /// Enable or disable dynamic rate adjustment based on measured throughput. + /// + /// When enabled (default), the rate limiter will dynamically adjust token refill rates + /// based on measured request throughput and throttling responses using cubic scaling. + /// When disabled, the rate limiter uses a fixed token refill rate. + pub fn use_dynamic_rate_adjustment(mut self, enabled: bool) -> Self { + self.set_use_dynamic_rate_adjustment(Some(enabled)); + self + } + /// Enable or disable dynamic rate adjustment based on measured throughput. + pub fn set_use_dynamic_rate_adjustment(&mut self, enabled: Option) -> &mut Self { + self.use_dynamic_rate_adjustment = enabled; + self + } + /// Set the cost in tokens for an initial request. + /// + /// This determines how many tokens are consumed from the bucket for each initial + /// (non-retry) request. Default is 1.0. + pub fn initial_request_cost(mut self, cost: f64) -> Self { + self.set_initial_request_cost(Some(cost)); + self + } + /// Set the cost in tokens for an initial request. + pub fn set_initial_request_cost(&mut self, cost: Option) -> &mut Self { + self.initial_request_cost = cost; + self + } + /// Set the cost in tokens for a retry request. + /// + /// This determines how many tokens are consumed from the bucket for each retry + /// (non-initial, non-timeout) request. Default is 5.0. + pub fn retry_cost(mut self, cost: f64) -> Self { + self.set_retry_cost(Some(cost)); + self + } + /// Set the cost in tokens for a retry request. + pub fn set_retry_cost(&mut self, cost: Option) -> &mut Self { + self.retry_cost = cost; + self + } + /// Set the cost in tokens for a retry request with timeout. + /// + /// This determines how many tokens are consumed from the bucket for each retry + /// that experienced a timeout. Default is 10.0. + pub fn retry_timeout_cost(mut self, cost: f64) -> Self { + self.set_retry_timeout_cost(Some(cost)); + self + } + /// Set the cost in tokens for a retry request with timeout. + pub fn set_retry_timeout_cost(&mut self, cost: Option) -> &mut Self { + self.retry_timeout_cost = cost; + self + } + /// Set the number of tokens to add to the bucket for each successful request. + /// + /// This provides a token reward for successful requests in addition to the time-based refill. + /// Default is 0.0 (no reward). + /// + /// **Important:** Success rewards only apply after the token bucket has been enabled, + /// which happens when the first throttling error is encountered. Before that point, + /// the rate limiter allows all requests through without token consumption or rewards. + pub fn success_token_reward(mut self, reward: f64) -> Self { + self.set_success_token_reward(Some(reward)); + self + } + /// Set the number of tokens to add to the bucket for each successful request. + pub fn set_success_token_reward(&mut self, reward: Option) -> &mut Self { + self.success_token_reward = reward; + self + } /// Build the ClientRateLimiter. pub fn build(self) -> ClientRateLimiter { - ClientRateLimiter { - inner: Arc::new(Mutex::new(Inner { - fill_rate: self.token_refill_rate.unwrap_or_default(), + let use_dynamic = self.use_dynamic_rate_adjustment.unwrap_or(true); + + let inner = if use_dynamic { + RateLimiterImpl::Dynamic(DynamicInner { + fill_rate: self.token_refill_rate.unwrap_or(MIN_FILL_RATE), max_capacity: self.maximum_bucket_capacity.unwrap_or(f64::MAX), current_capacity: self.current_bucket_capacity.unwrap_or_default(), last_timestamp: self.time_of_last_refill, @@ -410,14 +698,41 @@ impl ClientRateLimiterBuilder { .tokens_retrieved_per_second_at_time_of_last_throttle .unwrap_or_default(), time_of_last_throttle: self.time_of_last_throttle.unwrap_or_default(), - })), + initial_request_cost: self + .initial_request_cost + .unwrap_or(DEFAULT_INITIAL_REQUEST_COST), + retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST), + retry_timeout_cost: self + .retry_timeout_cost + .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST), + }) + } else { + RateLimiterImpl::Static(StaticInner { + fill_rate: self.token_refill_rate.unwrap_or(MIN_FILL_RATE), + max_capacity: self.maximum_bucket_capacity.unwrap_or(f64::MAX), + current_capacity: self.current_bucket_capacity.unwrap_or_default(), + last_timestamp: self.time_of_last_refill, + enabled: self.enable_throttling.unwrap_or_default(), + initial_request_cost: self + .initial_request_cost + .unwrap_or(DEFAULT_INITIAL_REQUEST_COST), + retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST), + retry_timeout_cost: self + .retry_timeout_cost + .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST), + success_token_reward: self.success_token_reward.unwrap_or_default(), + }) + }; + + ClientRateLimiter { + inner: Arc::new(Mutex::new(inner)), } } } #[cfg(test)] mod tests { - use super::{cubic_throttle, ClientRateLimiter}; + use super::{cubic_throttle, ClientRateLimiter, RateLimiterImpl}; use crate::client::retries::client_rate_limiter::RequestReason; use approx::assert_relative_eq; use aws_smithy_async::rt::sleep::AsyncSleep; @@ -437,8 +752,16 @@ mod tests { .time_of_last_throttle(1.0) .build(); - rate_limiter.inner.lock().unwrap().calculate_time_window(); - let new_rate = rate_limiter.inner.lock().unwrap().cubic_success(1.0); + let new_rate = { + let mut lock = rate_limiter.inner.lock().unwrap(); + match &mut *lock { + RateLimiterImpl::Dynamic(inner) => { + inner.calculate_time_window(); + inner.cubic_success(1.0) + } + RateLimiterImpl::Static(_) => panic!("Expected dynamic rate limiter"), + } + }; assert_relative_eq!(new_rate, 7.0); } @@ -449,13 +772,20 @@ mod tests { .time_of_last_throttle(0.0) .build(); - assert!( - !rate_limiter.inner.lock().unwrap().enabled, - "rate_limiter should be disabled by default" - ); + let enabled = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Dynamic(inner) => inner.enabled, + RateLimiterImpl::Static(inner) => inner.enabled, + }; + assert!(!enabled, "rate_limiter should be disabled by default"); + rate_limiter.update_rate_limiter(0.0, true); + + let enabled = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Dynamic(inner) => inner.enabled, + RateLimiterImpl::Static(inner) => inner.enabled, + }; assert!( - rate_limiter.inner.lock().unwrap().enabled, + enabled, "rate_limiter should be enabled after throttling error" ); } @@ -507,12 +837,16 @@ mod tests { // was implemented. See for yourself: // https://github.com/aws/aws-sdk-go-v2/blob/844ff45cdc76182229ad098c95bf3f5ab8c20e9f/aws/retry/adaptive_ratelimit_test.go#L97 for attempt in attempts { - rate_limiter.inner.lock().unwrap().calculate_time_window(); - let calculated_rate = rate_limiter - .inner - .lock() - .unwrap() - .cubic_success(attempt.seconds_since_unix_epoch); + let calculated_rate = { + let mut lock = rate_limiter.inner.lock().unwrap(); + match &mut *lock { + RateLimiterImpl::Dynamic(inner) => { + inner.calculate_time_window(); + inner.cubic_success(attempt.seconds_since_unix_epoch) + } + RateLimiterImpl::Static(_) => panic!("Expected dynamic rate limiter"), + } + }; assert_relative_eq!(attempt.expected_calculated_rate, calculated_rate); } @@ -579,15 +913,20 @@ mod tests { // https://github.com/aws/aws-sdk-go-v2/blob/844ff45cdc76182229ad098c95bf3f5ab8c20e9f/aws/retry/adaptive_ratelimit_test.go#L97 let mut calculated_rate = 0.0; for attempt in attempts { - let mut inner = rate_limiter.inner.lock().unwrap(); - inner.calculate_time_window(); - if attempt.throttled { - calculated_rate = cubic_throttle(calculated_rate); - inner.time_of_last_throttle = attempt.seconds_since_unix_epoch; - inner.last_max_rate = calculated_rate; - } else { - calculated_rate = inner.cubic_success(attempt.seconds_since_unix_epoch); - }; + let mut lock = rate_limiter.inner.lock().unwrap(); + match &mut *lock { + RateLimiterImpl::Dynamic(inner) => { + inner.calculate_time_window(); + if attempt.throttled { + calculated_rate = cubic_throttle(calculated_rate); + inner.time_of_last_throttle = attempt.seconds_since_unix_epoch; + inner.last_max_rate = calculated_rate; + } else { + calculated_rate = inner.cubic_success(attempt.seconds_since_unix_epoch); + } + } + RateLimiterImpl::Static(_) => panic!("Expected dynamic rate limiter"), + } assert_relative_eq!(attempt.expected_calculated_rate, calculated_rate); } @@ -718,14 +1057,18 @@ mod tests { ); rate_limiter.update_rate_limiter(attempt.seconds_since_unix_epoch, attempt.throttled); - assert_relative_eq!( - attempt.expected_tokens_retrieved_per_second, - rate_limiter.inner.lock().unwrap().measured_tx_rate - ); - assert_relative_eq!( - attempt.expected_token_refill_rate, - rate_limiter.inner.lock().unwrap().fill_rate - ); + + let inner = rate_limiter.inner.lock().unwrap(); + match &*inner { + RateLimiterImpl::Dynamic(inner) => { + assert_relative_eq!( + attempt.expected_tokens_retrieved_per_second, + inner.measured_tx_rate + ); + assert_relative_eq!(attempt.expected_token_refill_rate, inner.fill_rate); + } + RateLimiterImpl::Static(_) => panic!("Expected dynamic rate limiter"), + } } } @@ -760,13 +1103,187 @@ mod tests { crl.update_rate_limiter(time_source.seconds_since_unix_epoch(), false); } - let inner = crl.inner.lock().unwrap(); - assert!(inner.enabled, "the rate limiter should still be enabled"); + let lock = crl.inner.lock().unwrap(); + let (enabled, last_timestamp) = match &*lock { + RateLimiterImpl::Dynamic(inner) => (inner.enabled, inner.last_timestamp), + RateLimiterImpl::Static(inner) => (inner.enabled, inner.last_timestamp), + }; + assert!(enabled, "the rate limiter should still be enabled"); // Assert that the rate limiter respects the passage of time. assert_relative_eq!( - inner.last_timestamp.unwrap(), + last_timestamp.unwrap(), sleep_impl.total_duration().as_secs_f64(), max_relative = 0.0001 ); } + + #[tokio::test] + async fn test_static_mode_does_not_adjust_rate() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(5.0) + .build(); + + // Enable the rate limiter with a throttling error + rate_limiter.update_rate_limiter(0.0, true); + let initial_rate = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.fill_rate, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(initial_rate, 5.0); + + // Process some successful requests - rate should NOT change in static mode + for i in 1..10 { + rate_limiter.update_rate_limiter(i as f64, false); + let current_rate = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.fill_rate, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(current_rate, 5.0, max_relative = 0.0001); + } + + // Process a throttling error - rate should still NOT change in static mode + rate_limiter.update_rate_limiter(10.0, true); + let final_rate = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.fill_rate, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(final_rate, 5.0, max_relative = 0.0001); + } + + #[tokio::test] + async fn test_custom_initial_request_cost() { + let rate_limiter = ClientRateLimiter::builder() + .initial_request_cost(0.5) + .token_refill_rate(10.0) + .maximum_bucket_capacity(10.0) + .current_bucket_capacity(10.0) + .enable_throttling(true) + .build(); + + // Make an initial request - should consume 0.5 tokens + let result = + rate_limiter.acquire_permission_to_send_a_request(0.0, RequestReason::InitialRequest); + assert!(result.is_ok()); + + let capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Dynamic(inner) => inner.current_capacity, + RateLimiterImpl::Static(inner) => inner.current_capacity, + }; + assert_relative_eq!(capacity, 9.5); + } + + #[tokio::test] + async fn test_success_token_reward_adds_tokens() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(5.0) + .maximum_bucket_capacity(10.0) + .current_bucket_capacity(5.0) + .success_token_reward(1.0) + .enable_throttling(true) + .build(); + + let initial_capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(initial_capacity, 5.0); + + // Successful request should add 1.0 token + rate_limiter.update_rate_limiter(0.0, false); + + let new_capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(new_capacity, 6.0); + } + + #[tokio::test] + async fn test_success_token_reward_respects_max_capacity() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(5.0) + .maximum_bucket_capacity(10.0) + .current_bucket_capacity(9.8) + .success_token_reward(1.0) + .enable_throttling(true) + .build(); + + // Successful request should add 1.0 token but cap at max_capacity + rate_limiter.update_rate_limiter(0.0, false); + + let capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(capacity, 10.0); + } + + #[tokio::test] + async fn test_success_token_reward_does_not_apply_on_error() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(5.0) + .maximum_bucket_capacity(10.0) + .current_bucket_capacity(5.0) + .success_token_reward(1.0) + .enable_throttling(true) + .build(); + + let initial_capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(initial_capacity, 5.0); + + // Any error (including throttling) should NOT add tokens + rate_limiter.update_rate_limiter(0.0, true); + + let capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(capacity, 5.0); + } + + #[tokio::test] + async fn test_success_token_reward_only_when_enabled() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(5.0) + .maximum_bucket_capacity(10.0) + .current_bucket_capacity(5.0) + .success_token_reward(1.0) + .enable_throttling(false) + .build(); + + // Successful request should NOT add tokens when rate limiter is disabled + rate_limiter.update_rate_limiter(0.0, false); + + let capacity = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.current_capacity, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(capacity, 5.0); + } + + #[tokio::test] + async fn test_static_mode_allows_very_low_refill_rate() { + let rate_limiter = ClientRateLimiter::builder() + .use_dynamic_rate_adjustment(false) + .token_refill_rate(0.01) + .maximum_bucket_capacity(1.0) + .current_bucket_capacity(1.0) + .enable_throttling(true) + .build(); + + // Verify the low rate was accepted + let fill_rate = match &*rate_limiter.inner.lock().unwrap() { + RateLimiterImpl::Static(inner) => inner.fill_rate, + RateLimiterImpl::Dynamic(_) => panic!("Expected static rate limiter"), + }; + assert_relative_eq!(fill_rate, 0.01); + } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs index c8700ee275c..4aece478047 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs @@ -51,14 +51,28 @@ impl StandardRetryStrategy { Default::default() } - fn release_retry_permit(&self) -> ReleaseResult { + fn release_retry_permit(&self, token_bucket: &TokenBucket) -> ReleaseResult { let mut retry_permit = self.retry_permit.lock().unwrap(); match retry_permit.take() { Some(p) => { - drop(p); + // Retry succeeded: reward success and forget permit if configured, otherwise release permit back + if token_bucket.success_reward() > 0.0 { + token_bucket.reward_success(); + p.forget(); + } else { + drop(p); // Original behavior - release back to bucket + } APermitWasReleased } - None => NoPermitWasReleased, + None => { + // First-attempt success: reward success or regenerate token + if token_bucket.success_reward() > 0.0 { + token_bucket.reward_success(); + } else { + token_bucket.regenerate_a_token(); + } + NoPermitWasReleased + } } } @@ -210,15 +224,9 @@ impl RetryStrategy for StandardRetryStrategy { .unwrap_or(false); update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error); - // on success release any retry quota held by previous attempts + // on success release any retry quota held by previous attempts, reward success when indicated if !ctx.is_failed() { - if let NoPermitWasReleased = self.release_retry_permit() { - // In the event that there was no retry permit to release, we generate new - // permits from nothing. We do this to make up for permits we had to "forget". - // Otherwise, repeated retries would empty the bucket and nothing could fill it - // back up again. - token_bucket.regenerate_a_token(); - } + self.release_retry_permit(token_bucket); } // end bookkeeping @@ -313,7 +321,7 @@ fn check_rate_limiter_for_delay( None } -fn calculate_exponential_backoff( +pub(super) fn calculate_exponential_backoff( base: f64, initial_backoff: f64, retry_attempts: u32, @@ -338,7 +346,7 @@ fn calculate_exponential_backoff( result.mul_f64(base) } -fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 { +pub(super) fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 { let request_time = runtime_components .time_source() .expect("time source required for retries"); diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs index 50a6fe8a343..23250c41b3b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs @@ -3,16 +3,27 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_async::time::SharedTimeSource; use aws_smithy_types::config_bag::{Storable, StoreReplace}; use aws_smithy_types::retry::ErrorKind; +use std::fmt; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering; use std::sync::Arc; +use std::time::{Duration, SystemTime}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use tracing::trace; const DEFAULT_CAPACITY: usize = 500; -const RETRY_COST: u32 = 5; -const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2; +// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911. +// Therefore, we will enforce a value lower than that to ensure behavior is +// identical across platforms. +// This also allows room for slight bucket overfill in the case where a bucket +// is at maximum capacity and another thread drops a permit it was holding. +const MAXIMUM_CAPACITY: usize = 500_000_000; +const DEFAULT_RETRY_COST: u32 = 5; +const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2; const PERMIT_REGENERATION_AMOUNT: usize = 1; +const DEFAULT_SUCCESS_REWARD: f32 = 0.0; /// Token bucket used for standard and adaptive retry. #[derive(Clone, Debug)] @@ -21,6 +32,50 @@ pub struct TokenBucket { max_permits: usize, timeout_retry_cost: u32, retry_cost: u32, + success_reward: f32, + fractional_tokens: Arc, + refill_rate: f32, + time_source: SharedTimeSource, + creation_time: SystemTime, + last_refill_age_secs: Arc, +} + +struct AtomicF32 { + storage: AtomicU32, +} +impl AtomicF32 { + fn new(value: f32) -> Self { + let as_u32 = value.to_bits(); + Self { + storage: AtomicU32::new(as_u32), + } + } + fn store(&self, value: f32) { + let as_u32 = value.to_bits(); + self.storage.store(as_u32, Ordering::Relaxed) + } + fn load(&self) -> f32 { + let as_u32 = self.storage.load(Ordering::Relaxed); + f32::from_bits(as_u32) + } +} + +impl fmt::Debug for AtomicF32 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Use debug_struct, debug_tuple, or write! for formatting + f.debug_struct("AtomicF32") + .field("value", &self.load()) + .finish() + } +} + +impl Clone for AtomicF32 { + fn clone(&self) -> Self { + // Manually clone each field + AtomicF32 { + storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)), + } + } } impl Storable for TokenBucket { @@ -29,11 +84,18 @@ impl Storable for TokenBucket { impl Default for TokenBucket { fn default() -> Self { + let time_source = SharedTimeSource::default(); Self { semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)), max_permits: DEFAULT_CAPACITY, - timeout_retry_cost: RETRY_TIMEOUT_COST, - retry_cost: RETRY_COST, + timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST, + retry_cost: DEFAULT_RETRY_COST, + success_reward: DEFAULT_SUCCESS_REWARD, + fractional_tokens: Arc::new(AtomicF32::new(0.0)), + refill_rate: 0.0, + time_source: time_source.clone(), + creation_time: time_source.now(), + last_refill_age_secs: Arc::new(AtomicU32::new(0)), } } } @@ -41,20 +103,30 @@ impl Default for TokenBucket { impl TokenBucket { /// Creates a new `TokenBucket` with the given initial quota. pub fn new(initial_quota: usize) -> Self { + let time_source = SharedTimeSource::default(); Self { semaphore: Arc::new(Semaphore::new(initial_quota)), max_permits: initial_quota, + time_source: time_source.clone(), + creation_time: time_source.now(), ..Default::default() } } /// A token bucket with unlimited capacity that allows retries at no cost. pub fn unlimited() -> Self { + let time_source = SharedTimeSource::default(); Self { - semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)), - max_permits: Semaphore::MAX_PERMITS, + semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)), + max_permits: MAXIMUM_CAPACITY, timeout_retry_cost: 0, retry_cost: 0, + success_reward: 0.0, + fractional_tokens: Arc::new(AtomicF32::new(0.0)), + refill_rate: 0.0, + time_source: time_source.clone(), + creation_time: time_source.now(), + last_refill_age_secs: Arc::new(AtomicU32::new(0)), } } @@ -64,6 +136,11 @@ impl TokenBucket { } pub(crate) fn acquire(&self, err: &ErrorKind) -> Option { + // Add time-based tokens to fractional accumulator + self.refill_tokens_based_on_time(); + // Convert accumulated fractional tokens to whole tokens + self.convert_fractional_tokens(); + let retry_cost = if err == &ErrorKind::TransientError { self.timeout_retry_cost } else { @@ -76,16 +153,114 @@ impl TokenBucket { .ok() } + pub(crate) fn success_reward(&self) -> f32 { + self.success_reward + } + pub(crate) fn regenerate_a_token(&self) { - if self.semaphore.available_permits() < self.max_permits { - trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket"); - self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT) + self.add_permits(PERMIT_REGENERATION_AMOUNT); + } + + /// Converts accumulated fractional tokens to whole tokens and adds them as permits. + /// Stores the remaining fractional amount back. + /// This is shared by both time-based refill and success rewards. + #[inline] + fn convert_fractional_tokens(&self) { + let mut calc_fractional_tokens = self.fractional_tokens.load(); + // Verify that fractional tokens have not become corrupted - if they have, reset to zero + if !calc_fractional_tokens.is_finite() { + tracing::error!( + "Fractional tokens corrupted to: {}, resetting to 0.0", + calc_fractional_tokens + ); + self.fractional_tokens.store(0.0); + return; + } + + let full_tokens_accumulated = calc_fractional_tokens.floor(); + if full_tokens_accumulated >= 1.0 { + self.add_permits(full_tokens_accumulated as usize); + calc_fractional_tokens -= full_tokens_accumulated; } + // Always store the updated fractional tokens back, even if no conversion happened + self.fractional_tokens.store(calc_fractional_tokens); } - #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))] - pub(crate) fn available_permits(&self) -> usize { - self.semaphore.available_permits() + /// Refills tokens based on elapsed time since last refill. + /// This method implements lazy evaluation - tokens are only calculated when accessed. + /// Uses a single compare-and-swap to ensure only one thread processes each time window. + #[inline] + fn refill_tokens_based_on_time(&self) { + if self.refill_rate > 0.0 { + let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed); + + // Get current time from TimeSource and calculate current age + let current_time = self.time_source.now(); + let current_age_secs = current_time + .duration_since(self.creation_time) + .unwrap_or(Duration::ZERO) + .as_secs() as u32; + + // Early exit if no time elapsed - most threads take this path + if current_age_secs == last_refill_secs { + return; + } + // Try to atomically claim this time window with a single CAS + // If we lose, another thread is handling the refill, so we can exit + if self + .last_refill_age_secs + .compare_exchange( + last_refill_secs, + current_age_secs, + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_err() + { + // Another thread claimed this time window, we're done + return; + } + + // We won the CAS - we're responsible for adding tokens for this time window + let current_fractional = self.fractional_tokens.load(); + let max_fractional = self.max_permits as f32; + + // Skip token addition if already at cap + if current_fractional >= max_fractional { + return; + } + + let elapsed_secs = current_age_secs - last_refill_secs; + let tokens_to_add = elapsed_secs as f32 * self.refill_rate; + + // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth + let new_fractional = (current_fractional + tokens_to_add).min(max_fractional); + self.fractional_tokens.store(new_fractional); + } + } + + #[inline] + pub(crate) fn reward_success(&self) { + if self.success_reward > 0.0 { + let current = self.fractional_tokens.load(); + let max_fractional = self.max_permits as f32; + // Early exit if already at cap - no point calculating + if current >= max_fractional { + return; + } + // Cap fractional tokens at max_permits to prevent unbounded growth + let new_fractional = (current + self.success_reward).min(max_fractional); + self.fractional_tokens.store(new_fractional); + } + } + + pub(crate) fn add_permits(&self, amount: usize) { + let available = self.semaphore.available_permits(); + if available >= self.max_permits { + return; + } + self.semaphore + .add_permits(amount.min(self.max_permits - available)); } /// Returns true if the token bucket is full, false otherwise @@ -97,6 +272,12 @@ impl TokenBucket { pub fn is_empty(&self) -> bool { self.semaphore.available_permits() == 0 } + + #[allow(dead_code)] // only used in tests + #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))] + pub(crate) fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } } /// Builder for constructing a `TokenBucket`. @@ -105,6 +286,9 @@ pub struct TokenBucketBuilder { capacity: Option, retry_cost: Option, timeout_retry_cost: Option, + success_reward: Option, + refill_rate: Option, + time_source: Option, } impl TokenBucketBuilder { @@ -114,7 +298,10 @@ impl TokenBucketBuilder { } /// Sets the maximum bucket capacity for the builder. - pub fn capacity(mut self, capacity: usize) -> Self { + pub fn capacity(mut self, mut capacity: usize) -> Self { + if capacity > MAXIMUM_CAPACITY { + capacity = MAXIMUM_CAPACITY; + } self.capacity = Some(capacity); self } @@ -131,13 +318,49 @@ impl TokenBucketBuilder { self } + /// Sets the reward for any successful request for the builder. + pub fn success_reward(mut self, reward: f32) -> Self { + self.success_reward = Some(reward); + self + } + + /// Sets the refill rate (tokens per second) for time-based token regeneration. + /// + /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration. + /// Non-finite values (NaN, infinity) are treated as 0.0. + pub fn refill_rate(mut self, rate: f32) -> Self { + let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 }; + self.refill_rate = Some(validated_rate); + self + } + + /// Sets the time source for the token bucket. + /// + /// If not set, defaults to `SystemTimeSource`. + pub fn time_source( + mut self, + time_source: impl aws_smithy_async::time::TimeSource + 'static, + ) -> Self { + self.time_source = Some(SharedTimeSource::new(time_source)); + self + } + /// Builds a `TokenBucket`. pub fn build(self) -> TokenBucket { + let time_source = self.time_source.unwrap_or_default(); TokenBucket { semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))), max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY), - retry_cost: self.retry_cost.unwrap_or(RETRY_COST), - timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST), + retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST), + timeout_retry_cost: self + .timeout_retry_cost + .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST), + success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD), + fractional_tokens: Arc::new(AtomicF32::new(0.0)), + refill_rate: self.refill_rate.unwrap_or(0.0), + time_source: time_source.clone(), + creation_time: time_source.now(), + last_refill_age_secs: Arc::new(AtomicU32::new(0)), } } } @@ -145,6 +368,7 @@ impl TokenBucketBuilder { #[cfg(test)] mod tests { use super::*; + use aws_smithy_async::time::TimeSource; #[test] fn test_unlimited_token_bucket() { @@ -155,7 +379,7 @@ mod tests { assert!(bucket.acquire(&ErrorKind::TransientError).is_some()); // Should have maximum capacity - assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS); + assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY); // Should have zero retry costs assert_eq!(bucket.retry_cost, 0); @@ -168,10 +392,7 @@ mod tests { assert!(permit.is_some()); permits.push(permit); // Available permits should stay constant - assert_eq!( - tokio::sync::Semaphore::MAX_PERMITS, - bucket.semaphore.available_permits() - ); + assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits()); } } @@ -194,4 +415,541 @@ mod tests { // Verify next acquisition fails assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none()); } + + #[test] + fn test_fractional_tokens_accumulate_and_convert() { + let bucket = TokenBucket::builder() + .capacity(10) + .success_reward(0.4) + .build(); + + // acquire 10 tokens to bring capacity below max so we can test accumulation + let _hold_permit = bucket.acquire(&ErrorKind::TransientError); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // First success: 0.4 fractional tokens + bucket.reward_success(); + bucket.convert_fractional_tokens(); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // Second success: 0.8 fractional tokens + bucket.reward_success(); + bucket.convert_fractional_tokens(); + assert_eq!(bucket.semaphore.available_permits(), 0); + + // Third success: 1.2 fractional tokens -> 1 full token added + bucket.reward_success(); + bucket.convert_fractional_tokens(); + assert_eq!(bucket.semaphore.available_permits(), 1); + } + + #[test] + fn test_fractional_tokens_respect_max_capacity() { + let bucket = TokenBucket::builder() + .capacity(10) + .success_reward(2.0) + .build(); + + for _ in 0..20 { + bucket.reward_success(); + } + + assert!(bucket.semaphore.available_permits() == 10); + } + + #[test] + fn test_convert_fractional_tokens() { + // (input, expected_permits_added, expected_remaining) + let test_cases = [ + (0.7, 0, 0.7), + (1.0, 1, 0.0), + (2.3, 2, 0.3), + (5.8, 5, 0.8), + (10.0, 10, 0.0), + // verify that if fractional permits are corrupted, we reset to 0 gracefully + (f32::NAN, 0, 0.0), + (f32::INFINITY, 0, 0.0), + ]; + + for (input, expected_permits, expected_remaining) in test_cases { + let bucket = TokenBucket::builder().capacity(10).build(); + let _hold_permit = bucket.acquire(&ErrorKind::TransientError); + let initial = bucket.semaphore.available_permits(); + + bucket.fractional_tokens.store(input); + bucket.convert_fractional_tokens(); + + assert_eq!( + bucket.semaphore.available_permits() - initial, + expected_permits + ); + assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001); + } + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_builder_with_custom_values() { + let bucket = TokenBucket::builder() + .capacity(100) + .retry_cost(10) + .timeout_retry_cost(20) + .success_reward(0.5) + .refill_rate(2.5) + .build(); + + assert_eq!(bucket.max_permits, 100); + assert_eq!(bucket.retry_cost, 10); + assert_eq!(bucket.timeout_retry_cost, 20); + assert_eq!(bucket.success_reward, 0.5); + assert_eq!(bucket.refill_rate, 2.5); + } + + #[test] + fn test_builder_refill_rate_validation() { + // Test negative values are clamped to 0.0 + let bucket = TokenBucket::builder().refill_rate(-5.0).build(); + assert_eq!(bucket.refill_rate, 0.0); + + // Test valid positive value + let bucket = TokenBucket::builder().refill_rate(1.5).build(); + assert_eq!(bucket.refill_rate, 1.5); + + // Test zero is valid + let bucket = TokenBucket::builder().refill_rate(0.0).build(); + assert_eq!(bucket.refill_rate, 0.0); + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_builder_custom_time_source() { + use aws_smithy_async::test_util::ManualTimeSource; + use std::time::UNIX_EPOCH; + + // Test that TokenBucket uses provided TimeSource when specified via builder + let manual_time = ManualTimeSource::new(UNIX_EPOCH); + let bucket = TokenBucket::builder() + .capacity(100) + .refill_rate(1.0) + .time_source(manual_time.clone()) + .build(); + + // Verify the bucket uses the manual time source + assert_eq!(bucket.creation_time, UNIX_EPOCH); + + // Consume all tokens to test refill from empty state + let _permits = bucket.semaphore.try_acquire_many(100).unwrap(); + assert_eq!(bucket.available_permits(), 0); + + // Advance time and verify tokens are added based on manual time + manual_time.advance(Duration::from_secs(5)); + + bucket.refill_tokens_based_on_time(); + bucket.convert_fractional_tokens(); + + // Should have 5 tokens (5 seconds * 1 token/sec) + assert_eq!(bucket.available_permits(), 5); + } + + #[test] + fn test_atomicf32_f32_to_bits_conversion_correctness() { + // This is the core functionality + let test_values = vec![ + 0.0, + -0.0, + 1.0, + -1.0, + f32::INFINITY, + f32::NEG_INFINITY, + f32::NAN, + f32::MIN, + f32::MAX, + f32::MIN_POSITIVE, + f32::EPSILON, + std::f32::consts::PI, + std::f32::consts::E, + // Test values that could expose bit manipulation bugs + 1.23456789e-38, // Very small normal number + 1.23456789e38, // Very large number (within f32 range) + 1.1754944e-38, // Near MIN_POSITIVE for f32 + ]; + + for &expected in &test_values { + let atomic = AtomicF32::new(expected); + let actual = atomic.load(); + + // For NaN, we can't use == but must check bit patterns + if expected.is_nan() { + assert!(actual.is_nan(), "Expected NaN, got {}", actual); + // Different NaN bit patterns should be preserved exactly + assert_eq!(expected.to_bits(), actual.to_bits()); + } else { + assert_eq!(expected.to_bits(), actual.to_bits()); + } + } + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_atomicf32_store_load_preserves_exact_bits() { + let atomic = AtomicF32::new(0.0); + + // Test that store/load cycle preserves EXACT bit patterns + // This would catch bugs in the to_bits/from_bits conversion + let critical_bit_patterns = vec![ + 0x00000000u32, // +0.0 + 0x80000000u32, // -0.0 + 0x7F800000u32, // +infinity + 0xFF800000u32, // -infinity + 0x7FC00000u32, // Quiet NaN + 0x7FA00000u32, // Signaling NaN + 0x00000001u32, // Smallest positive subnormal + 0x007FFFFFu32, // Largest subnormal + 0x00800000u32, // Smallest positive normal (MIN_POSITIVE) + ]; + + for &expected_bits in &critical_bit_patterns { + let expected_f32 = f32::from_bits(expected_bits); + atomic.store(expected_f32); + let loaded_f32 = atomic.load(); + let actual_bits = loaded_f32.to_bits(); + + assert_eq!(expected_bits, actual_bits); + } + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_atomicf32_concurrent_store_load_safety() { + use std::sync::Arc; + use std::thread; + + let atomic = Arc::new(AtomicF32::new(0.0)); + let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let mut handles = Vec::new(); + + // Start multiple threads that continuously write different values + for &value in &test_values { + let atomic_clone = Arc::clone(&atomic); + let handle = thread::spawn(move || { + for _ in 0..1000 { + atomic_clone.store(value); + } + }); + handles.push(handle); + } + + // Start a reader thread that continuously reads + let atomic_reader = Arc::clone(&atomic); + let reader_handle = thread::spawn(move || { + let mut readings = Vec::new(); + for _ in 0..5000 { + let value = atomic_reader.load(); + readings.push(value); + } + readings + }); + + // Wait for all writers to complete + for handle in handles { + handle.join().expect("Writer thread panicked"); + } + + let readings = reader_handle.join().expect("Reader thread panicked"); + + // Verify that all read values are valid (one of the written values) + // This tests that there's no data corruption from concurrent access + for &reading in &readings { + assert!(test_values.contains(&reading) || reading == 0.0); + + // More importantly, verify the reading is a valid f32 + // (not corrupted bits that happen to parse as valid) + assert!( + reading.is_finite() || reading == 0.0, + "Corrupted reading detected" + ); + } + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_atomicf32_stress_concurrent_access() { + use std::sync::{Arc, Barrier}; + use std::thread; + + let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let atomic = Arc::new(AtomicF32::new(0.0)); + let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads + let mut handles = Vec::new(); + + // Launch threads that all start simultaneously + for i in 0..10 { + let atomic_clone = Arc::clone(&atomic); + let barrier_clone = Arc::clone(&barrier); + let handle = thread::spawn(move || { + barrier_clone.wait(); // All threads start at same time + + // Tight loop increases chance of race conditions + for _ in 0..10000 { + let value = i as f32; + atomic_clone.store(value); + let loaded = atomic_clone.load(); + // Verify no corruption occurred + assert!(loaded >= 0.0 && loaded <= 9.0); + assert!( + expected_values.contains(&loaded), + "Got unexpected value: {}, expected one of {:?}", + loaded, + expected_values + ); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_atomicf32_integration_with_token_bucket_usage() { + let atomic = AtomicF32::new(0.0); + let success_reward = 0.3; + let iterations = 5; + + // Accumulate fractional tokens + for _ in 1..=iterations { + let current = atomic.load(); + atomic.store(current + success_reward); + } + + let accumulated = atomic.load(); + let expected_total = iterations as f32 * success_reward; // 1.5 + + // Test the floor() operation pattern + let full_tokens = accumulated.floor(); + atomic.store(accumulated - full_tokens); + let remaining = atomic.load(); + + // These assertions should be general: + assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc. + assert!(remaining >= 0.0 && remaining < 1.0); + assert_eq!(remaining, expected_total - expected_total.floor()); + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_atomicf32_clone_creates_independent_copy() { + let original = AtomicF32::new(123.456); + let cloned = original.clone(); + + // Verify they start with the same value + assert_eq!(original.load(), cloned.load()); + + // Verify they're independent - modifying one doesn't affect the other + original.store(999.0); + assert_eq!( + cloned.load(), + 123.456, + "Clone should be unaffected by original changes" + ); + assert_eq!(original.load(), 999.0, "Original should have new value"); + } + + #[test] + fn test_combined_time_and_success_rewards() { + use aws_smithy_async::test_util::ManualTimeSource; + use std::time::UNIX_EPOCH; + + let time_source = ManualTimeSource::new(UNIX_EPOCH); + let bucket = TokenBucket { + refill_rate: 1.0, + success_reward: 0.5, + time_source: time_source.clone().into(), + creation_time: time_source.now(), + semaphore: Arc::new(Semaphore::new(0)), + max_permits: 100, + ..Default::default() + }; + + // Add success rewards: 2 * 0.5 = 1.0 token + bucket.reward_success(); + bucket.reward_success(); + + // Advance time by 2 seconds + time_source.advance(Duration::from_secs(2)); + + // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens + // Total: 1.0 + 2.0 = 3.0 tokens + bucket.refill_tokens_based_on_time(); + bucket.convert_fractional_tokens(); + + assert_eq!(bucket.available_permits(), 3); + assert!(bucket.fractional_tokens.load().abs() < 0.0001); + } + + #[test] + fn test_refill_rates() { + use aws_smithy_async::test_util::ManualTimeSource; + use std::time::UNIX_EPOCH; + // (refill_rate, elapsed_secs, expected_permits, expected_fractional) + let test_cases = [ + (10.0, 2, 20, 0.0), // Basic: 2 sec * 10 tokens/sec = 20 tokens + (0.001, 1100, 1, 0.1), // Small: 1100 * 0.001 = 1.1 tokens + (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens + (0.001, 1200, 1, 0.2), // 1200 * 0.001 = 1.2 tokens + (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens + (0.001, 500, 0, 0.5), // Fractional only: 500 * 0.001 = 0.5 tokens + ]; + + for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases { + let time_source = ManualTimeSource::new(UNIX_EPOCH); + let bucket = TokenBucket { + refill_rate, + time_source: time_source.clone().into(), + creation_time: time_source.now(), + semaphore: Arc::new(Semaphore::new(0)), + max_permits: 100, + ..Default::default() + }; + + // Advance time by the specified duration + time_source.advance(Duration::from_secs(elapsed_secs)); + + bucket.refill_tokens_based_on_time(); + bucket.convert_fractional_tokens(); + + assert_eq!( + bucket.available_permits(), + expected_permits, + "Rate {}: After {}s expected {} permits", + refill_rate, + elapsed_secs, + expected_permits + ); + assert!( + (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001, + "Rate {}: After {}s expected {} fractional, got {}", + refill_rate, + elapsed_secs, + expected_fractional, + bucket.fractional_tokens.load() + ); + } + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_rewards_capped_at_max_capacity() { + use aws_smithy_async::test_util::ManualTimeSource; + use std::time::UNIX_EPOCH; + + let time_source = ManualTimeSource::new(UNIX_EPOCH); + let bucket = TokenBucket { + refill_rate: 50.0, + success_reward: 2.0, + time_source: time_source.clone().into(), + creation_time: time_source.now(), + semaphore: Arc::new(Semaphore::new(5)), + max_permits: 10, + ..Default::default() + }; + + // Add success rewards: 50 * 2.0 = 100 tokens (without cap) + for _ in 0..50 { + bucket.reward_success(); + } + + // Fractional tokens capped at 10 from success rewards + assert_eq!(bucket.fractional_tokens.load(), 10.0); + + // Advance time by 100 seconds + time_source.advance(Duration::from_secs(100)); + + // Time-based refill: 100 * 50 = 5000 tokens (without cap) + // But fractional is already at 10, so it stays at 10 + bucket.refill_tokens_based_on_time(); + + // Fractional tokens should be capped at max_permits (10) + assert_eq!( + bucket.fractional_tokens.load(), + 10.0, + "Fractional tokens should be capped at max_permits" + ); + // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10) + bucket.convert_fractional_tokens(); + assert_eq!(bucket.available_permits(), 10); + } + + #[cfg(any(feature = "test-util", feature = "legacy-test-util"))] + #[test] + fn test_concurrent_time_based_refill_no_over_generation() { + use aws_smithy_async::test_util::ManualTimeSource; + use std::sync::{Arc, Barrier}; + use std::thread; + use std::time::UNIX_EPOCH; + + let time_source = ManualTimeSource::new(UNIX_EPOCH); + // Create bucket with 1 token/sec refill + let bucket = Arc::new(TokenBucket { + refill_rate: 1.0, + time_source: time_source.clone().into(), + creation_time: time_source.now(), + semaphore: Arc::new(Semaphore::new(0)), + max_permits: 100, + ..Default::default() + }); + + // Advance time by 10 seconds + time_source.advance(Duration::from_secs(10)); + + // Launch 100 threads that all try to refill simultaneously + let barrier = Arc::new(Barrier::new(100)); + let mut handles = Vec::new(); + + for _ in 0..100 { + let bucket_clone1 = Arc::clone(&bucket); + let barrier_clone1 = Arc::clone(&barrier); + let bucket_clone2 = Arc::clone(&bucket); + let barrier_clone2 = Arc::clone(&barrier); + + let handle1 = thread::spawn(move || { + // Wait for all threads to be ready + barrier_clone1.wait(); + + // All threads call refill at the same time + bucket_clone1.refill_tokens_based_on_time(); + }); + + let handle2 = thread::spawn(move || { + // Wait for all threads to be ready + barrier_clone2.wait(); + + // All threads call refill at the same time + bucket_clone2.refill_tokens_based_on_time(); + }); + handles.push(handle1); + handles.push(handle2); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } + + // Convert fractional tokens to whole tokens + bucket.convert_fractional_tokens(); + + // Should have exactly 10 tokens (10 seconds * 1 token/sec) + // Not 1000 tokens (100 threads * 10 tokens each) + assert_eq!( + bucket.available_permits(), + 10, + "Only one thread should have added tokens, not all 100" + ); + + // Fractional should be 0 after conversion + assert!(bucket.fractional_tokens.load().abs() < 0.0001); + } }