diff --git a/server/.gitignore b/server/.gitignore index 9481ce52..a63b034d 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -27,3 +27,6 @@ target/ # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +# These will be created by the sqlx CLI +*.sqlx diff --git a/server/Cargo.lock b/server/Cargo.lock index 3dd73cef..972c5c5c 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -339,6 +339,20 @@ dependencies = [ "tracing-error", ] +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "config" version = "0.14.0" @@ -1692,6 +1706,29 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6472825949c09872e8f2c50bde59fcefc17748b6be5c90fd67cd8b4daca73bfd" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "serde", + "serde_json", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -2074,6 +2111,7 @@ dependencies = [ "oauth2", "once_cell", "password-auth", + "redis", "reqwest 0.12.1", "serde", "serde_json", @@ -2102,6 +2140,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.8" diff --git a/server/Cargo.toml b/server/Cargo.toml index b385f9cd..00399630 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -27,7 +27,14 @@ reqwest = { version = "0.12.1", features = ["json"] } serde = "1.0.197" serde_json = "1.0.114" serde_urlencoded = "0.7.1" -sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls", "migrate", "uuid", "time"] } +sqlx = { version = "0.7.4", features = [ + "postgres", + "runtime-tokio", + "tls-rustls", + "migrate", + "uuid", + "time", +] } thiserror = "1.0.58" tokio = { version = "1.36.0", features = ["full"] } tower-http = { version = "0.5.2", features = ["trace", "cors"] } @@ -36,6 +43,11 @@ tower = "0.4.13" tracing-error = "0.2.0" tracing-log = "0.2.0" tracing-logfmt = "0.3.4" -tracing-subscriber = { version = "0.3.18", features = ["json", "registry", "env-filter"] } +tracing-subscriber = { version = "0.3.18", features = [ + "json", + "registry", + "env-filter", +] } uuid = { version = "1.8.0", features = ["serde"] } log = "0.4.21" +redis = { version = "0.25.2", features = ["tokio-comp", "json"] } diff --git a/server/config/dev.toml b/server/config/dev.toml index a820754a..456ee9d1 100644 --- a/server/config/dev.toml +++ b/server/config/dev.toml @@ -1,4 +1,5 @@ db = "postgresql://postgres:postgres@localhost/curieo" +cache = "redis://127.0.0.1/" [log] level = "info" diff --git a/server/migrations/20240403115016_search_history.sql b/server/migrations/20240403115016_search_history.sql new file mode 100644 index 00000000..36a52a50 --- /dev/null +++ b/server/migrations/20240403115016_search_history.sql @@ -0,0 +1,11 @@ +create table search_history ( + search_history_id uuid primary key default uuid_generate_v1mc(), + search_text text not null, + response_text text not null, + response_sources text[] not null, + created_at timestamptz not null default now(), + updated_at timestamptz not null default now() +); + +-- And applying our `updated_at` trigger is as easy as this. +SELECT trigger_updated_at('search_history'); diff --git a/server/src/err.rs b/server/src/err.rs index 9562d515..5087db1c 100644 --- a/server/src/err.rs +++ b/server/src/err.rs @@ -14,6 +14,7 @@ pub enum AppError { UnprocessableEntity(ErrorMap), Sqlx(sqlx::Error), GenericError(color_eyre::eyre::Error), + Redis(redis::RedisError), } impl AppError { @@ -22,6 +23,7 @@ impl AppError { Self::Unauthorized => StatusCode::UNAUTHORIZED, Self::UnprocessableEntity { .. } => StatusCode::UNPROCESSABLE_ENTITY, Self::Sqlx(_) | Self::GenericError(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Redis(_) => StatusCode::INTERNAL_SERVER_ERROR, } } /// Convenient constructor for `Error::UnprocessableEntity`. @@ -47,6 +49,12 @@ impl From for AppError { } } +impl From for AppError { + fn from(inner: redis::RedisError) -> Self { + AppError::Redis(inner) + } +} + impl Display for AppError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -54,6 +62,7 @@ impl Display for AppError { AppError::Sqlx(e) => write!(f, "{}", e), AppError::UnprocessableEntity(e) => write!(f, "{:?}", e), AppError::Unauthorized => write!(f, "Unauthorized"), + AppError::Redis(e) => write!(f, "{}", e), } } } @@ -105,6 +114,7 @@ impl IntoResponse for AppError { } AppError::Sqlx(ref e) => error!("SQLx error: {:?}", e), AppError::GenericError(ref e) => error!("Generic error: {:?}", e), + AppError::Redis(ref e) => error!("Redis error: {:?}", e), }; // Return a http status code and json body with error message. diff --git a/server/src/lib.rs b/server/src/lib.rs index e2fbfcc3..5a865443 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod auth; mod err; mod health_check; pub mod routing; +pub mod search; pub mod secrets; pub mod settings; pub mod startup; diff --git a/server/src/routing/api.rs b/server/src/routing/api.rs index 608e9aa4..dca9bae3 100644 --- a/server/src/routing/api.rs +++ b/server/src/routing/api.rs @@ -8,7 +8,7 @@ use tracing::Level; use crate::auth::models::{OAuth2Clients, PostgresBackend}; use crate::startup::AppState; -use crate::{auth, health_check, users}; +use crate::{auth, health_check, search, users}; pub fn router(state: AppState) -> color_eyre::Result { // Session layer. @@ -33,6 +33,7 @@ pub fn router(state: AppState) -> color_eyre::Result { let api_routes = Router::new() //.nest("/search", search::routes()) //.layer(middleware::from_fn(some_auth_middleware)) + .nest("/search", search::routes()) .nest("/users", users::routes()) .route_layer(login_required!( PostgresBackend, diff --git a/server/src/search/mod.rs b/server/src/search/mod.rs new file mode 100644 index 00000000..1d3d1250 --- /dev/null +++ b/server/src/search/mod.rs @@ -0,0 +1,7 @@ +pub use models::*; +pub use routes::*; +pub use services::*; + +pub mod models; +pub mod routes; +pub mod services; diff --git a/server/src/search/models.rs b/server/src/search/models.rs new file mode 100644 index 00000000..94fef466 --- /dev/null +++ b/server/src/search/models.rs @@ -0,0 +1,33 @@ +use serde::{Deserialize, Serialize}; +use sqlx::types::time; +use sqlx::FromRow; +use std::fmt::Debug; + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchQueryRequest { + pub query: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchHistoryRequest { + pub limit: Option, + pub offset: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SearchResponse { + pub response_text: String, + pub response_sources: Vec, +} + +#[derive(FromRow, Serialize, Deserialize, Clone, Debug)] +pub struct SearchHistory { + pub search_history_id: uuid::Uuid, + // pub user_id: uuid::Uuid, + pub search_text: String, + pub response_text: String, + pub response_sources: Vec, + + pub created_at: time::OffsetDateTime, + pub updated_at: time::OffsetDateTime, +} diff --git a/server/src/search/routes.rs b/server/src/search/routes.rs new file mode 100644 index 00000000..83b4b524 --- /dev/null +++ b/server/src/search/routes.rs @@ -0,0 +1,53 @@ +use crate::err::AppError; +use crate::search::services; +use crate::search::{SearchHistory, SearchHistoryRequest, SearchQueryRequest}; +use crate::startup::AppState; +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::{Json, Router}; +use redis::Client as RedisClient; +use sqlx::PgPool; + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_handler( + State(pool): State, + State(cache): State, + Query(search_query): Query, +) -> crate::Result { + let mut connection = cache + .get_multiplexed_async_connection() + .await + .map_err(|e| AppError::from(e))?; + + let search_response = services::search(&mut connection, &search_query).await?; + services::insert_search_history(&pool, &mut connection, &search_query, &search_response) + .await?; + + Ok((StatusCode::OK, Json(search_response))) +} + +#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] +async fn get_search_history_handler( + State(pool): State, + Query(search_history_request): Query, +) -> crate::Result { + let search_history = sqlx::query_as!( + SearchHistory, + "select * from search_history order by created_at desc limit $1 offset $2", + search_history_request.limit.unwrap_or(10) as i64, + search_history_request.offset.unwrap_or(0) as i64 + ) + .fetch_all(&pool) + .await + .map_err(|e| AppError::from(e))?; + + Ok((StatusCode::OK, Json(search_history))) +} + +pub fn routes() -> Router { + Router::new() + .route("/", get(get_search_handler)) + .route("/history", get(get_search_history_handler)) +} diff --git a/server/src/search/services.rs b/server/src/search/services.rs new file mode 100644 index 00000000..19153c21 --- /dev/null +++ b/server/src/search/services.rs @@ -0,0 +1,70 @@ +use crate::err::AppError; +use crate::search::{SearchHistory, SearchQueryRequest, SearchResponse}; +use color_eyre::eyre::eyre; +use redis::aio::MultiplexedConnection; +use redis::AsyncCommands; +use sqlx::PgPool; + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn search( + cache: &mut MultiplexedConnection, + search_query: &SearchQueryRequest, +) -> crate::Result { + let cache_response: Option = cache + .get(search_query.query.clone()) + .await + .map_err(|e| AppError::from(e))?; + + let cache_response = match cache_response { + Some(response) => response, + None => String::new(), + }; + + let cache_response: Option = + serde_json::from_str(&cache_response).unwrap_or(None); + + if let Some(response) = cache_response { + return Ok(response); + } + + // sleep for 3 seconds to simulate a slow search + // TODO: replace this with actual search logic using GRPC calls with backend services + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + + let response = SearchResponse { + response_text: "sample response".to_string(), + response_sources: vec!["www.source1.com".to_string(), "www.source2.com".to_string()], + }; + + return Ok(response); +} + +#[tracing::instrument(level = "debug", ret, err)] +pub async fn insert_search_history( + pool: &PgPool, + cache: &mut MultiplexedConnection, + search_query: &SearchQueryRequest, + search_response: &SearchResponse, +) -> crate::Result { + cache + .set( + &search_query.query, + serde_json::to_string(&search_response) + .map_err(|_| eyre!("unable to convert string to json"))?, + ) + .await + .map_err(|e| AppError::from(e))?; + + let search_history = sqlx::query_as!( + SearchHistory, + "insert into search_history (search_text, response_text, response_sources) values ($1, $2, $3) returning *", + &search_query.query, + &search_response.response_text, + &search_response.response_sources + ) + .fetch_one(pool) + .await + .map_err(|e| AppError::from(e))?; + + return Ok(search_history); +} diff --git a/server/src/settings.rs b/server/src/settings.rs index 7c0703d2..e2433770 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -78,6 +78,7 @@ pub struct Settings { pub host: String, pub port: u16, pub db: Secret, + pub cache: Secret, } impl Settings { diff --git a/server/src/startup.rs b/server/src/startup.rs index d54dda0b..7a20e994 100644 --- a/server/src/startup.rs +++ b/server/src/startup.rs @@ -1,5 +1,6 @@ use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router}; use color_eyre::eyre::eyre; +use redis::Client as RedisClient; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tokio::net::TcpListener; @@ -44,6 +45,7 @@ impl Application { #[derive(Clone, Debug, FromRef)] pub struct AppState { pub db: PgPool, + pub cache: redis::Client, pub settings: Settings, } @@ -58,13 +60,26 @@ pub async fn db_connect(database_url: &str) -> Result { } } +pub async fn cache_connect(cache_url: &str) -> Result { + match RedisClient::open(cache_url) { + Ok(client) => Ok(client), + Err(e) => Err(eyre!("Failed to connect to Redis: {}", e).into()), + } +} + async fn run( listener: TcpListener, settings: Settings, ) -> Result, Router>> { let db = db_connect(settings.db.expose()).await?; - let state = AppState { db, settings }; + let cache = cache_connect(settings.cache.expose()).await?; + + let state = AppState { + db, + cache, + settings, + }; let app = router(state)?;