Skip to content

Commit

Permalink
Use try / trace pattern for caching
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Apr 10, 2024
1 parent 15f4922 commit 4cbc428
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 81 deletions.
98 changes: 49 additions & 49 deletions server/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use axum::extract::FromRef;
use color_eyre::eyre::eyre;
use redis::AsyncCommands;
use redis::Client as RedisClient;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::future::Future;

#[derive(Debug, Clone, Deserialize)]
#[allow(unused)]
Expand Down Expand Up @@ -39,96 +39,96 @@ impl Cache {
})
}

pub async fn zincr(&self, space: &str, key: &str, value: i64) -> Result<(), AppError> {
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
match self.try_get(key).await {
Ok(response) => response,
Err(e) => {
tracing::error!("Failed to get cache key {}: {}", key, e);
None
}
}
}

pub async fn try_get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, AppError> {
if !self.settings.enabled {
return Ok(());
return Ok(None);
}

let mut connection = self.client.get_multiplexed_tokio_connection().await?;

connection
.zincr(space, key, value)
.await
.map_err(|e| AppError::from(e))?;
let cache_response: Option<T> =
connection.get(key).await.map(|response: Option<String>| {
response.and_then(|response| serde_json::from_str(&response).ok())
})?;

Ok(())
Ok(cache_response)
}

pub async fn zrevrange(
&self,
space: &str,
start: i64,
stop: i64,
) -> Result<Vec<String>, AppError> {
pub async fn set<T: Serialize>(&self, key: &str, value: &T) {
if let Err(e) = self.try_set(key, value).await {
tracing::error!("Failed to set cache key {}: {}", key, e);
}
}

pub async fn try_set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), AppError> {
if !self.settings.enabled {
return Ok(vec![]);
return Ok(());
}

let mut connection = self.client.get_multiplexed_tokio_connection().await?;

let cache_response: Vec<String> = connection
.zrevrange(space, start as isize - 1, stop as isize - 1)
.await
.map_err(|e| AppError::from(e))?;
connection
.set(
key,
serde_json::to_string(value)
.map_err(|_| eyre!("unable to convert string to json"))?,
)
.await?;

Ok(cache_response)
Ok(())
}

pub async fn zremrangebyrank(&self, space: &str) -> Result<(), AppError> {
pub async fn zincr(&self, space: &str, key: &str, value: i64) -> Result<(), AppError> {
if !self.settings.enabled {
return Ok(());
}

let mut connection = self.client.get_multiplexed_tokio_connection().await?;

connection
.zremrangebyrank(space, 0, -self.settings.max_sorted_size as isize - 1)
.await
.map_err(|e| AppError::from(e))?;
connection.zincr(space, key, value).await?;

Ok(())
}
}

pub trait CacheFn<T> {
fn get(&self, key: &str) -> impl Future<Output = Result<Option<T>, AppError>>;
fn set(&self, key: &str, value: &T) -> impl Future<Output = Result<(), AppError>>;
}

impl<T: Serialize + for<'de> Deserialize<'de>> CacheFn<T> for Cache {
async fn get(&self, key: &str) -> Result<Option<T>, AppError> {
pub async fn zrevrange(
&self,
space: &str,
start: i64,
stop: i64,
) -> Result<Vec<String>, AppError> {
if !self.settings.enabled {
return Ok(None);
return Ok(vec![]);
}

let mut connection = self.client.get_multiplexed_tokio_connection().await?;

let cache_response: Option<T> = connection
.get(key)
.await
.map(|response: Option<String>| {
response.and_then(|response| serde_json::from_str(&response).ok())
})
.map_err(|e| AppError::from(e))?;
let cache_response: Vec<String> = connection
.zrevrange(space, start as isize - 1, stop as isize - 1)
.await?;

Ok(cache_response)
}

async fn set(&self, key: &str, value: &T) -> Result<(), AppError> {
pub async fn zremrangebyrank(&self, space: &str) -> Result<(), AppError> {
if !self.settings.enabled {
return Ok(());
}

let mut connection = self.client.get_multiplexed_tokio_connection().await?;

connection
.set(
key,
serde_json::to_string(value)
.map_err(|_| eyre!("unable to convert string to json"))?,
)
.await
.map_err(|e| AppError::from(e))?;
.zremrangebyrank(space, 0, -self.settings.max_sorted_size as isize - 1)
.await?;

Ok(())
}
Expand Down
42 changes: 10 additions & 32 deletions server/src/search/services.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::cache::{Cache, CacheFn};
use crate::err::AppError;
use crate::cache::Cache;
use crate::search::{
RAGTokenResponse, SearchHistory, SearchHistoryRequest, SearchQueryRequest,
SearchReactionRequest, SearchResponse, TopSearchRequest,
SearchHistory, SearchHistoryRequest, SearchQueryRequest, SearchReactionRequest, SearchResponse,
TopSearchRequest,
};
use crate::settings::SETTINGS;
use color_eyre::eyre::eyre;
Expand All @@ -16,38 +15,22 @@ pub async fn search(
cache: &Cache,
search_query: &SearchQueryRequest,
) -> crate::Result<SearchResponse> {
let cache_response: Option<SearchResponse> = cache.get(&search_query.query).await?;
if let Some(response) = cache_response {
if let Some(response) = cache.get(&search_query.query).await {
return Ok(response);
}

// TODO: replace this with actual search logic using GRPC calls with backend services
let rag_api_url = SETTINGS.rag_api.clone() + "/token";
let form_data = [
("username", &SETTINGS.rag_api_username.expose()),
("password", &SETTINGS.rag_api_password.expose()),
];
let token: RAGTokenResponse = ReqwestClient::new()
.post(rag_api_url)
.form(&form_data)
.send()
.await
.map_err(|_| eyre!("unable to send request to rag api"))?
.json()
.await
.map_err(|_| eyre!("unable to parse json response from rag api"))?;

let rag_api_url = SETTINGS.rag_api.clone() + "/search?query=" + &search_query.query;
let response: SearchResponse = ReqwestClient::new()
.get(rag_api_url)
.header("Authorization", format!("Bearer {}", token.access_token))
.send()
.await
.map_err(|_| eyre!("unable to send request to rag api"))?
.json()
.await
.map_err(|_| eyre!("unable to parse json response from rag api"))?;

cache.set(&search_query.query, &response).await;

return Ok(response);
}

Expand All @@ -59,8 +42,6 @@ pub async fn insert_search_history(
search_query: &SearchQueryRequest,
search_response: &SearchResponse,
) -> crate::Result<SearchHistory> {
cache.set(&search_query.query, search_response).await?;

let session_id = search_query.session_id.unwrap_or(Uuid::new_v4());

let search_history = sqlx::query_as!(
Expand All @@ -73,8 +54,7 @@ pub async fn insert_search_history(
&search_response.sources
)
.fetch_one(pool)
.await
.map_err(|e| AppError::from(e))?;
.await?;

return Ok(search_history);
}
Expand All @@ -93,8 +73,7 @@ pub async fn get_search_history(
search_history_request.offset.unwrap_or(0) as i64
)
.fetch_all(pool)
.await
.map_err(|e| AppError::from(e))?;
.await?;

return Ok(search_history);
}
Expand All @@ -110,7 +89,7 @@ pub async fn get_top_searches(
}

let limit = top_search_request.limit.unwrap_or(10);
if limit < 1 || limit > 100 {
if !(1..=100).contains(&limit) {
Err(eyre!("limit must be a number between 1 and 100"))?;
}

Expand All @@ -133,8 +112,7 @@ pub async fn update_search_reaction(
user_id
)
.fetch_one(pool)
.await
.map_err(|e| AppError::from(e))?;
.await?;

return Ok(search_history);
}

0 comments on commit 4cbc428

Please sign in to comment.