diff --git a/server/.gitignore b/server/.gitignore index 9481ce52..a63b034d 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -27,3 +27,6 @@ target/ # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# These will be created by the sqlx CLI +*.sqlx diff --git a/server/Cargo.lock b/server/Cargo.lock index 3d425f62..c2a9b0c8 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -362,6 +362,20 @@ dependencies = [ "tracing-error", ] +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "config" version = "0.14.0" @@ -1741,6 +1755,29 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6472825949c09872e8f2c50bde59fcefc17748b6be5c90fd67cd8b4daca73bfd" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "serde", + "serde_json", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -2126,6 +2163,8 @@ dependencies = [ "oauth2", "once_cell", "password-auth", + "rand", + "redis", "reqwest 0.12.1", "serde", "serde_json", @@ -2154,6 +2193,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.8" @@ -3022,6 +3067,7 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ + "getrandom", "serde", ] diff --git a/server/Cargo.toml b/server/Cargo.toml index abb4337b..721ac226 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -11,6 +11,8 @@ name = "server" path = "src/main.rs" [dependencies] + + anyhow = { version = "1.0.81", features = ["backtrace"] } async-trait = "0.1.78" axum = { version = "0.7.4", features = ["macros"] } @@ -27,7 +29,14 @@ reqwest = { version = "0.12.1", features = ["json"] } serde = "1.0.197" serde_json = "1.0.114" serde_urlencoded = "0.7.1" -sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls", "migrate", "uuid", "time"] } +sqlx = { version = "0.7.4", features = [ + "postgres", + "runtime-tokio", + "tls-rustls", + "migrate", + "uuid", + "time", +] } thiserror = "1.0.58" tokio = { version = "1.36.0", features = ["full"] } tower-http = { version = "0.5.2", features = ["trace", "cors"] } @@ -36,9 +45,15 @@ tower = "0.4.13" tracing-error = "0.2.0" tracing-log = "0.2.0" tracing-logfmt = "0.3.4" -tracing-subscriber = { version = "0.3.18", features = ["json", "registry", "env-filter"] } -uuid = { version = "1.8.0", features = ["serde"] } +tracing-subscriber = { version = "0.3.18", features = [ + "json", + "registry", + "env-filter", +] } +uuid = { version = "1.8.0", features = ["serde", "v4"] } log = "0.4.21" +redis = { version = "0.25.2", features = ["tokio-comp", "json"] } +rand = "0.8.5" futures = "0.3.30" axum-extra = { version = "0.9.3", features = ["typed-header"] } http = "1.1.0" diff --git a/server/config/dev.toml b/server/config/dev.toml index a820754a..5f7a8f55 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -1,4 +1,10 @@ db = "postgresql://postgres:postgres@localhost/curieo" +cache = "redis://127.0.0.1/" +cache_max_sorted_size = 100 +rag_api = "http://127.0.0.1:8000" +rag_api_username = "curieo" +rag_api_password = "curieo" +oauth2_clients = [] [log] level = "info" diff --git a/server/migrations/20240403115016_search_history.sql b/server/migrations/20240403115016_search_history.sql new file mode 100644 index 00000000..1bcbb603 --- /dev/null +++ b/server/migrations/20240403115016_search_history.sql @@ -0,0 +1,12 @@ +create table search_history ( + search_history_id uuid primary key default uuid_generate_v1mc(), + user_id uuid not null references users(user_id), + query text not null, + result text not null, + sources text[] not null, + created_at timestamptz not null default now(), + updated_at timestamptz not null default now() +); + +-- And applying our `updated_at` trigger is as easy as this. +SELECT trigger_updated_at('search_history'); diff --git a/server/src/err.rs b/server/src/err.rs index dbf18970..4d057488 100644 --- a/server/src/err.rs +++ b/server/src/err.rs @@ -16,6 +16,7 @@ pub enum AppError { UnprocessableEntity(ErrorMap), Sqlx(sqlx::Error), GenericError(color_eyre::eyre::Error), + Redis(redis::RedisError), } impl AppError { @@ -24,6 +25,7 @@ impl AppError { Self::Unauthorized => StatusCode::UNAUTHORIZED, Self::UnprocessableEntity { .. } => StatusCode::UNPROCESSABLE_ENTITY, Self::Sqlx(_) | Self::GenericError(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Redis(_) => StatusCode::INTERNAL_SERVER_ERROR, } } /// Convenient constructor for `Error::UnprocessableEntity`. @@ -49,6 +51,12 @@ impl From for AppError { } } +impl From for AppError { + fn from(inner: redis::RedisError) -> Self { + AppError::Redis(inner) + } +} + impl From for AppError { fn from(e: BackendError) -> Self { match e { @@ -67,6 +75,7 @@ impl Display for AppError { AppError::Sqlx(e) => write!(f, "{}", e), AppError::UnprocessableEntity(e) => write!(f, "{:?}", e), AppError::Unauthorized => write!(f, "Unauthorized"), + AppError::Redis(e) => write!(f, "{}", e), } } } @@ -118,6 +127,7 @@ impl IntoResponse for AppError { } AppError::Sqlx(ref e) => error!("SQLx error: {:?}", e), AppError::GenericError(ref e) => error!("Generic error: {:?}", e), + AppError::Redis(ref e) => error!("Redis error: {:?}", e), }; // Return a http status code and json body with error message. diff --git a/server/src/lib.rs b/server/src/lib.rs index 2a03549a..ef9dee7b 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod auth; mod err; mod health_check; pub mod routing; +pub mod search; pub mod secrets; pub mod settings; pub mod startup; diff --git a/server/src/routing/api.rs b/server/src/routing/api.rs index 996b5748..9486f068 100644 --- a/server/src/routing/api.rs +++ b/server/src/routing/api.rs @@ -8,7 +8,7 @@ use tracing::Level; use crate::auth::models::PostgresBackend; use crate::startup::AppState; -use crate::{auth, health_check, users}; +use crate::{auth, health_check, search, users}; pub fn router(state: AppState) -> color_eyre::Result { //sqlx::migrate!().run(&db).await?; @@ -33,6 +33,8 @@ pub fn router(state: AppState) -> color_eyre::Result { let api_routes = Router::new() //.nest("/search", search::routes()) + //.layer(middleware::from_fn(some_auth_middleware)) + .nest("/search", search::routes()) .nest("/users", users::routes()) .route_layer(login_required!( PostgresBackend, diff --git a/server/src/search/mod.rs b/server/src/search/mod.rs new file mode 100644 index 00000000..1d3d1250 --- /dev/null +++ b/server/src/search/mod.rs @@ -0,0 +1,7 @@ +pub use models::*; +pub use routes::*; +pub use services::*; + +pub mod models; +pub mod routes; +pub mod services; diff --git a/server/src/search/models.rs b/server/src/search/models.rs new file mode 100644 index 00000000..24d5f17e --- /dev/null +++ b/server/src/search/models.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; +use sqlx::types::time; +use sqlx::FromRow; +use std::fmt::Debug; + +#[derive(Serialize, Deserialize, Debug)] +pub struct TopSearchRequest { + pub limit: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct RAGTokenResponse { + pub access_token: String, + pub token_type: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchQueryRequest { + pub query: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchHistoryRequest { + pub limit: Option, + pub offset: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchResponse { + pub result: String, + pub sources: Vec, +} + +#[derive(FromRow, Serialize, Deserialize, Clone, Debug)] +pub struct SearchHistory { + pub search_history_id: uuid::Uuid, + pub user_id: uuid::Uuid, + pub query: String, + pub result: String, + pub sources: Vec, + + pub created_at: time::OffsetDateTime, + pub updated_at: time::OffsetDateTime, +} diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs new file mode 100644 index 00000000..37190a27 --- /dev/null +++ b/server/src/search/routes.rs @@ -0,0 +1,80 @@ +use crate::err::AppError; +use crate::search::services; +use crate::search::{SearchHistoryRequest, SearchQueryRequest, TopSearchRequest}; +use crate::startup::AppState; +use crate::users::User; +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::{Json, Router}; +use redis::{AsyncCommands, Client as RedisClient}; +use sqlx::PgPool; + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_handler( + State(pool): State, + State(cache): State, + user: User, + Query(search_query): Query, +) -> crate::Result { + let user_id = user.user_id; + + let mut connection = cache + .get_multiplexed_async_connection() + .await + .map_err(|e| AppError::from(e))?; + + let search_response = services::search(&mut connection, &search_query).await?; + services::insert_search_history( + &pool, + &mut connection, + &user_id, + &search_query, + &search_response, + ) + .await?; + + connection + .zincr("search_queries", &search_query.query, 1) + .await + .map_err(|e| AppError::from(e))?; + + Ok((StatusCode::OK, Json(search_response))) +} + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_history_handler( + State(pool): State, + user: User, + Query(search_history_request): Query, +) -> crate::Result { + let user_id = user.user_id; + + let search_history = + services::get_search_history(&pool, &user_id, &search_history_request).await?; + + Ok((StatusCode::OK, Json(search_history))) +} + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_top_searches_handler( + State(cache): State, + Query(query): Query, +) -> crate::Result { + let mut connection = cache + .get_multiplexed_async_connection() + .await + .map_err(|e| AppError::from(e))?; + + let top_searches = services::get_top_searches(&mut connection, &query).await?; + + Ok((StatusCode::OK, Json(top_searches))) +} + +pub fn routes() -> Router { + Router::new() + .route("/", get(get_search_handler)) + .route("/history", get(get_search_history_handler)) + .route("/top", get(get_top_searches_handler)) +} diff --git a/server/src/search/services.rs b/server/src/search/services.rs new file mode 100644 index 00000000..ac8a9417 --- /dev/null +++ b/server/src/search/services.rs @@ -0,0 +1,142 @@ +use crate::err::AppError; +use crate::search::{ + RAGTokenResponse, SearchHistory, SearchHistoryRequest, SearchQueryRequest, SearchResponse, + TopSearchRequest, +}; +use crate::settings::SETTINGS; +use color_eyre::eyre::eyre; +use rand::Rng; +use redis::aio::MultiplexedConnection; +use redis::AsyncCommands; +use reqwest::Client as ReqwestClient; +use sqlx::PgPool; +use uuid::Uuid; + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn search( + cache: &mut MultiplexedConnection, + search_query: &SearchQueryRequest, +) -> crate::Result { + let cache_response: Option = cache + .get(search_query.query.clone()) + .await + .map(|response: Option| { + response.and_then(|response| serde_json::from_str(&response).ok()) + }) + .map_err(|e| AppError::from(e))?; + + if let Some(response) = cache_response { + return Ok(response); + } + + // TODO: replace this with actual search logic using GRPC calls with backend services + let rag_api_url = SETTINGS.rag_api.clone() + "/token"; + let form_data = [ + ("username", &SETTINGS.rag_api_username.expose()), + ("password", &SETTINGS.rag_api_password.expose()), + ]; + let token: RAGTokenResponse = ReqwestClient::new() + .post(rag_api_url) + .form(&form_data) + .send() + .await + .map_err(|_| eyre!("unable to send request to rag api"))? + .json() + .await + .map_err(|_| eyre!("unable to parse json response from rag api"))?; + + let rag_api_url = SETTINGS.rag_api.clone() + "/search?query=" + &search_query.query; + let response: SearchResponse = ReqwestClient::new() + .get(rag_api_url) + .header("Authorization", format!("Bearer {}", token.access_token)) + .send() + .await + .map_err(|_| eyre!("unable to send request to rag api"))? + .json() + .await + .map_err(|_| eyre!("unable to parse json response from rag api"))?; + + return Ok(response); +} + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn insert_search_history( + pool: &PgPool, + cache: &mut MultiplexedConnection, + user_id: &Uuid, + search_query: &SearchQueryRequest, + search_response: &SearchResponse, +) -> crate::Result { + cache + .set( + &search_query.query, + serde_json::to_string(&search_response) + .map_err(|_| eyre!("unable to convert string to json"))?, + ) + .await + .map_err(|e| AppError::from(e))?; + + let search_history = sqlx::query_as!( + SearchHistory, + "insert into search_history (user_id, query, result, sources) values ($1, $2, $3, $4) returning *", + user_id, + search_query.query, + search_response.result, + &search_response.sources + ) + .fetch_one(pool) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(search_history); +} + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn get_search_history( + pool: &PgPool, + user_id: &Uuid, + search_history_request: &SearchHistoryRequest, +) -> crate::Result> { + let search_history = sqlx::query_as!( + SearchHistory, + "select * from search_history where user_id = $1 order by created_at desc limit $2 offset $3", + user_id, + search_history_request.limit.unwrap_or(10) as i64, + search_history_request.offset.unwrap_or(0) as i64 + ) + .fetch_all(pool) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(search_history); +} + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn get_top_searches( + cache: &mut MultiplexedConnection, + top_search_request: &TopSearchRequest, +) -> crate::Result> { + let random_number = rand::thread_rng().gen_range(0.0..1.0); + if random_number < 0.1 { + cache + .zremrangebyrank( + "search_history", + 0, + -SETTINGS.cache_max_sorted_size as isize - 1, + ) + .await + .map_err(|e| AppError::from(e))?; + } + + let limit = top_search_request.limit.unwrap_or(10); + if limit < 1 || limit > 100 { + Err(eyre!("limit must be a number between 1 and 100"))?; + } + + let top_searches: Vec = cache + .zrevrange("search_queries", 0, limit as isize - 1) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(top_searches); +} diff --git a/server/src/settings.rs b/server/src/settings.rs index 0a93e988..11fed010 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -1,5 +1,4 @@ use std::{env, fmt::Display}; - use crate::auth::oauth2::OAuth2Client; use crate::secrets::Secret; use config::{Config, Environment, File}; @@ -78,6 +77,11 @@ pub struct Settings { pub host: String, pub port: u16, pub db: Secret, + pub cache: Secret, + pub rag_api: String, + pub rag_api_username: Secret, + pub rag_api_password: Secret, + pub cache_max_sorted_size: i64, pub oauth2_clients: Vec, } diff --git a/server/src/startup.rs b/server/src/startup.rs index 2d91f3e5..ffc092df 100644 --- a/server/src/startup.rs +++ b/server/src/startup.rs @@ -4,6 +4,7 @@ use crate::settings::Settings; use crate::Result; use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router}; use color_eyre::eyre::eyre; +use redis::Client as RedisClient; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tokio::net::TcpListener; @@ -44,14 +45,16 @@ impl Application { #[derive(Clone, Debug, FromRef)] pub struct AppState { pub db: PgPool, + pub cache: RedisClient, pub oauth2_clients: Vec, pub settings: Settings, } -impl From<(PgPool, Settings)> for AppState { - fn from((db, settings): (PgPool, Settings)) -> Self { +impl From<(PgPool, RedisClient, Settings)> for AppState { + fn from((db, cache, settings): (PgPool, RedisClient, Settings)) -> Self { Self { db, + cache, oauth2_clients: settings.oauth2_clients.clone(), settings, } @@ -69,12 +72,22 @@ pub async fn db_connect(database_url: &str) -> Result { } } +pub async fn cache_connect(cache_url: &str) -> Result { + match RedisClient::open(cache_url) { + Ok(client) => Ok(client), + Err(e) => Err(eyre!("Failed to connect to Redis: {}", e).into()), + } +} + async fn run( listener: TcpListener, settings: Settings, ) -> Result, Router>> { let db = db_connect(settings.db.expose()).await?; - let state = AppState::from((db, settings)); + + let cache = cache_connect(settings.cache.expose()).await?; + + let state = AppState::from((db, cache, settings)); let app = router(state)?; diff --git a/server/tests/health_check.rs b/server/tests/health_check.rs index cd7517e8..cdbcc18d 100644 --- a/server/tests/health_check.rs +++ b/server/tests/health_check.rs @@ -4,14 +4,15 @@ use tower::ServiceExt; use server::routing::router; use server::settings::Settings; -use server::startup::{db_connect, AppState}; +use server::startup::{db_connect, cache_connect, AppState}; #[tokio::test] async fn health_check_works() { let settings = Settings::new(); let db = db_connect(settings.db.expose()).await.unwrap(); - let state = AppState::from((db, settings)); + let cache = cache_connect(settings.cache.expose()).await.unwrap(); + let state = AppState::from((db, cache, settings)); let router = router(state).unwrap(); let request = Request::builder() diff --git a/server/tests/search.rs b/server/tests/search.rs new file mode 100644 index 00000000..d66264a6 --- /dev/null +++ b/server/tests/search.rs @@ -0,0 +1,109 @@ +use server::auth::models::RegisterUserRequest; +use server::auth::register; +use server::search::{get_search_history, get_top_searches, insert_search_history, search}; +use server::search::{SearchHistoryRequest, SearchQueryRequest, SearchResponse, TopSearchRequest}; +use server::settings::Settings; +use server::startup::cache_connect; +use server::Result; +use sqlx::PgPool; + +#[sqlx::test] +async fn search_test() -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let search_query = SearchQueryRequest { + query: "test".to_string(), + }; + + let search_result = search(&mut connection, &search_query).await; + + assert!(search_result.is_ok()); + + Ok(()) +} + +#[sqlx::test] +async fn top_searches_test() -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let top_search_query = TopSearchRequest { limit: Some(1) }; + + let top_searches_result = get_top_searches(&mut connection, &top_search_query).await; + + assert!(top_searches_result.is_ok()); + assert_eq!(top_searches_result.unwrap().len(), 1); + + Ok(()) +} + +#[sqlx::test] +async fn insert_search_and_get_search_history_test(pool: PgPool) -> Result<()> { + let settings = Settings::new(); + let mut connection = cache_connect(settings.cache.expose()) + .await + .unwrap() + .get_multiplexed_async_connection() + .await + .unwrap(); + + let new_user = register( + pool.clone(), + RegisterUserRequest { + email: "test-email".to_string(), + username: "test-username".to_string(), + password: Some("password".to_string().into()), + access_token: Default::default(), + }, + ) + .await?; + + let user_id = new_user.user_id; + let search_query = SearchQueryRequest { + query: "test_query".to_string(), + }; + let search_response = SearchResponse { + result: "test_result".to_string(), + sources: vec!["test_source".to_string()], + }; + + let search_insertion_result = insert_search_history( + &pool, + &mut connection, + &user_id, + &search_query, + &search_response, + ) + .await; + + assert!(search_insertion_result.is_ok()); + + let search_history_request = SearchHistoryRequest { + limit: Some(1), + offset: Some(0), + }; + + let search_history_result = get_search_history(&pool, &user_id, &search_history_request).await; + + assert!(&search_history_result.is_ok()); + let search_history_result = search_history_result.unwrap(); + + assert_eq!(&search_history_result.len(), &1); + assert_eq!(&search_history_result[0].query, &search_query.query); + assert_eq!(search_history_result[0].user_id, user_id); + assert_eq!(search_history_result[0].result, search_response.result); + assert_eq!(search_history_result[0].sources, search_response.sources); + + Ok(()) +} diff --git a/server/tests/users.rs b/server/tests/users.rs index 0bf29f21..688d0620 100644 --- a/server/tests/users.rs +++ b/server/tests/users.rs @@ -10,6 +10,7 @@ use server::users::selectors::get_user; use server::Result; use sqlx::PgPool; use tower::ServiceExt; +use server::startup::cache_connect; /// Helper function to create a GET request for a given URI. fn _send_get_request(uri: &str) -> Request { @@ -47,7 +48,8 @@ async fn register_and_get_users_test(pool: PgPool) -> Result<()> { #[sqlx::test] async fn register_users_works(pool: PgPool) { let settings = Settings::new(); - let state = AppState::from((pool, settings)); + let cache = cache_connect(settings.cache.expose()).await.unwrap(); + let state = AppState::from((pool, cache, settings)); let router = router(state).unwrap(); let form = &[