Skip to content

Commit

Permalink
Refactor to allow reinstating resource manager
Browse files Browse the repository at this point in the history
  • Loading branch information
sd109 committed Feb 7, 2025
1 parent f809321 commit d9dfd2d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 58 deletions.
34 changes: 13 additions & 21 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use bytes::Bytes;
use cached::{proc_macro::io_cached, stores::DiskCacheBuilder};

use std::sync::Arc;
use tokio::sync::SemaphorePermit;
use tower::Layer;
use tower::ServiceBuilder;
use tower_http::normalize_path::NormalizePathLayer;
Expand Down Expand Up @@ -168,11 +167,7 @@ async fn schema() -> &'static str {
///
/// * `client`: S3 client object
/// * `request_data`: RequestData object for the request
#[tracing::instrument(
level = "DEBUG",
// skip(client, request_data, resource_manager, mem_permits)
skip(client, request_data)
)]
#[tracing::instrument(level = "DEBUG", skip(client, request_data, resource_manager))]
#[io_cached(
map_error = r##"|e| ActiveStorageError::CacheError{ error: format!("{:?}", e) }"##,
disk = true,
Expand All @@ -183,11 +178,15 @@ async fn schema() -> &'static str {
async fn download_object<'a>(
client: &s3_client::S3Client,
request_data: &models::RequestData,
// resource_manager: &'a ResourceManager,
// mem_permits: &mut Option<SemaphorePermit<'a>>,
resource_manager: &ResourceManager,
) -> Result<Bytes, ActiveStorageError> {
// If we're given a size in the request data then use this to
// get an initial guess at the required memory resources.
let memory = request_data.size.unwrap_or(0);
let mut mem_permits = resource_manager.memory(memory).await?;

let range = s3_client::get_range(request_data.offset, request_data.size);
// let _conn_permits = resource_manager.s3_connection().await?;
let _conn_permits = resource_manager.s3_connection().await?;

// Increment the prometheus metric for cache misses
LOCAL_CACHE_MISSES.with_label_values(&["disk"]).inc();
Expand All @@ -197,8 +196,8 @@ async fn download_object<'a>(
&request_data.bucket,
&request_data.object,
range,
// resource_manager,
// mem_permits,
resource_manager,
&mut mem_permits,
)
.await
}
Expand All @@ -222,8 +221,6 @@ async fn operation_handler<T: operation::Operation>(
auth: Option<TypedHeader<Authorization<Basic>>>,
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
) -> Result<models::Response, ActiveStorageError> {
let memory = request_data.size.unwrap_or(0);
let mut _mem_permits = state.resource_manager.memory(memory).await?;
let credentials = if let Some(TypedHeader(auth)) = auth {
s3_client::S3Credentials::access_key(auth.username(), auth.password())
} else {
Expand All @@ -234,14 +231,9 @@ async fn operation_handler<T: operation::Operation>(
.get(&request_data.source, credentials)
.instrument(tracing::Span::current())
.await;
let data = download_object(
&s3_client,
&request_data,
// &state.resource_manager,
// &mut _mem_permits,
)
.instrument(tracing::Span::current())
.await?;
let data = download_object(&s3_client, &request_data, &state.resource_manager)
.instrument(tracing::Span::current())
.await?;
// All remaining work is synchronous. If the use_rayon argument was specified, delegate to the
// Rayon thread pool. Otherwise, execute as normal using Tokio.
if state.args.use_rayon {
Expand Down
43 changes: 6 additions & 37 deletions src/s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ impl S3Client {
bucket: &str,
key: &str,
range: Option<String>,
// resource_manager: &'a ResourceManager,
// mem_permits: &mut Option<SemaphorePermit<'a>>,
resource_manager: &'a ResourceManager,
mem_permits: &mut Option<SemaphorePermit<'a>>,
) -> Result<Bytes, ActiveStorageError> {
let mut response = self
.client
Expand All @@ -175,9 +175,10 @@ impl S3Client {
.try_into()?;

// FIXME: how to account for compressed data?
// if mem_permits.is_none() {
// *mem_permits = resource_manager.memory(content_length).await?;
// };
if mem_permits.is_none() || mem_permits.as_ref().unwrap().num_permits() == 0 {
*mem_permits = resource_manager.memory(content_length).await?;
};

// The data returned by the S3 client does not have any alignment guarantees. In order to
// reinterpret the data as an array of numbers with a higher alignment than 1, we need to
// return the data in Bytes object in which the underlying data has a higher alignment.
Expand Down Expand Up @@ -225,40 +226,8 @@ pub fn get_range(offset: Option<usize>, size: Option<usize>) -> Option<String> {
#[cfg(test)]
mod tests {
use super::*;
use cached::{proc_macro::io_cached, stores::DiskCacheBuilder};
use url::Url;

// #[cached(
// ty = "SizedCache<String, String>",
// create = "{ SizedCache::with_size(100) }",
// convert = r#"{ format!("{}{}", a, b) }"#
// )]
// fn cache_test(a: &str, b: &str) -> String {
// format!("{} - {}", a, b)
// }

#[io_cached(
map_error = r##"|e| ActiveStorageError::CacheError{ error: format!("{:?}", e) }"##,
disk = true,
create = r##"{ DiskCacheBuilder::new("test-cache").set_disk_directory("./").build().expect("valid disk cache builder") }"##,
key = "String",
convert = r##"{ format!("{}:{}", a, b) }"##
)]
async fn cache_test(a: &str, b: &str) -> Result<String, ActiveStorageError> {
println!("Function called");
Ok(format!("{} - {}", a, b))
}

#[tokio::test]
async fn disk_cache() {
// cache_test("a").unwrap();
// cache_test("a").unwrap();
// cache_test(1, 2).unwrap();
// cache_test(1, 2).unwrap();
cache_test("a", "b").await.unwrap();
cache_test("a", "b").await.unwrap();
}

fn make_access_key() -> S3Credentials {
S3Credentials::access_key("user", "password")
}
Expand Down

0 comments on commit d9dfd2d

Please sign in to comment.