Skip to content

Commit

Permalink
Merge pull request #25 from curieo-org/search_endpoint
Browse files Browse the repository at this point in the history
added initial version of search APIs
  • Loading branch information
rathijitpapon authored Apr 5, 2024
2 parents 8e581a9 + da516b7 commit 305041d
Show file tree
Hide file tree
Showing 17 changed files with 508 additions and 11 deletions.
3 changes: 3 additions & 0 deletions server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ target/

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

# These will be created by the sqlx CLI
*.sqlx
46 changes: 46 additions & 0 deletions server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 18 additions & 3 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ name = "server"
path = "src/main.rs"

[dependencies]


anyhow = { version = "1.0.81", features = ["backtrace"] }
async-trait = "0.1.78"
axum = { version = "0.7.4", features = ["macros"] }
Expand All @@ -27,7 +29,14 @@ reqwest = { version = "0.12.1", features = ["json"] }
serde = "1.0.197"
serde_json = "1.0.114"
serde_urlencoded = "0.7.1"
sqlx = { version = "0.7.4", features = ["postgres", "runtime-tokio", "tls-rustls", "migrate", "uuid", "time"] }
sqlx = { version = "0.7.4", features = [
"postgres",
"runtime-tokio",
"tls-rustls",
"migrate",
"uuid",
"time",
] }
thiserror = "1.0.58"
tokio = { version = "1.36.0", features = ["full"] }
tower-http = { version = "0.5.2", features = ["trace", "cors"] }
Expand All @@ -36,9 +45,15 @@ tower = "0.4.13"
tracing-error = "0.2.0"
tracing-log = "0.2.0"
tracing-logfmt = "0.3.4"
tracing-subscriber = { version = "0.3.18", features = ["json", "registry", "env-filter"] }
uuid = { version = "1.8.0", features = ["serde"] }
tracing-subscriber = { version = "0.3.18", features = [
"json",
"registry",
"env-filter",
] }
uuid = { version = "1.8.0", features = ["serde", "v4"] }
log = "0.4.21"
redis = { version = "0.25.2", features = ["tokio-comp", "json"] }
rand = "0.8.5"
futures = "0.3.30"
axum-extra = { version = "0.9.3", features = ["typed-header"] }
http = "1.1.0"
6 changes: 6 additions & 0 deletions server/config/dev.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
db = "postgresql://postgres:postgres@localhost/curieo"
cache = "redis://127.0.0.1/"
cache_max_sorted_size = 100
rag_api = "http://127.0.0.1:8000"
rag_api_username = "curieo"
rag_api_password = "curieo"
oauth2_clients = []

[log]
level = "info"
Expand Down
12 changes: 12 additions & 0 deletions server/migrations/20240403115016_search_history.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
create table search_history (
search_history_id uuid primary key default uuid_generate_v1mc(),
user_id uuid not null references users(user_id),
query text not null,
result text not null,
sources text[] not null,
created_at timestamptz not null default now(),
updated_at timestamptz not null default now()
);

-- And applying our `updated_at` trigger is as easy as this.
SELECT trigger_updated_at('search_history');
10 changes: 10 additions & 0 deletions server/src/err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub enum AppError {
UnprocessableEntity(ErrorMap),
Sqlx(sqlx::Error),
GenericError(color_eyre::eyre::Error),
Redis(redis::RedisError),
}

impl AppError {
Expand All @@ -24,6 +25,7 @@ impl AppError {
Self::Unauthorized => StatusCode::UNAUTHORIZED,
Self::UnprocessableEntity { .. } => StatusCode::UNPROCESSABLE_ENTITY,
Self::Sqlx(_) | Self::GenericError(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Redis(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
/// Convenient constructor for `Error::UnprocessableEntity`.
Expand All @@ -49,6 +51,12 @@ impl From<sqlx::Error> for AppError {
}
}

impl From<redis::RedisError> for AppError {
fn from(inner: redis::RedisError) -> Self {
AppError::Redis(inner)
}
}

impl From<BackendError> for AppError {
fn from(e: BackendError) -> Self {
match e {
Expand All @@ -67,6 +75,7 @@ impl Display for AppError {
AppError::Sqlx(e) => write!(f, "{}", e),
AppError::UnprocessableEntity(e) => write!(f, "{:?}", e),
AppError::Unauthorized => write!(f, "Unauthorized"),
AppError::Redis(e) => write!(f, "{}", e),
}
}
}
Expand Down Expand Up @@ -118,6 +127,7 @@ impl IntoResponse for AppError {
}
AppError::Sqlx(ref e) => error!("SQLx error: {:?}", e),
AppError::GenericError(ref e) => error!("Generic error: {:?}", e),
AppError::Redis(ref e) => error!("Redis error: {:?}", e),
};

// Return a http status code and json body with error message.
Expand Down
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod auth;
mod err;
mod health_check;
pub mod routing;
pub mod search;
pub mod secrets;
pub mod settings;
pub mod startup;
Expand Down
4 changes: 3 additions & 1 deletion server/src/routing/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tracing::Level;

use crate::auth::models::PostgresBackend;
use crate::startup::AppState;
use crate::{auth, health_check, users};
use crate::{auth, health_check, search, users};

pub fn router(state: AppState) -> color_eyre::Result<Router> {
//sqlx::migrate!().run(&db).await?;
Expand All @@ -33,6 +33,8 @@ pub fn router(state: AppState) -> color_eyre::Result<Router> {

let api_routes = Router::new()
//.nest("/search", search::routes())
//.layer(middleware::from_fn(some_auth_middleware))
.nest("/search", search::routes())
.nest("/users", users::routes())
.route_layer(login_required!(
PostgresBackend,
Expand Down
7 changes: 7 additions & 0 deletions server/src/search/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub use models::*;
pub use routes::*;
pub use services::*;

pub mod models;
pub mod routes;
pub mod services;
44 changes: 44 additions & 0 deletions server/src/search/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};
use sqlx::types::time;
use sqlx::FromRow;
use std::fmt::Debug;

#[derive(Serialize, Deserialize, Debug)]
pub struct TopSearchRequest {
pub limit: Option<i64>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct RAGTokenResponse {
pub access_token: String,
pub token_type: String,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SearchQueryRequest {
pub query: String,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SearchHistoryRequest {
pub limit: Option<u8>,
pub offset: Option<u8>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct SearchResponse {
pub result: String,
pub sources: Vec<String>,
}

#[derive(FromRow, Serialize, Deserialize, Clone, Debug)]
pub struct SearchHistory {
pub search_history_id: uuid::Uuid,
pub user_id: uuid::Uuid,
pub query: String,
pub result: String,
pub sources: Vec<String>,

pub created_at: time::OffsetDateTime,
pub updated_at: time::OffsetDateTime,
}
80 changes: 80 additions & 0 deletions server/src/search/routes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use crate::err::AppError;
use crate::search::services;
use crate::search::{SearchHistoryRequest, SearchQueryRequest, TopSearchRequest};
use crate::startup::AppState;
use crate::users::User;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{Json, Router};
use redis::{AsyncCommands, Client as RedisClient};
use sqlx::PgPool;

#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))]
async fn get_search_handler(
State(pool): State<PgPool>,
State(cache): State<RedisClient>,
user: User,
Query(search_query): Query<SearchQueryRequest>,
) -> crate::Result<impl IntoResponse> {
let user_id = user.user_id;

let mut connection = cache
.get_multiplexed_async_connection()
.await
.map_err(|e| AppError::from(e))?;

let search_response = services::search(&mut connection, &search_query).await?;
services::insert_search_history(
&pool,
&mut connection,
&user_id,
&search_query,
&search_response,
)
.await?;

connection
.zincr("search_queries", &search_query.query, 1)
.await
.map_err(|e| AppError::from(e))?;

Ok((StatusCode::OK, Json(search_response)))
}

#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))]
async fn get_search_history_handler(
State(pool): State<PgPool>,
user: User,
Query(search_history_request): Query<SearchHistoryRequest>,
) -> crate::Result<impl IntoResponse> {
let user_id = user.user_id;

let search_history =
services::get_search_history(&pool, &user_id, &search_history_request).await?;

Ok((StatusCode::OK, Json(search_history)))
}

#[tracing::instrument(level = "debug", skip_all, ret, err(Debug))]
async fn get_top_searches_handler(
State(cache): State<RedisClient>,
Query(query): Query<TopSearchRequest>,
) -> crate::Result<impl IntoResponse> {
let mut connection = cache
.get_multiplexed_async_connection()
.await
.map_err(|e| AppError::from(e))?;

let top_searches = services::get_top_searches(&mut connection, &query).await?;

Ok((StatusCode::OK, Json(top_searches)))
}

pub fn routes() -> Router<AppState> {
Router::new()
.route("/", get(get_search_handler))
.route("/history", get(get_search_history_handler))
.route("/top", get(get_top_searches_handler))
}
Loading

0 comments on commit 305041d

Please sign in to comment.