From e4ae3345701c684f3a3f6d92ae9bf1fae6cf9c86 Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Thu, 4 Apr 2024 18:07:11 +0600 Subject: [PATCH 1/9] added initial version of search APIs --- server/.gitignore | 3 + server/Cargo.lock | 44 ++++++++++++ server/Cargo.toml | 16 ++++- server/config/dev.toml | 1 + .../20240403115016_search_history.sql | 11 +++ server/src/err.rs | 10 +++ server/src/lib.rs | 1 + server/src/routing/api.rs | 3 +- server/src/search/mod.rs | 7 ++ server/src/search/models.rs | 33 +++++++++ server/src/search/routes.rs | 53 ++++++++++++++ server/src/search/services.rs | 70 +++++++++++++++++++ server/src/settings.rs | 1 + server/src/startup.rs | 17 ++++- 14 files changed, 266 insertions(+), 4 deletions(-) create mode 100644 server/migrations/20240403115016_search_history.sql create mode 100644 server/src/search/mod.rs create mode 100644 server/src/search/models.rs create mode 100644 server/src/search/routes.rs create mode 100644 server/src/search/services.rs diff --git a/server/.gitignore b/server/.gitignore index 9481ce52..a63b034d 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -27,3 +27,6 @@ target/ # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# These will be created by the sqlx CLI +*.sqlx diff --git a/server/Cargo.lock b/server/Cargo.lock index 3dd73cef..972c5c5c 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -339,6 +339,20 @@ dependencies = [ "tracing-error", ] +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "config" version = "0.14.0" @@ -1692,6 +1706,29 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6472825949c09872e8f2c50bde59fcefc17748b6be5c90fd67cd8b4daca73bfd" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "serde", + "serde_json", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -2074,6 +2111,7 @@ dependencies = [ "oauth2", "once_cell", "password-auth", + "redis", "reqwest 0.12.1", "serde", "serde_json", @@ -2102,6 +2140,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.8" diff --git a/server/Cargo.toml b/server/Cargo.toml index b385f9cd..00399630 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -27,7 +27,14 @@ reqwest = { version = "0.12.1", features = ["json"] } serde = "1.0.197" serde_json = "1.0.114" serde_urlencoded = "0.7.1" -sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls", "migrate", "uuid", "time"] } +sqlx = { version = "0.7.4", features = [ + "postgres", + "runtime-tokio", + "tls-rustls", + "migrate", + "uuid", + "time", +] } thiserror = "1.0.58" tokio = { version = "1.36.0", features = ["full"] } tower-http = { version = "0.5.2", features = ["trace", "cors"] } @@ -36,6 +43,11 @@ tower = "0.4.13" tracing-error = "0.2.0" tracing-log = "0.2.0" tracing-logfmt = "0.3.4" -tracing-subscriber = { version = "0.3.18", features = ["json", "registry", "env-filter"] } +tracing-subscriber = { version = "0.3.18", features = [ + "json", + "registry", + "env-filter", +] } uuid = { version = "1.8.0", features = ["serde"] } log = "0.4.21" +redis = { version = "0.25.2", features = ["tokio-comp", "json"] } diff --git a/server/config/dev.toml b/server/config/dev.toml index a820754a..456ee9d1 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -1,4 +1,5 @@ db = "postgresql://postgres:postgres@localhost/curieo" +cache = "redis://127.0.0.1/" [log] level = "info" diff --git a/server/migrations/20240403115016_search_history.sql b/server/migrations/20240403115016_search_history.sql new file mode 100644 index 00000000..36a52a50 --- /dev/null +++ b/server/migrations/20240403115016_search_history.sql @@ -0,0 +1,11 @@ +create table search_history ( + search_history_id uuid primary key default uuid_generate_v1mc(), + search_text text not null, + response_text text not null, + response_sources text[] not null, + created_at timestamptz not null default now(), + updated_at timestamptz not null default now() +); + +-- And applying our `updated_at` trigger is as easy as this. +SELECT trigger_updated_at('search_history'); diff --git a/server/src/err.rs b/server/src/err.rs index 9562d515..5087db1c 100644 --- a/server/src/err.rs +++ b/server/src/err.rs @@ -14,6 +14,7 @@ pub enum AppError { UnprocessableEntity(ErrorMap), Sqlx(sqlx::Error), GenericError(color_eyre::eyre::Error), + Redis(redis::RedisError), } impl AppError { @@ -22,6 +23,7 @@ impl AppError { Self::Unauthorized => StatusCode::UNAUTHORIZED, Self::UnprocessableEntity { .. } => StatusCode::UNPROCESSABLE_ENTITY, Self::Sqlx(_) | Self::GenericError(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Redis(_) => StatusCode::INTERNAL_SERVER_ERROR, } } /// Convenient constructor for `Error::UnprocessableEntity`. @@ -47,6 +49,12 @@ impl From for AppError { } } +impl From for AppError { + fn from(inner: redis::RedisError) -> Self { + AppError::Redis(inner) + } +} + impl Display for AppError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -54,6 +62,7 @@ impl Display for AppError { AppError::Sqlx(e) => write!(f, "{}", e), AppError::UnprocessableEntity(e) => write!(f, "{:?}", e), AppError::Unauthorized => write!(f, "Unauthorized"), + AppError::Redis(e) => write!(f, "{}", e), } } } @@ -105,6 +114,7 @@ impl IntoResponse for AppError { } AppError::Sqlx(ref e) => error!("SQLx error: {:?}", e), AppError::GenericError(ref e) => error!("Generic error: {:?}", e), + AppError::Redis(ref e) => error!("Redis error: {:?}", e), }; // Return a http status code and json body with error message. diff --git a/server/src/lib.rs b/server/src/lib.rs index e2fbfcc3..5a865443 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod auth; mod err; mod health_check; pub mod routing; +pub mod search; pub mod secrets; pub mod settings; pub mod startup; diff --git a/server/src/routing/api.rs b/server/src/routing/api.rs index 608e9aa4..dca9bae3 100644 --- a/server/src/routing/api.rs +++ b/server/src/routing/api.rs @@ -8,7 +8,7 @@ use tracing::Level; use crate::auth::models::{OAuth2Clients, PostgresBackend}; use crate::startup::AppState; -use crate::{auth, health_check, users}; +use crate::{auth, health_check, search, users}; pub fn router(state: AppState) -> color_eyre::Result { // Session layer. @@ -33,6 +33,7 @@ pub fn router(state: AppState) -> color_eyre::Result { let api_routes = Router::new() //.nest("/search", search::routes()) //.layer(middleware::from_fn(some_auth_middleware)) + .nest("/search", search::routes()) .nest("/users", users::routes()) .route_layer(login_required!( PostgresBackend, diff --git a/server/src/search/mod.rs b/server/src/search/mod.rs new file mode 100644 index 00000000..1d3d1250 --- /dev/null +++ b/server/src/search/mod.rs @@ -0,0 +1,7 @@ +pub use models::*; +pub use routes::*; +pub use services::*; + +pub mod models; +pub mod routes; +pub mod services; diff --git a/server/src/search/models.rs b/server/src/search/models.rs new file mode 100644 index 00000000..94fef466 --- /dev/null +++ b/server/src/search/models.rs @@ -0,0 +1,33 @@ +use serde::{Deserialize, Serialize}; +use sqlx::types::time; +use sqlx::FromRow; +use std::fmt::Debug; + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchQueryRequest { + pub query: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchHistoryRequest { + pub limit: Option, + pub offset: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchResponse { + pub response_text: String, + pub response_sources: Vec, +} + +#[derive(FromRow, Serialize, Deserialize, Clone, Debug)] +pub struct SearchHistory { + pub search_history_id: uuid::Uuid, + // pub user_id: uuid::Uuid, + pub search_text: String, + pub response_text: String, + pub response_sources: Vec, + + pub created_at: time::OffsetDateTime, + pub updated_at: time::OffsetDateTime, +} diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs new file mode 100644 index 00000000..83b4b524 --- /dev/null +++ b/server/src/search/routes.rs @@ -0,0 +1,53 @@ +use crate::err::AppError; +use crate::search::services; +use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest}; +use crate::startup::AppState; +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::{Json, Router}; +use redis::Client as RedisClient; +use sqlx::PgPool; + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_handler( + State(pool): State, + State(cache): State, + Query(search_query): Query, +) -> crate::Result { + let mut connection = cache + .get_multiplexed_async_connection() + .await + .map_err(|e| AppError::from(e))?; + + let search_response = services::search(&mut connection, &search_query).await?; + services::insert_search_history(&pool, &mut connection, &search_query, &search_response) + .await?; + + Ok((StatusCode::OK, Json(search_response))) +} + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_history_handler( + State(pool): State, + Query(search_history_request): Query, +) -> crate::Result { + let search_history = sqlx::query_as!( + SearchHistory, + "select * from search_history order by created_at desc limit $1 offset $2", + search_history_request.limit.unwrap_or(10) as i64, + search_history_request.offset.unwrap_or(0) as i64 + ) + .fetch_all(&pool) + .await + .map_err(|e| AppError::from(e))?; + + Ok((StatusCode::OK, Json(search_history))) +} + +pub fn routes() -> Router { + Router::new() + .route("/", get(get_search_handler)) + .route("/history", get(get_search_history_handler)) +} diff --git a/server/src/search/services.rs b/server/src/search/services.rs new file mode 100644 index 00000000..19153c21 --- /dev/null +++ b/server/src/search/services.rs @@ -0,0 +1,70 @@ +use crate::err::AppError; +use crate::search::{SearchHistory, SearchQueryRequest, SearchResponse}; +use color_eyre::eyre::eyre; +use redis::aio::MultiplexedConnection; +use redis::AsyncCommands; +use sqlx::PgPool; + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn search( + cache: &mut MultiplexedConnection, + search_query: &SearchQueryRequest, +) -> crate::Result { + let cache_response: Option = cache + .get(search_query.query.clone()) + .await + .map_err(|e| AppError::from(e))?; + + let cache_response = match cache_response { + Some(response) => response, + None => String::new(), + }; + + let cache_response: Option = + serde_json::from_str(&cache_response).unwrap_or(None); + + if let Some(response) = cache_response { + return Ok(response); + } + + // sleep for 3 seconds to simulate a slow search + // TODO: replace this with actual search logic using GRPC calls with backend services + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + + let response = SearchResponse { + response_text: "sample response".to_string(), + response_sources: vec!["www.source1.com".to_string(), "www.source2.com".to_string()], + }; + + return Ok(response); +} + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn insert_search_history( + pool: &PgPool, + cache: &mut MultiplexedConnection, + search_query: &SearchQueryRequest, + search_response: &SearchResponse, +) -> crate::Result { + cache + .set( + &search_query.query, + serde_json::to_string(&search_response) + .map_err(|_| eyre!("unable to convert string to json"))?, + ) + .await + .map_err(|e| AppError::from(e))?; + + let search_history = sqlx::query_as!( + SearchHistory, + "insert into search_history (search_text, response_text, response_sources) values ($1, $2, $3) returning *", + &search_query.query, + &search_response.response_text, + &search_response.response_sources + ) + .fetch_one(pool) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(search_history); +} diff --git a/server/src/settings.rs b/server/src/settings.rs index 7c0703d2..e2433770 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -78,6 +78,7 @@ pub struct Settings { pub host: String, pub port: u16, pub db: Secret, + pub cache: Secret, } impl Settings { diff --git a/server/src/startup.rs b/server/src/startup.rs index d54dda0b..7a20e994 100644 --- a/server/src/startup.rs +++ b/server/src/startup.rs @@ -1,5 +1,6 @@ use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router}; use color_eyre::eyre::eyre; +use redis::Client as RedisClient; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tokio::net::TcpListener; @@ -44,6 +45,7 @@ impl Application { #[derive(Clone, Debug, FromRef)] pub struct AppState { pub db: PgPool, + pub cache: redis::Client, pub settings: Settings, } @@ -58,13 +60,26 @@ pub async fn db_connect(database_url: &str) -> Result { } } +pub async fn cache_connect(cache_url: &str) -> Result { + match RedisClient::open(cache_url) { + Ok(client) => Ok(client), + Err(e) => Err(eyre!("Failed to connect to Redis: {}", e).into()), + } +} + async fn run( listener: TcpListener, settings: Settings, ) -> Result, Router>> { let db = db_connect(settings.db.expose()).await?; - let state = AppState { db, settings }; + let cache = cache_connect(settings.cache.expose()).await?; + + let state = AppState { + db, + cache, + settings, + }; let app = router(state)?; From 66574d0dbdddc782cb650319852b22b0ee3dde7c Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 11:36:35 +0600 Subject: [PATCH 2/9] added API call to RAG Backend --- server/config/dev.toml | 3 ++ .../20240403115016_search_history.sql | 6 +-- server/src/search/models.rs | 16 ++++-- server/src/search/services.rs | 53 ++++++++++++------- server/src/settings.rs | 3 ++ 5 files changed, 54 insertions(+), 27 deletions(-) diff --git a/server/config/dev.toml b/server/config/dev.toml index 456ee9d1..f197d372 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -1,5 +1,8 @@ db = "postgresql://postgres:postgres@localhost/curieo" cache = "redis://127.0.0.1/" +rag_api = "http://127.0.0.1:8000" +rag_api_username = "curieo" +rag_api_password = "curieo" [log] level = "info" diff --git a/server/migrations/20240403115016_search_history.sql b/server/migrations/20240403115016_search_history.sql index 36a52a50..d40ca7ea 100644 --- a/server/migrations/20240403115016_search_history.sql +++ b/server/migrations/20240403115016_search_history.sql @@ -1,8 +1,8 @@ create table search_history ( search_history_id uuid primary key default uuid_generate_v1mc(), - search_text text not null, - response_text text not null, - response_sources text[] not null, + query text not null, + result text not null, + sources text[] not null, created_at timestamptz not null default now(), updated_at timestamptz not null default now() ); diff --git a/server/src/search/models.rs b/server/src/search/models.rs index 94fef466..1dbbfd18 100644 --- a/server/src/search/models.rs +++ b/server/src/search/models.rs @@ -3,6 +3,12 @@ use sqlx::types::time; use sqlx::FromRow; use std::fmt::Debug; +#[derive(Serialize, Deserialize, Debug)] +pub struct RAGTokenResponse { + pub access_token: String, + pub token_type: String, +} + #[derive(Serialize, Deserialize, Debug)] pub struct SearchQueryRequest { pub query: String, @@ -16,17 +22,17 @@ pub struct SearchHistoryRequest { #[derive(Serialize, Deserialize, Debug)] pub struct SearchResponse { - pub response_text: String, - pub response_sources: Vec, + pub result: String, + pub sources: Vec, } #[derive(FromRow, Serialize, Deserialize, Clone, Debug)] pub struct SearchHistory { pub search_history_id: uuid::Uuid, // pub user_id: uuid::Uuid, - pub search_text: String, - pub response_text: String, - pub response_sources: Vec, + pub query: String, + pub result: String, + pub sources: Vec, pub created_at: time::OffsetDateTime, pub updated_at: time::OffsetDateTime, diff --git a/server/src/search/services.rs b/server/src/search/services.rs index 19153c21..0e0f2bc7 100644 --- a/server/src/search/services.rs +++ b/server/src/search/services.rs @@ -1,8 +1,10 @@ use crate::err::AppError; -use crate::search::{SearchHistory, SearchQueryRequest, SearchResponse}; +use crate::search::{RAGTokenResponse, SearchHistory, SearchQueryRequest, SearchResponse}; +use crate::settings::SETTINGS; use color_eyre::eyre::eyre; use redis::aio::MultiplexedConnection; use redis::AsyncCommands; +use reqwest::Client as ReqwestClient; use sqlx::PgPool; #[tracing::instrument(level = "debug", ret, err)] @@ -10,31 +12,44 @@ pub async fn search( cache: &mut MultiplexedConnection, search_query: &SearchQueryRequest, ) -> crate::Result { - let cache_response: Option = cache + let cache_response: Option = cache .get(search_query.query.clone()) .await + .map(|response: Option| { + response.and_then(|response| serde_json::from_str(&response).ok()) + }) .map_err(|e| AppError::from(e))?; - let cache_response = match cache_response { - Some(response) => response, - None => String::new(), - }; - - let cache_response: Option = - serde_json::from_str(&cache_response).unwrap_or(None); - if let Some(response) = cache_response { return Ok(response); } - // sleep for 3 seconds to simulate a slow search // TODO: replace this with actual search logic using GRPC calls with backend services - tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + let rag_api_url = SETTINGS.rag_api.clone() + "/token"; + let form_data = [ + ("username", &SETTINGS.rag_api_username.expose()), + ("password", &SETTINGS.rag_api_password.expose()), + ]; + let token: RAGTokenResponse = ReqwestClient::new() + .post(rag_api_url) + .form(&form_data) + .send() + .await + .map_err(|_| eyre!("unable to send request to rag api"))? + .json() + .await + .map_err(|_| eyre!("unable to parse json response from rag api"))?; - let response = SearchResponse { - response_text: "sample response".to_string(), - response_sources: vec!["www.source1.com".to_string(), "www.source2.com".to_string()], - }; + let rag_api_url = SETTINGS.rag_api.clone() + "/search?query=" + &search_query.query; + let response: SearchResponse = ReqwestClient::new() + .get(rag_api_url) + .header("Authorization", format!("Bearer {}", token.access_token)) + .send() + .await + .map_err(|_| eyre!("unable to send request to rag api"))? + .json() + .await + .map_err(|_| eyre!("unable to parse json response from rag api"))?; return Ok(response); } @@ -57,10 +72,10 @@ pub async fn insert_search_history( let search_history = sqlx::query_as!( SearchHistory, - "insert into search_history (search_text, response_text, response_sources) values ($1, $2, $3) returning *", + "insert into search_history (query, result, sources) values ($1, $2, $3) returning *", &search_query.query, - &search_response.response_text, - &search_response.response_sources + &search_response.result, + &search_response.sources ) .fetch_one(pool) .await diff --git a/server/src/settings.rs b/server/src/settings.rs index e2433770..4db5d72b 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -79,6 +79,9 @@ pub struct Settings { pub port: u16, pub db: Secret, pub cache: Secret, + pub rag_api: String, + pub rag_api_username: Secret, + pub rag_api_password: Secret, } impl Settings { From 4e494bf4648b2af307372281cd79ce44e117f60f Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 11:54:41 +0600 Subject: [PATCH 3/9] added user_id foreign key in the search_history --- server/Cargo.lock | 1 + server/Cargo.toml | 2 +- .../20240403115016_search_history.sql | 1 + server/src/search/models.rs | 2 +- server/src/search/routes.rs | 17 ++++++++++++++--- server/src/search/services.rs | 9 ++++++--- 6 files changed, 24 insertions(+), 8 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index 972c5c5c..d1bd4961 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -3014,6 +3014,7 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ + "getrandom", "serde", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index 00399630..7c8ccd45 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -48,6 +48,6 @@ tracing-subscriber = { version = "0.3.18", features = [ "registry", "env-filter", ] } -uuid = { version = "1.8.0", features = ["serde"] } +uuid = { version = "1.8.0", features = ["serde", "v4"] } log = "0.4.21" redis = { version = "0.25.2", features = ["tokio-comp", "json"] } diff --git a/server/migrations/20240403115016_search_history.sql b/server/migrations/20240403115016_search_history.sql index d40ca7ea..1bcbb603 100644 --- a/server/migrations/20240403115016_search_history.sql +++ b/server/migrations/20240403115016_search_history.sql @@ -1,5 +1,6 @@ create table search_history ( search_history_id uuid primary key default uuid_generate_v1mc(), + user_id uuid not null references users(user_id), query text not null, result text not null, sources text[] not null, diff --git a/server/src/search/models.rs b/server/src/search/models.rs index 1dbbfd18..414b657c 100644 --- a/server/src/search/models.rs +++ b/server/src/search/models.rs @@ -29,7 +29,7 @@ pub struct SearchResponse { #[derive(FromRow, Serialize, Deserialize, Clone, Debug)] pub struct SearchHistory { pub search_history_id: uuid::Uuid, - // pub user_id: uuid::Uuid, + pub user_id: uuid::Uuid, pub query: String, pub result: String, pub sources: Vec, diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 83b4b524..6696f979 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -16,14 +16,22 @@ async fn get_search_handler( State(cache): State, Query(search_query): Query, ) -> crate::Result { + let user_id = uuid::Uuid::parse_str("78c4c766-f310-11ee-a6ee-5f4062fc15f2").unwrap(); + let mut connection = cache .get_multiplexed_async_connection() .await .map_err(|e| AppError::from(e))?; let search_response = services::search(&mut connection, &search_query).await?; - services::insert_search_history(&pool, &mut connection, &search_query, &search_response) - .await?; + services::insert_search_history( + &pool, + &mut connection, + &user_id, + &search_query, + &search_response, + ) + .await?; Ok((StatusCode::OK, Json(search_response))) } @@ -33,9 +41,12 @@ async fn get_search_history_handler( State(pool): State, Query(search_history_request): Query, ) -> crate::Result { + let user_id = uuid::Uuid::parse_str("78c4c766-f310-11ee-a6ee-5f4062fc15f2").unwrap(); + let search_history = sqlx::query_as!( SearchHistory, - "select * from search_history order by created_at desc limit $1 offset $2", + "select * from search_history where user_id = $1 order by created_at desc limit $2 offset $3", + user_id, search_history_request.limit.unwrap_or(10) as i64, search_history_request.offset.unwrap_or(0) as i64 ) diff --git a/server/src/search/services.rs b/server/src/search/services.rs index 0e0f2bc7..76aeb8f0 100644 --- a/server/src/search/services.rs +++ b/server/src/search/services.rs @@ -6,6 +6,7 @@ use redis::aio::MultiplexedConnection; use redis::AsyncCommands; use reqwest::Client as ReqwestClient; use sqlx::PgPool; +use uuid::Uuid; #[tracing::instrument(level = "debug", ret, err)] pub async fn search( @@ -58,6 +59,7 @@ pub async fn search( pub async fn insert_search_history( pool: &PgPool, cache: &mut MultiplexedConnection, + user_id: &Uuid, search_query: &SearchQueryRequest, search_response: &SearchResponse, ) -> crate::Result { @@ -72,9 +74,10 @@ pub async fn insert_search_history( let search_history = sqlx::query_as!( SearchHistory, - "insert into search_history (query, result, sources) values ($1, $2, $3) returning *", - &search_query.query, - &search_response.result, + "insert into search_history (user_id, query, result, sources) values ($1, $2, $3, $4) returning *", + user_id, + search_query.query, + search_response.result, &search_response.sources ) .fetch_one(pool) From 48819fe3b338128b0a54bf0508aac62c146c96cc Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:00:35 +0600 Subject: [PATCH 4/9] added top search queries api --- server/Cargo.lock | 1 + server/Cargo.toml | 3 +++ server/config/dev.toml | 1 + server/src/search/models.rs | 5 ++++ server/src/search/routes.rs | 48 +++++++++++++++++++++++++++++++++++-- server/src/settings.rs | 1 + 6 files changed, 57 insertions(+), 2 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index d1bd4961..e3c38d13 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -2111,6 +2111,7 @@ dependencies = [ "oauth2", "once_cell", "password-auth", + "rand", "redis", "reqwest 0.12.1", "serde", diff --git a/server/Cargo.toml b/server/Cargo.toml index 7c8ccd45..05f548e2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,8 @@ name = "server" path = "src/main.rs" [dependencies] + + anyhow = { version = "1.0.81", features = ["backtrace"] } async-trait = "0.1.78" axum = { version = "0.7.4", features = ["macros"] } @@ -51,3 +53,4 @@ tracing-subscriber = { version = "0.3.18", features = [ uuid = { version = "1.8.0", features = ["serde", "v4"] } log = "0.4.21" redis = { version = "0.25.2", features = ["tokio-comp", "json"] } +rand = "0.8.5" diff --git a/server/config/dev.toml b/server/config/dev.toml index f197d372..970a1a6e 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -1,5 +1,6 @@ db = "postgresql://postgres:postgres@localhost/curieo" cache = "redis://127.0.0.1/" +cache_max_sorted_size = 100 rag_api = "http://127.0.0.1:8000" rag_api_username = "curieo" rag_api_password = "curieo" diff --git a/server/src/search/models.rs b/server/src/search/models.rs index 414b657c..24d5f17e 100644 --- a/server/src/search/models.rs +++ b/server/src/search/models.rs @@ -3,6 +3,11 @@ use sqlx::types::time; use sqlx::FromRow; use std::fmt::Debug; +#[derive(Serialize, Deserialize, Debug)] +pub struct TopSearchRequest { + pub limit: Option, +} + #[derive(Serialize, Deserialize, Debug)] pub struct RAGTokenResponse { pub access_token: String, diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 6696f979..415e7e6e 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -1,13 +1,16 @@ use crate::err::AppError; use crate::search::services; -use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest}; +use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; +use crate::settings::SETTINGS; use crate::startup::AppState; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; use axum::{Json, Router}; -use redis::Client as RedisClient; +use color_eyre::eyre::eyre; +use rand::Rng; +use redis::{AsyncCommands, Client as RedisClient}; use sqlx::PgPool; #[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] @@ -33,6 +36,11 @@ async fn get_search_handler( ) .await?; + connection + .zincr("search_queries", &search_query.query, 1) + .await + .map_err(|e| AppError::from(e))?; + Ok((StatusCode::OK, Json(search_response))) } @@ -57,8 +65,44 @@ async fn get_search_history_handler( Ok((StatusCode::OK, Json(search_history))) } +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_top_searches_handler( + State(cache): State, + Query(query): Query, +) -> crate::Result { + let mut connection = cache + .get_multiplexed_async_connection() + .await + .map_err(|e| AppError::from(e))?; + + let random_number = rand::thread_rng().gen_range(0.0..1.0); + if random_number < 0.1 { + connection + .zremrangebyrank( + "search_history", + 0, + -SETTINGS.cache_max_sorted_size as isize - 1, + ) + .await + .map_err(|e| AppError::from(e))?; + } + + let limit = query.limit.unwrap_or(10); + if limit < 1 || limit > 100 { + Err(eyre!("limit must be a number between 1 and 100"))?; + } + + let top_searches: Vec = connection + .zrevrange("search_queries", 0, limit as isize - 1) + .await + .map_err(|e| AppError::from(e))?; + + Ok((StatusCode::OK, Json(top_searches))) +} + pub fn routes() -> Router { Router::new() .route("/", get(get_search_handler)) .route("/history", get(get_search_history_handler)) + .route("/top", get(get_top_searches_handler)) } diff --git a/server/src/settings.rs b/server/src/settings.rs index 4db5d72b..22aa9143 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -82,6 +82,7 @@ pub struct Settings { pub rag_api: String, pub rag_api_username: Secret, pub rag_api_password: Secret, + pub cache_max_sorted_size: i64, } impl Settings { From a1fa661fe4dcf34587631a2ee6a8c1b3084ead28 Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:06:01 +0600 Subject: [PATCH 5/9] modularized code of top search api --- server/src/search/routes.rs | 25 +------------------------ server/src/search/services.rs | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 415e7e6e..14db471f 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -1,15 +1,12 @@ use crate::err::AppError; use crate::search::services; use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; -use crate::settings::SETTINGS; use crate::startup::AppState; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; use axum::{Json, Router}; -use color_eyre::eyre::eyre; -use rand::Rng; use redis::{AsyncCommands, Client as RedisClient}; use sqlx::PgPool; @@ -75,27 +72,7 @@ async fn get_top_searches_handler( .await .map_err(|e| AppError::from(e))?; - let random_number = rand::thread_rng().gen_range(0.0..1.0); - if random_number < 0.1 { - connection - .zremrangebyrank( - "search_history", - 0, - -SETTINGS.cache_max_sorted_size as isize - 1, - ) - .await - .map_err(|e| AppError::from(e))?; - } - - let limit = query.limit.unwrap_or(10); - if limit < 1 || limit > 100 { - Err(eyre!("limit must be a number between 1 and 100"))?; - } - - let top_searches: Vec = connection - .zrevrange("search_queries", 0, limit as isize - 1) - .await - .map_err(|e| AppError::from(e))?; + let top_searches = services::get_top_searches(&mut connection, &query).await?; Ok((StatusCode::OK, Json(top_searches))) } diff --git a/server/src/search/services.rs b/server/src/search/services.rs index 76aeb8f0..4e12fbe1 100644 --- a/server/src/search/services.rs +++ b/server/src/search/services.rs @@ -1,7 +1,10 @@ use crate::err::AppError; -use crate::search::{RAGTokenResponse, SearchHistory, SearchQueryRequest, SearchResponse}; +use crate::search::{ + RAGTokenResponse, SearchHistory, SearchQueryRequest, SearchResponse, TopSearchRequest, +}; use crate::settings::SETTINGS; use color_eyre::eyre::eyre; +use rand::Rng; use redis::aio::MultiplexedConnection; use redis::AsyncCommands; use reqwest::Client as ReqwestClient; @@ -86,3 +89,33 @@ pub async fn insert_search_history( return Ok(search_history); } + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn get_top_searches( + cache: &mut MultiplexedConnection, + top_search_request: &TopSearchRequest, +) -> crate::Result> { + let random_number = rand::thread_rng().gen_range(0.0..1.0); + if random_number < 0.1 { + cache + .zremrangebyrank( + "search_history", + 0, + -SETTINGS.cache_max_sorted_size as isize - 1, + ) + .await + .map_err(|e| AppError::from(e))?; + } + + let limit = top_search_request.limit.unwrap_or(10); + if limit < 1 || limit > 100 { + Err(eyre!("limit must be a number between 1 and 100"))?; + } + + let top_searches: Vec = cache + .zrevrange("search_queries", 0, limit as isize - 1) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(top_searches); +} From 1090dbf49288a056521d34816ba48ea734ccda0a Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 17:24:01 +0600 Subject: [PATCH 6/9] :bug: fix merge conflict --- server/config/dev.toml | 1 + server/src/settings.rs | 1 - server/src/startup.rs | 14 +++++--------- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/server/config/dev.toml b/server/config/dev.toml index 970a1a6e..5f7a8f55 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -4,6 +4,7 @@ cache_max_sorted_size = 100 rag_api = "http://127.0.0.1:8000" rag_api_username = "curieo" rag_api_password = "curieo" +oauth2_clients = [] [log] level = "info" diff --git a/server/src/settings.rs b/server/src/settings.rs index eb9f1ca4..11fed010 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -1,5 +1,4 @@ use std::{env, fmt::Display}; - use crate::auth::oauth2::OAuth2Client; use crate::secrets::Secret; use config::{Config, Environment, File}; diff --git a/server/src/startup.rs b/server/src/startup.rs index 72843907..ffc092df 100644 --- a/server/src/startup.rs +++ b/server/src/startup.rs @@ -45,15 +45,16 @@ impl Application { #[derive(Clone, Debug, FromRef)] pub struct AppState { pub db: PgPool, - pub cache: redis::Client, + pub cache: RedisClient, pub oauth2_clients: Vec, pub settings: Settings, } -impl From<(PgPool, Settings)> for AppState { - fn from((db, settings): (PgPool, Settings)) -> Self { +impl From<(PgPool, RedisClient, Settings)> for AppState { + fn from((db, cache, settings): (PgPool, RedisClient, Settings)) -> Self { Self { db, + cache, oauth2_clients: settings.oauth2_clients.clone(), settings, } @@ -86,12 +87,7 @@ async fn run( let cache = cache_connect(settings.cache.expose()).await?; - let state = AppState { - db, - cache, - settings, - }; - let state = AppState::from((db, settings)); + let state = AppState::from((db, cache, settings)); let app = router(state)?; From 8dc9017911d26e689209f2edfbe7d0b47797b6e1 Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 17:32:54 +0600 Subject: [PATCH 7/9] :bug: added cache in the settings in test --- server/tests/health_check.rs | 5 +++-- server/tests/users.rs | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/server/tests/health_check.rs b/server/tests/health_check.rs index cd7517e8..cdbcc18d 100644 --- a/server/tests/health_check.rs +++ b/server/tests/health_check.rs @@ -4,14 +4,15 @@ use tower::ServiceExt; use server::routing::router; use server::settings::Settings; -use server::startup::{db_connect, AppState}; +use server::startup::{db_connect, cache_connect, AppState}; #[tokio::test] async fn health_check_works() { let settings = Settings::new(); let db = db_connect(settings.db.expose()).await.unwrap(); - let state = AppState::from((db, settings)); + let cache = cache_connect(settings.cache.expose()).await.unwrap(); + let state = AppState::from((db, cache, settings)); let router = router(state).unwrap(); let request = Request::builder() diff --git a/server/tests/users.rs b/server/tests/users.rs index 0bf29f21..688d0620 100644 --- a/server/tests/users.rs +++ b/server/tests/users.rs @@ -10,6 +10,7 @@ use server::users::selectors::get_user; use server::Result; use sqlx::PgPool; use tower::ServiceExt; +use server::startup::cache_connect; /// Helper function to create a GET request for a given URI. fn _send_get_request(uri: &str) -> Request { @@ -47,7 +48,8 @@ async fn register_and_get_users_test(pool: PgPool) -> Result<()> { #[sqlx::test] async fn register_users_works(pool: PgPool) { let settings = Settings::new(); - let state = AppState::from((pool, settings)); + let cache = cache_connect(settings.cache.expose()).await.unwrap(); + let state = AppState::from((pool, cache, settings)); let router = router(state).unwrap(); let form = &[ From e8d59e36019214a1a824cbc00bf9a2ed77badf4f Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 17:39:38 +0600 Subject: [PATCH 8/9] added auth middleware with the search endpoints --- server/src/search/routes.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 14db471f..157da0ec 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -1,4 +1,5 @@ use crate::err::AppError; +use crate::users::User; use crate::search::services; use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; use crate::startup::AppState; @@ -14,9 +15,10 @@ use sqlx::PgPool; async fn get_search_handler( State(pool): State, State(cache): State, + user: User, Query(search_query): Query, ) -> crate::Result { - let user_id = uuid::Uuid::parse_str("78c4c766-f310-11ee-a6ee-5f4062fc15f2").unwrap(); + let user_id = user.user_id; let mut connection = cache .get_multiplexed_async_connection() @@ -44,9 +46,10 @@ async fn get_search_handler( #[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] async fn get_search_history_handler( State(pool): State, + user: User, Query(search_history_request): Query, ) -> crate::Result { - let user_id = uuid::Uuid::parse_str("78c4c766-f310-11ee-a6ee-5f4062fc15f2").unwrap(); + let user_id = user.user_id; let search_history = sqlx::query_as!( SearchHistory, From da516b73d1eb0d9e0e128578468542f3cceda953 Mon Sep 17 00:00:00 2001 From: Rathijit Paul <30369246+rathijitpapon@users.noreply.github.com> Date: Fri, 5 Apr 2024 19:27:15 +0600 Subject: [PATCH 9/9] add test code in the search functions --- server/src/search/routes.rs | 16 ++--- server/src/search/services.rs | 23 ++++++- server/tests/search.rs | 109 ++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 13 deletions(-) create mode 100644 server/tests/search.rs diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs index 157da0ec..37190a27 100644 --- a/server/src/search/routes.rs +++ b/server/src/search/routes.rs @@ -1,8 +1,8 @@ use crate::err::AppError; -use crate::users::User; use crate::search::services; -use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; +use crate::search::{SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; use crate::startup::AppState; +use crate::users::User; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; @@ -51,16 +51,8 @@ async fn get_search_history_handler( ) -> crate::Result { let user_id = user.user_id; - let search_history = sqlx::query_as!( - SearchHistory, - "select * from search_history where user_id = $1 order by created_at desc limit $2 offset $3", - user_id, - search_history_request.limit.unwrap_or(10) as i64, - search_history_request.offset.unwrap_or(0) as i64 - ) - .fetch_all(&pool) - .await - .map_err(|e| AppError::from(e))?; + let search_history = + services::get_search_history(&pool, &user_id, &search_history_request).await?; Ok((StatusCode::OK, Json(search_history))) } diff --git a/server/src/search/services.rs b/server/src/search/services.rs index 4e12fbe1..ac8a9417 100644 --- a/server/src/search/services.rs +++ b/server/src/search/services.rs @@ -1,6 +1,7 @@ use crate::err::AppError; use crate::search::{ - RAGTokenResponse, SearchHistory, SearchQueryRequest, SearchResponse, TopSearchRequest, + RAGTokenResponse, SearchHistory, SearchHistoryRequest, SearchQueryRequest, SearchResponse, + TopSearchRequest, }; use crate::settings::SETTINGS; use color_eyre::eyre::eyre; @@ -90,6 +91,26 @@ pub async fn insert_search_history( return Ok(search_history); } +#[tracing::instrument(level = "debug", ret, err)] +pub async fn get_search_history( + pool: &PgPool, + user_id: &Uuid, + search_history_request: &SearchHistoryRequest, +) -> crate::Result> { + let search_history = sqlx::query_as!( + SearchHistory, + "select * from search_history where user_id = $1 order by created_at desc limit $2 offset $3", + user_id, + search_history_request.limit.unwrap_or(10) as i64, + search_history_request.offset.unwrap_or(0) as i64 + ) + .fetch_all(pool) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(search_history); +} + #[tracing::instrument(level = "debug", ret, err)] pub async fn get_top_searches( cache: &mut MultiplexedConnection, diff --git a/server/tests/search.rs b/server/tests/search.rs new file mode 100644 index 00000000..d66264a6 --- /dev/null +++ b/server/tests/search.rs @@ -0,0 +1,109 @@ +use server::auth::models::RegisterUserRequest; +use server::auth::register; +use server::search::{get_search_history, get_top_searches, insert_search_history, search}; +use server::search::{SearchHistoryRequest, SearchQueryRequest, SearchResponse, TopSearchRequest}; +use server::settings::Settings; +use server::startup::cache_connect; +use server::Result; +use sqlx::PgPool; + +#[sqlx::test] +async fn search_test() -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let search_query = SearchQueryRequest { + query: "test".to_string(), + }; + + let search_result = search(&mut connection, &search_query).await; + + assert!(search_result.is_ok()); + + Ok(()) +} + +#[sqlx::test] +async fn top_searches_test() -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let top_search_query = TopSearchRequest { limit: Some(1) }; + + let top_searches_result = get_top_searches(&mut connection, &top_search_query).await; + + assert!(top_searches_result.is_ok()); + assert_eq!(top_searches_result.unwrap().len(), 1); + + Ok(()) +} + +#[sqlx::test] +async fn insert_search_and_get_search_history_test(pool: PgPool) -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let new_user = register( + pool.clone(), + RegisterUserRequest { + email: "test-email".to_string(), + username: "test-username".to_string(), + password: Some("password".to_string().into()), + access_token: Default::default(), + }, + ) + .await?; + + let user_id = new_user.user_id; + let search_query = SearchQueryRequest { + query: "test_query".to_string(), + }; + let search_response = SearchResponse { + result: "test_result".to_string(), + sources: vec!["test_source".to_string()], + }; + + let search_insertion_result = insert_search_history( + &pool, + &mut connection, + &user_id, + &search_query, + &search_response, + ) + .await; + + assert!(search_insertion_result.is_ok()); + + let search_history_request = SearchHistoryRequest { + limit: Some(1), + offset: Some(0), + }; + + let search_history_result = get_search_history(&pool, &user_id, &search_history_request).await; + + assert!(&search_history_result.is_ok()); + let search_history_result = search_history_result.unwrap(); + + assert_eq!(&search_history_result.len(), &1); + assert_eq!(&search_history_result[0].query, &search_query.query); + assert_eq!(search_history_result[0].user_id, user_id); + assert_eq!(search_history_result[0].result, search_response.result); + assert_eq!(search_history_result[0].sources, search_response.sources); + + Ok(()) +}