From 459447807ca2980f07ebf120372e30245436e64f Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 5 Apr 2024 11:25:00 +0200 Subject: [PATCH] One big commit. Sorry reviewers. --- server/Cargo.lock | 52 ++++++ server/Cargo.toml | 3 + server/config/default.toml | 2 + server/src/auth/mod.rs | 1 + server/src/auth/models.rs | 305 ++++++++++++---------------------- server/src/auth/oauth2.rs | 147 ++++++++++++++++ server/src/auth/routes.rs | 131 ++++++++++----- server/src/auth/services.rs | 23 ++- server/src/auth/utils.rs | 18 +- server/src/err.rs | 13 ++ server/src/lib.rs | 1 + server/src/routing/api.rs | 12 +- server/src/secrets.rs | 28 +--- server/src/settings.rs | 5 +- server/src/startup.rs | 22 ++- server/src/users/models.rs | 46 ++++- server/src/users/routes.rs | 18 +- server/src/users/selectors.rs | 6 +- server/src/utils.rs | 1 + server/tests/health_check.rs | 7 +- server/tests/users.rs | 63 +++---- 21 files changed, 539 insertions(+), 365 deletions(-) create mode 100644 server/src/auth/oauth2.rs create mode 100644 server/src/utils.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index 3dd73cef..3d425f62 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -163,6 +163,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-util", + "headers", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-login" version = "0.15.0" @@ -646,6 +669,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -725,6 +749,7 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -828,6 +853,30 @@ dependencies = [ "hashbrown 0.14.3", ] +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 1.1.0", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.1.0", +] + [[package]] name = "heck" version = "0.4.1" @@ -2064,11 +2113,14 @@ dependencies = [ "anyhow", "async-trait", "axum", + "axum-extra", "axum-login", "chrono", "color-eyre", "config", "dotenvy", + "futures", + "http 1.1.0", "hyper 1.2.0", "log", "oauth2", diff --git a/server/Cargo.toml b/server/Cargo.toml index b385f9cd..abb4337b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -39,3 +39,6 @@ tracing-logfmt = "0.3.4" tracing-subscriber = { version = "0.3.18", features = ["json", "registry", "env-filter"] } uuid = { version = "1.8.0", features = ["serde"] } log = "0.4.21" +futures = "0.3.30" +axum-extra = { version = "0.9.3", features = ["typed-header"] } +http = "1.1.0" diff --git a/server/config/default.toml b/server/config/default.toml index 63e06b57..0459bc81 100644 --- a/server/config/default.toml +++ b/server/config/default.toml @@ -4,3 +4,5 @@ port = 3030 [log] level = "debug" format = "pretty" + +[[oauth2_clients]] \ No newline at end of file diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index a4b6579e..cfc490d4 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -2,6 +2,7 @@ pub use models::*; pub use routes::*; pub use services::*; pub mod models; +pub mod oauth2; pub mod routes; pub mod services; mod utils; diff --git a/server/src/auth/models.rs b/server/src/auth/models.rs index 75a8848c..667c4497 100644 --- a/server/src/auth/models.rs +++ b/server/src/auth/models.rs @@ -1,3 +1,4 @@ +use crate::auth::oauth2::OAuth2Client; use crate::auth::utils; use crate::secrets::Secret; use crate::telemetry::spawn_blocking_with_tracing; @@ -5,41 +6,28 @@ use crate::users::User; use async_trait::async_trait; use axum::http::header::{AUTHORIZATION, USER_AGENT}; use axum_login::{AuthnBackend, UserId}; -use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse}; -use oauth2::reqwest::{async_http_client, AsyncHttpClientError}; +use oauth2::basic::BasicRequestTokenError; +use oauth2::reqwest::AsyncHttpClientError; use oauth2::url::Url; -use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, - PkceCodeVerifier, RevocationUrl, TokenResponse, TokenUrl, -}; +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeVerifier, TokenResponse}; use serde::de::Error; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Deserializer}; use sqlx::PgPool; use tokio::task; -use tracing::debug; -#[derive(Deserialize, Serialize, Clone, Debug)] +#[derive(Deserialize, Clone, Debug)] #[serde(remote = "Self")] pub struct RegisterUserRequest { pub email: String, pub username: String, - pub password_hash: Option>, + pub password: Option>, pub access_token: Option>, } -impl Serialize for RegisterUserRequest { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - Self::serialize(self, serializer) - } -} - impl<'de> Deserialize<'de> for RegisterUserRequest { fn deserialize>(deserializer: D) -> Result { let s = Self::deserialize(deserializer)?; - if s.password_hash.is_some() && s.access_token.is_some() { + if s.password.is_some() && s.access_token.is_some() { return Err(Error::custom("should only have password or access token")); } @@ -47,105 +35,104 @@ impl<'de> Deserialize<'de> for RegisterUserRequest { } } -#[derive(Debug, Clone)] -pub struct PostgresBackend { - db: PgPool, - oauth_clients: OAuth2Clients, -} - -#[derive(Debug, Clone)] -pub struct OAuth2Clients { - pub github: GitHubOAuthClient, - pub google: GoogleOAuthClient, -} - -impl Default for OAuth2Clients { - fn default() -> Self { - Self { - github: GitHubOAuthClient, - google: GoogleOAuthClient, - } - } -} - -#[async_trait] -trait OAuth2Client { - const CLIENT_ID: &'static str; - const CLIENT_SECRET: &'static str; - const AUTH_URL: &'static str; - const TOKEN_URL: &'static str; - const REVOCATION_URL: &'static str; - - fn client(&self) -> BasicClient { - BasicClient::new( - ClientId::new(Self::CLIENT_ID.to_string()), - Some(ClientSecret::new(Self::CLIENT_SECRET.to_string())), - AuthUrl::new(Self::AUTH_URL.to_string()).expect("invalid auth url"), - Some(TokenUrl::new(Self::TOKEN_URL.to_string()).expect("invalid token url")), - ) - .set_revocation_uri( - RevocationUrl::new(Self::REVOCATION_URL.to_string()).expect("invalid auth url"), - ) - } - fn authorize_url(&self, state: S) -> (Url, CsrfToken, PkceCodeVerifier) - where - S: FnOnce() -> CsrfToken, - { - let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); - let (url, token) = self - .client() - .authorize_url(state) - //.add_scope(Scope::new("user:email".to_string())) - .set_pkce_challenge(pkce_code_challenge) - .url(); - - (url, token, pkce_code_verifier) - } - async fn exchange_code( - &self, - code: AuthorizationCode, - ) -> Result> { - self.client() - .exchange_code(code) - .request_async(async_http_client) - .await - } -} - -#[derive(Debug, Clone)] -pub struct GitHubOAuthClient; - -impl OAuth2Client for GitHubOAuthClient { - const CLIENT_ID: &'static str = "github_client_id"; - const CLIENT_SECRET: &'static str = "github_client_secret"; - const AUTH_URL: &'static str = ""; - const TOKEN_URL: &'static str = ""; - const REVOCATION_URL: &'static str = ""; +#[derive(Debug, Deserialize)] +struct UserInfo { + login: String, } #[derive(Debug, Clone)] -pub struct GoogleOAuthClient; - -impl OAuth2Client for GoogleOAuthClient { - const CLIENT_ID: &'static str = ""; - const CLIENT_SECRET: &'static str = ""; - const AUTH_URL: &'static str = "https://accounts.google.com/o/oauth2/v2/auth"; - const TOKEN_URL: &'static str = "https://www.googleapis.com/oauth2/v3/token"; - const REVOCATION_URL: &'static str = "https://oauth2.googleapis.com/revoke"; +pub struct PostgresBackend { + db: PgPool, + oauth2_clients: Vec, } impl PostgresBackend { - pub fn new(db: PgPool, oauth_clients: OAuth2Clients) -> Self { - Self { db, oauth_clients } + pub fn new(db: PgPool, oauth2_clients: Vec) -> Self { + Self { db, oauth2_clients } } pub fn authorize_url(&self) -> (Url, CsrfToken, PkceCodeVerifier) { - self.oauth_clients - .google + self.oauth2_clients + .first() + .unwrap() .authorize_url(CsrfToken::new_random) } } +#[tracing::instrument(level = "debug", ret, err)] +async fn password_authenticate( + db: &PgPool, + password_credentials: PasswordCredentials, +) -> Result, BackendError> { + let user = sqlx::query_as!( + User, + "select * from users where username = $1 and password_hash is not null", + password_credentials.username + ) + .fetch_optional(db) + .await + .map_err(BackendError::Sqlx)?; + + // Verifying the password is blocking and potentially slow, so we do it via + // `spawn_blocking`. + spawn_blocking_with_tracing(move || { + utils::verify_user_password(user, password_credentials.password) + }) + .await? +} + +#[tracing::instrument(level = "debug", ret, err)] +async fn oauth_authenticate( + db: &PgPool, + oauth2_clients: &[OAuth2Client], + oauth_creds: OAuthCredentials, +) -> Result, BackendError> { + // Ensure the CSRF state has not been tampered with. + if oauth_creds.old_state.secret() != oauth_creds.new_state.secret() { + return Ok(None); + }; + + let client = oauth2_clients.first().unwrap(); + // Process authorization code, expecting a token response back. + let token_res = client + .exchange_code(AuthorizationCode::new(oauth_creds.code)) + .await + .map_err(BackendError::OAuth2)?; + + // Use access token to request user info. + let user_info = reqwest::Client::new() + .get("https://api.github.com/user") + .header(USER_AGENT.as_str(), "axum-login") // See: https://docs.github.com/en/rest/overview/resources-in-the-rest-api?apiVersion=2022-11-28#user-agent-required + .header( + AUTHORIZATION.as_str(), + format!("Bearer {}", token_res.access_token().secret()), + ) + .send() + .await + .map_err(BackendError::Reqwest)? + .json::() + .await + .map_err(BackendError::Reqwest)?; + + // Persist user in our database, so we can use `get_user`. + let user = sqlx::query_as( + r#" + insert into users (username, access_token) + values (?, ?) + on conflict(username) do update + set access_token = excluded.access_token + returning * + "#, + ) + .bind(user_info.login) + .bind(token_res.access_token().secret()) + .fetch_one(db) + .await + .map_err(BackendError::Sqlx)?; + + Ok(Some(user)) +} + #[allow(clippy::blocks_in_conditions)] #[async_trait] impl AuthnBackend for PostgresBackend { @@ -160,70 +147,10 @@ impl AuthnBackend for PostgresBackend { ) -> Result, Self::Error> { match creds { Credentials::Password(password_cred) => { - let user = sqlx::query_as!( - User, - "select * from users where username = $1 and password_hash is not null", - password_cred.username - ) - .fetch_optional(&self.db) - .await - .map_err(Self::Error::Sqlx)?; - - debug!("user: {:?}", user); - - // Verifying the password is blocking and potentially slow, so we do it via - // `spawn_blocking`. - spawn_blocking_with_tracing(move || { - utils::verify_user_password(user, password_cred.password) - }) - .await? + password_authenticate(&self.db, password_cred).await } Credentials::OAuth(oauth_creds) => { - // Ensure the CSRF state has not been tampered with. - if oauth_creds.old_state.secret() != oauth_creds.new_state.secret() { - return Ok(None); - }; - - // Process authorization code, expecting a token response back. - let token_res = self - .oauth_clients - .github - .exchange_code(AuthorizationCode::new(oauth_creds.code)) - .await - .map_err(Self::Error::OAuth2)?; - - // Use access token to request user info. - let user_info = reqwest::Client::new() - .get("https://api.github.com/user") - .header(USER_AGENT.as_str(), "axum-login") // See: https://docs.github.com/en/rest/overview/resources-in-the-rest-api?apiVersion=2022-11-28#user-agent-required - .header( - AUTHORIZATION.as_str(), - format!("Bearer {}", token_res.access_token().secret()), - ) - .send() - .await - .map_err(Self::Error::Reqwest)? - .json::() - .await - .map_err(Self::Error::Reqwest)?; - - // Persist user in our database, so we can use `get_user`. - let user = sqlx::query_as( - r#" - insert into users (username, access_token) - values (?, ?) - on conflict(username) do update - set access_token = excluded.access_token - returning * - "#, - ) - .bind(user_info.login) - .bind(token_res.access_token().secret()) - .fetch_one(&self.db) - .await - .map_err(Self::Error::Sqlx)?; - - Ok(Some(user)) + oauth_authenticate(&self.db, &self.oauth2_clients, oauth_creds).await } } } @@ -241,22 +168,22 @@ impl AuthnBackend for PostgresBackend { } } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub enum Credentials { - Password(PasswordCreds), - OAuth(OAuthCreds), + Password(PasswordCredentials), + OAuth(OAuthCredentials), } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct PasswordCreds { +#[derive(Debug, Clone, Deserialize)] +pub struct PasswordCredentials { pub username: String, pub password: Secret, pub next: Option, } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct OAuthCreds { +#[derive(Debug, Clone, Deserialize)] +pub struct OAuthCredentials { pub code: String, pub old_state: CsrfToken, pub new_state: CsrfToken, @@ -284,7 +211,6 @@ pub type AuthSession = axum_login::AuthSession; #[cfg(test)] mod tests { - use crate::auth::models::{Credentials, OAuth2Clients, PasswordCreds}; use crate::auth::utils::dummy_verify_password; use crate::secrets::Secret; @@ -292,31 +218,4 @@ mod tests { async fn test_dummy_verify_password() { assert!(dummy_verify_password(Secret::new("password")).is_ok()); } - - #[tokio::test] - async fn test_oauth2_clients() { - let clients = OAuth2Clients::default(); - //let _ = clients.github.client(); - //let _ = clients.google.client(); - } - - #[test] - fn test_creds() { - let creds = Credentials::Password(PasswordCreds { - username: "test".to_string(), - password: Secret::new("password".to_string()), - next: None, - }); - println!( - "as_json: {:?}", - serde_json::to_string_pretty(&creds).unwrap() - ); - - assert_eq!(1, 2) - } -} - -#[derive(Debug, Deserialize)] -struct UserInfo { - login: String, } diff --git a/server/src/auth/oauth2.rs b/server/src/auth/oauth2.rs new file mode 100644 index 00000000..3db4a1df --- /dev/null +++ b/server/src/auth/oauth2.rs @@ -0,0 +1,147 @@ +use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse}; +use oauth2::reqwest::{async_http_client, AsyncHttpClientError}; +use oauth2::url::Url; +use oauth2::{ + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, + PkceCodeVerifier, RevocationUrl, TokenUrl, +}; +use serde::{Deserialize, Deserializer}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq)] +pub enum OAuth2Provider { + Google, +} + +impl<'de> Deserialize<'de> for OAuth2Provider { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + match s.as_str() { + "google" => Ok(OAuth2Provider::Google), + _ => Err(serde::de::Error::custom(format!( + "Invalid OAuth2 provider: {}", + s + ))), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +struct OAuth2ClientInfo { + pub provider: OAuth2Provider, + pub client_id: ClientId, + pub client_secret: ClientSecret, + pub auth_url: AuthUrl, + pub token_url: TokenUrl, + pub revocation_url: RevocationUrl, +} + +impl From for BasicClient { + fn from(info: OAuth2ClientInfo) -> Self { + BasicClient::new( + info.client_id, + Some(info.client_secret), + info.auth_url, + Some(info.token_url), + ) + .set_revocation_uri(info.revocation_url) + } +} + +#[derive(Debug, Clone)] +pub struct OAuth2Client { + client: Arc, + info: OAuth2ClientInfo, +} + +impl OAuth2Client { + pub fn client(&self) -> &BasicClient { + self.client.as_ref() + } + pub fn provider(&self) -> &OAuth2Provider { + &self.info.provider + } + pub fn client_id(&self) -> &ClientId { + &self.info.client_id + } + pub fn client_secret(&self) -> &ClientSecret { + &self.info.client_secret + } + pub fn auth_url(&self) -> &AuthUrl { + &self.info.auth_url + } + pub fn token_url(&self) -> &TokenUrl { + &self.info.token_url + } + pub fn revocation_url(&self) -> &RevocationUrl { + &self.info.revocation_url + } + + pub fn authorize_url(&self, state: S) -> (Url, CsrfToken, PkceCodeVerifier) + where + S: FnOnce() -> CsrfToken, + { + let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); + let (url, token) = self + .client() + .authorize_url(state) + //.add_scope(Scope::new("user:email".to_string())) + .set_pkce_challenge(pkce_code_challenge) + .url(); + + (url, token, pkce_code_verifier) + } + pub async fn exchange_code( + &self, + code: AuthorizationCode, + ) -> Result> { + self.client() + .exchange_code(code) + .request_async(async_http_client) + .await + } +} + +impl<'de> Deserialize<'de> for OAuth2Client { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let client_info = OAuth2ClientInfo::deserialize(deserializer)?; + let basic_client = BasicClient::from(client_info.clone()); + + Ok(OAuth2Client { + client: Arc::new(basic_client), + info: client_info, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn deserialize_oauth2_client() { + let json = json!({ + "provider": "google", + "client_id": "client_id", + "client_secret": "client_secret", + "auth_url": "https://auth_url", + "token_url": "https://token_url", + "revocation_url": "https://revocation_url" + }); + + let client: OAuth2Client = serde_json::from_value(json).unwrap(); + assert_eq!(client.provider(), &OAuth2Provider::Google); + assert_eq!(client.client_id().as_str(), "client_id"); + assert_eq!(client.client_secret().secret(), "client_secret"); + assert_eq!(client.auth_url().as_str(), "https://auth_url"); + assert_eq!(client.token_url().as_str(), "https://token_url"); + assert_eq!(client.revocation_url().as_str(), "https://revocation_url"); + } +} diff --git a/server/src/auth/routes.rs b/server/src/auth/routes.rs index 2325bc59..bb2c5c49 100644 --- a/server/src/auth/routes.rs +++ b/server/src/auth/routes.rs @@ -1,16 +1,19 @@ use crate::auth::models::{AuthSession, Credentials, RegisterUserRequest}; use crate::auth::services::register; +use crate::auth::{OAuthCredentials, PasswordCredentials}; use crate::err::AppError; -use axum::extract::State; +use crate::startup::AppState; +use axum::extract::{Query, State}; use axum::http::StatusCode; -use axum::response::{IntoResponse, Redirect}; +use axum::response::{IntoResponse, Redirect, Response}; use axum::routing::{get, post}; use axum::{Form, Json, Router}; +use axum_login::tower_sessions::Session; use color_eyre::eyre::eyre; +use oauth2::CsrfToken; +use serde::Deserialize; use sqlx::PgPool; -use crate::startup::AppState; - #[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] async fn register_handler( State(pool): State, @@ -24,9 +27,12 @@ async fn register_handler( #[tracing::instrument(level = "debug", skip_all)] async fn login_handler( mut auth_session: AuthSession, - Form(creds): Form, + Form(creds): Form, ) -> crate::Result<()> { - let user = match auth_session.authenticate(creds.clone()).await { + let user = match auth_session + .authenticate(Credentials::Password(creds)) + .await + { Ok(Some(user)) => user, Ok(None) => return Err(AppError::Unauthorized), Err(_) => return Err(eyre!("Could not authenticate user").into()), @@ -35,65 +41,104 @@ async fn login_handler( if auth_session.login(&user).await.is_err() { return Err(eyre!("Could not login user").into()); } - - if let Credentials::Password(pw_creds) = creds { - //if let Some(ref next) = pw_creds.next { - // return Redirect::to(next).into_response(); - //} - } + //if let Credentials::Password(_pw_creds) = creds { + // if let Some(ref next) = pw_creds.next { + // return Redirect::to(next).into_response(); + // } + //} Ok(()) } -#[tracing::instrument(level = "debug", skip_all)] -async fn login_callback_handler( - mut auth_session: AuthSession, - Form(creds): Form, +pub const CSRF_STATE_KEY: &str = "auth.csrf-state"; +pub const NEXT_URL_KEY: &str = "auth.next-url"; +pub const PKCE_VERIFIER_KEY: &str = "auth.pkce-verifier"; + +#[derive(Debug, Clone, Deserialize)] +struct NextUrl { + next: Option, +} + +async fn oauth_handler( + auth_session: AuthSession, + session: Session, + Form(NextUrl { next }): Form, ) -> impl IntoResponse { - let user = match auth_session.authenticate(creds.clone()).await { - Ok(Some(user)) => user, - Ok(None) => return StatusCode::UNAUTHORIZED.into_response(), - Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), - }; + let (auth_url, csrf_state, pkce_code_verisfier) = auth_session.backend.authorize_url(); - if auth_session.login(&user).await.is_err() { - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } + session + .insert(CSRF_STATE_KEY, csrf_state.secret()) + .await + .expect("Serialization should not fail."); - if let Credentials::Password(pw_creds) = creds { - if let Some(ref next) = pw_creds.next { - return Redirect::to(next).into_response(); - } - } - Redirect::to("/").into_response() + session + .insert(NEXT_URL_KEY, next) + .await + .expect("Serialization should not fail."); + + session + .insert(PKCE_VERIFIER_KEY, pkce_code_verisfier) + .await + .expect("Serialization should not fail."); + + Redirect::to(auth_url.as_str()).into_response() +} +#[derive(Debug, Clone, Deserialize)] +pub struct AuthzResp { + code: String, + state: CsrfToken, } #[tracing::instrument(level = "debug", skip_all)] -async fn logout_handler( +pub async fn oauth_callback_handler( mut auth_session: AuthSession, - Form(creds): Form, -) -> impl IntoResponse { - let user = match auth_session.authenticate(creds.clone()).await { + session: Session, + Query(AuthzResp { + code, + state: new_state, + }): Query, +) -> crate::Result { + let Ok(Some(old_state)) = session.get(CSRF_STATE_KEY).await else { + return Err(eyre!("Session did not contain old csrf state").into()); + }; + + let creds = Credentials::OAuth(OAuthCredentials { + code, + old_state, + new_state, + }); + + let user = match auth_session.authenticate(creds).await { Ok(Some(user)) => user, - Ok(None) => return StatusCode::UNAUTHORIZED.into_response(), - Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + Ok(None) => return Err(AppError::Unauthorized), + Err(_) => return Err(eyre!("Could not authenticate user").into()), }; if auth_session.login(&user).await.is_err() { - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + return Err(eyre!("Could not login user").into()); } - if let Credentials::Password(pw_creds) = creds { - if let Some(ref next) = pw_creds.next { - return Redirect::to(next).into_response(); - } + if let Ok(Some(next)) = session.remove::(NEXT_URL_KEY).await { + Ok(Redirect::to(&next).into_response()) + } else { + Ok(Redirect::to("/").into_response()) } - Redirect::to("/").into_response() +} + +#[tracing::instrument(level = "debug", skip_all)] +async fn logout_handler(mut auth_session: AuthSession) -> crate::Result { + auth_session + .logout() + .await + .map_err(|e| eyre!("Failed to logout: {}", e))?; + + Ok(Redirect::to("/login").into_response()) } pub fn routes() -> Router { Router::new() .route("/register", post(register_handler)) .route("/login", post(login_handler)) - .route("/login_callback", get(login_callback_handler)) + .route("/oauth", post(oauth_handler)) + .route("/oauth_callback", get(oauth_callback_handler)) .route("/logout", get(logout_handler)) } diff --git a/server/src/auth/services.rs b/server/src/auth/services.rs index cfc24264..4a6ab52d 100644 --- a/server/src/auth/services.rs +++ b/server/src/auth/services.rs @@ -1,15 +1,14 @@ use crate::auth::models::RegisterUserRequest; use crate::auth::utils; use crate::err::{AppError, ResultExt}; +use crate::users::{User, UserRecord}; use color_eyre::eyre::eyre; use sqlx::PgPool; -use crate::users::User; - #[tracing::instrument(level = "debug", ret, err)] -pub async fn register(pool: PgPool, request: RegisterUserRequest) -> crate::Result { - if let Some(password) = request.password_hash { - let password_hash = utils::hash_password(password).await; +pub async fn register(pool: PgPool, request: RegisterUserRequest) -> crate::Result { + if let Some(password) = request.password { + let password_hash = utils::hash_password(password).await?; let user = sqlx::query_as!( User, "insert into users (email, username, password_hash) values ($1, $2, $3) returning *", @@ -19,14 +18,14 @@ pub async fn register(pool: PgPool, request: RegisterUserRequest) -> crate::Resu ) .fetch_one(&pool) .await - .on_constraint("user_username_key", |_| { + .on_constraint("users_username_key", |_| { AppError::unprocessable_entity([("username", "username taken")]) }) - .on_constraint("user_email_key", |_| { + .on_constraint("users_email_key", |_| { AppError::unprocessable_entity([("email", "email taken")]) })?; - return Ok(user); + return Ok(user.into()); } else if let Some(access_token) = request.access_token { let user = sqlx::query_as!( User, @@ -37,14 +36,14 @@ pub async fn register(pool: PgPool, request: RegisterUserRequest) -> crate::Resu ) .fetch_one(&pool) .await - .on_constraint("user_username_key", |_| { + .on_constraint("users_username_key", |_| { AppError::unprocessable_entity([("username", "username taken")]) }) - .on_constraint("user_email_key", |_| { + .on_constraint("users_email_key", |_| { AppError::unprocessable_entity([("email", "email taken")]) })?; - return Ok(user); + return Ok(user.into()); } - Err(eyre!("Either password_hash or access_token must be provided to create a user").into()) + Err(eyre!("Either password or access_token must be provided to create a user").into()) } diff --git a/server/src/auth/utils.rs b/server/src/auth/utils.rs index b842b127..891df718 100644 --- a/server/src/auth/utils.rs +++ b/server/src/auth/utils.rs @@ -2,8 +2,7 @@ use crate::auth::BackendError; use crate::secrets::Secret; use crate::users::User; use password_auth::{generate_hash, verify_password}; -use sqlx::__rt::spawn_blocking; -use tracing::debug; +use tokio::task::spawn_blocking; #[tracing::instrument(level = "debug", ret, err)] pub fn verify_user_password( @@ -22,18 +21,11 @@ pub fn verify_user_password( let Some(password_hash) = user.password_hash.expose() else { return dummy_verify_password(password_candidate); }; - debug!("password_hash: {:?}", password_hash); // If the user exists and has a password, we verify the password. match verify_password(password_candidate.expose(), password_hash.as_ref()) { - Ok(_) => { - debug!("User authenticated: {:?}", user); - Ok(Some(user)) - } - _ => { - debug!("User not authenticated"); - Ok(None) - } + Ok(_) => Ok(Some(user)), + _ => Ok(None), } } }; @@ -51,6 +43,6 @@ pub fn dummy_verify_password(pw: Secret>) -> Result) -> Secret { - spawn_blocking(move || Secret::new(generate_hash(password.expose().as_bytes()))).await +pub async fn hash_password(password: Secret) -> Result, BackendError> { + Ok(spawn_blocking(move || Secret::new(generate_hash(password.expose().as_bytes()))).await?) } diff --git a/server/src/err.rs b/server/src/err.rs index 9562d515..dbf18970 100644 --- a/server/src/err.rs +++ b/server/src/err.rs @@ -1,7 +1,9 @@ +use crate::auth::BackendError; use axum::http::header::WWW_AUTHENTICATE; use axum::http::{HeaderMap, HeaderValue, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::Json; +use color_eyre::eyre::eyre; use sqlx::error::DatabaseError; use std::borrow::Cow; use std::collections::HashMap; @@ -47,6 +49,17 @@ impl From for AppError { } } +impl From for AppError { + fn from(e: BackendError) -> Self { + match e { + BackendError::Sqlx(e) => AppError::Sqlx(e), + BackendError::Reqwest(e) => AppError::GenericError(eyre!(e)), + BackendError::OAuth2(e) => AppError::GenericError(eyre!(e)), + BackendError::TaskJoin(e) => AppError::GenericError(eyre!(e)), + } + } +} + impl Display for AppError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/server/src/lib.rs b/server/src/lib.rs index e2fbfcc3..2a03549a 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -11,6 +11,7 @@ pub mod settings; pub mod startup; mod telemetry; pub mod users; +pub mod utils; pub type Result = std::result::Result; diff --git a/server/src/routing/api.rs b/server/src/routing/api.rs index 608e9aa4..996b5748 100644 --- a/server/src/routing/api.rs +++ b/server/src/routing/api.rs @@ -6,11 +6,13 @@ use axum_login::{login_required, AuthManagerLayerBuilder}; use tower_http::trace::{self, TraceLayer}; use tracing::Level; -use crate::auth::models::{OAuth2Clients, PostgresBackend}; +use crate::auth::models::PostgresBackend; use crate::startup::AppState; use crate::{auth, health_check, users}; pub fn router(state: AppState) -> color_eyre::Result { + //sqlx::migrate!().run(&db).await?; + // Session layer. // // This uses `tower-sessions` to establish a layer that will provide the session @@ -21,18 +23,16 @@ pub fn router(state: AppState) -> color_eyre::Result { .with_same_site(SameSite::Lax) // Ensure we send the cookie from the OAuth redirect. .with_expiry(Expiry::OnInactivity(Duration::days(1))); - //sqlx::migrate!().run(&db).await?; - // Auth service. // // This combines the session layer with our backend to establish the auth // service which will provide the auth session as a request extension. - let oauth_clients = OAuth2Clients::default(); - let backend = PostgresBackend::new(state.db.clone(), oauth_clients); + + let backend = PostgresBackend::new(state.db.clone(), state.oauth2_clients.clone()); let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build(); + let api_routes = Router::new() //.nest("/search", search::routes()) - //.layer(middleware::from_fn(some_auth_middleware)) .nest("/users", users::routes()) .route_layer(login_required!( PostgresBackend, diff --git a/server/src/secrets.rs b/server/src/secrets.rs index 7d620657..05db2ba5 100644 --- a/server/src/secrets.rs +++ b/server/src/secrets.rs @@ -1,9 +1,8 @@ -use std::error::Error; -use std::fmt::{Debug, Display}; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; use sqlx::database::HasValueRef; use sqlx::{Decode, Postgres}; +use std::error::Error; +use std::fmt::{Debug, Display}; /// A wrapper around a value that should be kept secret /// when displayed. This is useful for fields like passwords @@ -41,19 +40,13 @@ impl Secret { } } -impl Display for Secret -where - T: Default + Clone + Display, -{ +impl Display for Secret { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[redacted]") } } -impl Debug for Secret -where - T: Default + Clone + Debug, -{ +impl Debug for Secret { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[redacted]") } @@ -61,7 +54,7 @@ where impl sqlx::Type for Secret where - T: Default + Clone + sqlx::Type, + T: sqlx::Type, { fn type_info() -> sqlx::postgres::PgTypeInfo { >::type_info() @@ -70,7 +63,7 @@ where impl sqlx::Decode<'_, Postgres> for Secret where - for<'a> T: sqlx::Type + sqlx::Decode<'a, Postgres> + Default + Clone, + for<'a> T: sqlx::Type + sqlx::Decode<'a, Postgres>, { fn decode( value: >::ValueRef, @@ -82,7 +75,7 @@ where impl<'de, T> Deserialize<'de> for Secret where - T: Deserialize<'de> + Default + Clone + Debug, + T: Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where @@ -94,7 +87,7 @@ where impl Serialize for Secret where - T: Serialize + Default + Clone + Debug, + T: Serialize, { fn serialize(&self, serializer: S) -> Result where @@ -106,10 +99,7 @@ where } } -impl From for Secret -where - T: Default + Clone, -{ +impl From for Secret { fn from(s: T) -> Self { Self(s) } diff --git a/server/src/settings.rs b/server/src/settings.rs index 7c0703d2..0a93e988 100644 --- a/server/src/settings.rs +++ b/server/src/settings.rs @@ -1,12 +1,12 @@ use std::{env, fmt::Display}; +use crate::auth::oauth2::OAuth2Client; +use crate::secrets::Secret; use config::{Config, Environment, File}; use dotenvy::dotenv; use once_cell::sync::Lazy; use serde::{Deserialize, Deserializer}; -use crate::secrets::Secret; - #[derive(Debug, Clone)] pub enum LogFmt { Json, @@ -78,6 +78,7 @@ pub struct Settings { pub host: String, pub port: u16, pub db: Secret, + pub oauth2_clients: Vec, } impl Settings { diff --git a/server/src/startup.rs b/server/src/startup.rs index d54dda0b..2d91f3e5 100644 --- a/server/src/startup.rs +++ b/server/src/startup.rs @@ -1,13 +1,13 @@ +use crate::auth::oauth2::OAuth2Client; +use crate::routing::router; +use crate::settings::Settings; +use crate::Result; use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router}; use color_eyre::eyre::eyre; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tokio::net::TcpListener; -use crate::routing::router; -use crate::settings::Settings; -use crate::Result; - pub struct Application { port: u16, server: Serve, Router>, @@ -44,9 +44,20 @@ impl Application { #[derive(Clone, Debug, FromRef)] pub struct AppState { pub db: PgPool, + pub oauth2_clients: Vec, pub settings: Settings, } +impl From<(PgPool, Settings)> for AppState { + fn from((db, settings): (PgPool, Settings)) -> Self { + Self { + db, + oauth2_clients: settings.oauth2_clients.clone(), + settings, + } + } +} + pub async fn db_connect(database_url: &str) -> Result { match PgPoolOptions::new() .max_connections(5) @@ -63,8 +74,7 @@ async fn run( settings: Settings, ) -> Result, Router>> { let db = db_connect(settings.db.expose()).await?; - - let state = AppState { db, settings }; + let state = AppState::from((db, settings)); let app = router(state)?; diff --git a/server/src/users/models.rs b/server/src/users/models.rs index e461aa9d..d3ec8ff2 100644 --- a/server/src/users/models.rs +++ b/server/src/users/models.rs @@ -1,16 +1,32 @@ +use crate::auth::AuthSession; use crate::secrets::Secret; +use async_trait::async_trait; +use axum::extract::FromRequestParts; +use axum::http::{request::Parts, StatusCode}; +use axum::response::{IntoResponse, Redirect, Response}; use axum_login::AuthUser; use serde::{Deserialize, Serialize}; use sqlx::types::time; use std::fmt::Debug; -#[derive(sqlx::FromRow, Serialize, Deserialize, Clone)] -pub struct UserOut { +#[derive(sqlx::FromRow, Serialize, Clone, Debug)] +pub struct UserRecord { pub user_id: uuid::Uuid, + pub email: String, pub username: String, } -#[derive(sqlx::FromRow, Serialize, Deserialize, Clone, Debug)] +impl From for UserRecord { + fn from(user: User) -> Self { + Self { + user_id: user.user_id, + email: user.email, + username: user.username, + } + } +} + +#[derive(sqlx::FromRow, Deserialize, Clone, Debug)] pub struct User { pub user_id: uuid::Uuid, pub email: String, @@ -22,6 +38,30 @@ pub struct User { pub updated_at: Option, } +struct AuthRedirect; + +impl IntoResponse for AuthRedirect { + fn into_response(self) -> Response { + Redirect::temporary("/api/auth/login").into_response() + } +} + +#[async_trait] +impl FromRequestParts for User +where + S: Send + Sync, +{ + // If anything goes wrong or no session is found, redirect to the auth page + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let auth_session = AuthSession::from_request_parts(parts, state).await?; + auth_session + .user + .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized")) + } +} + impl AuthUser for User { type Id = uuid::Uuid; diff --git a/server/src/users/routes.rs b/server/src/users/routes.rs index 5ac6ed65..755868ee 100644 --- a/server/src/users/routes.rs +++ b/server/src/users/routes.rs @@ -1,23 +1,13 @@ use crate::startup::AppState; -use crate::users::selectors::get_user; -use crate::users::User; -use axum::extract::{Path, State}; +use crate::users::{User, UserRecord}; use axum::routing::get; use axum::{Json, Router}; -use color_eyre::eyre::eyre; -use sqlx::PgPool; #[tracing::instrument(level = "debug", skip_all, ret, err(Debug))] -async fn get_user_handler( - State(pool): State, - Path(user_id): Path, -) -> crate::Result> { - match get_user(pool, user_id).await? { - Some(user) => Ok(Json(user)), - None => Err(eyre!("User not found").into()), - } +async fn get_user_handler(user: User) -> crate::Result> { + return Ok(Json(UserRecord::from(user))); } pub fn routes() -> Router { - Router::new().route("/:user_id", get(get_user_handler)) + Router::new().route("/me", get(get_user_handler)) } diff --git a/server/src/users/selectors.rs b/server/src/users/selectors.rs index c3f7cc0b..ed0f901a 100644 --- a/server/src/users/selectors.rs +++ b/server/src/users/selectors.rs @@ -1,13 +1,13 @@ use sqlx::types::uuid; use sqlx::PgPool; -use crate::users::User; +use crate::users::{User, UserRecord}; #[tracing::instrument(level = "debug", ret, err)] -pub async fn get_user(pool: PgPool, user_id: uuid::Uuid) -> color_eyre::Result> { +pub async fn get_user(pool: PgPool, user_id: uuid::Uuid) -> color_eyre::Result> { let user = sqlx::query_as!(User, "select * from users where user_id = $1", user_id) .fetch_optional(&pool) .await?; - Ok(user) + Ok(user.map(UserRecord::from)) } diff --git a/server/src/utils.rs b/server/src/utils.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/server/src/utils.rs @@ -0,0 +1 @@ + diff --git a/server/tests/health_check.rs b/server/tests/health_check.rs index 0429370c..cd7517e8 100644 --- a/server/tests/health_check.rs +++ b/server/tests/health_check.rs @@ -10,11 +10,8 @@ use server::startup::{db_connect, AppState}; async fn health_check_works() { let settings = Settings::new(); - let db = db_connect(settings.db.expose()) - .await - .expect("Failed to connect to Postgres."); - - let state = AppState { db, settings }; + let db = db_connect(settings.db.expose()).await.unwrap(); + let state = AppState::from((db, settings)); let router = router(state).unwrap(); let request = Request::builder() diff --git a/server/tests/users.rs b/server/tests/users.rs index 04eb30b1..0bf29f21 100644 --- a/server/tests/users.rs +++ b/server/tests/users.rs @@ -1,25 +1,18 @@ use axum::body::Body; -use axum::extract::rejection::{ - FailedToDeserializeForm, FailedToDeserializeFormBody, FormRejection, RawFormRejection, -}; -use axum::extract::{FromRequest, RawForm}; use axum::http::header::CONTENT_TYPE; use axum::http::{Request, StatusCode}; -use axum::response::IntoResponse; -use axum::RequestExt; -use axum::{http, Form}; -use sqlx::PgPool; -use tower::ServiceExt; - use server::auth::models::RegisterUserRequest; +use server::auth::register; use server::routing::router; -use server::secrets::Secret; use server::settings::Settings; -use server::startup::{db_connect, AppState}; +use server::startup::AppState; use server::users::selectors::get_user; -use server::users::services::register_user; +use server::Result; +use sqlx::PgPool; +use tower::ServiceExt; + /// Helper function to create a GET request for a given URI. -fn send_get_request(uri: &str) -> Request { +fn _send_get_request(uri: &str) -> Request { Request::builder() .uri(uri) .method("GET") @@ -28,21 +21,20 @@ fn send_get_request(uri: &str) -> Request { } #[sqlx::test] -async fn register_and_get_users_test(pool: PgPool) -> color_eyre::Result<()> { +async fn register_and_get_users_test(pool: PgPool) -> Result<()> { let user = get_user(pool.clone(), uuid::Uuid::nil()).await.unwrap(); assert!(user.is_none()); - let new_user = register_user( + let new_user = register( pool.clone(), RegisterUserRequest { email: "test-email".to_string(), username: "test-username".to_string(), - password_hash: Some("password".to_string()).into(), + password: Some("password".to_string().into()), access_token: Default::default(), }, ) - .await? - .unwrap(); + .await?; let user = get_user(pool.clone(), new_user.user_id).await?.unwrap(); @@ -52,42 +44,41 @@ async fn register_and_get_users_test(pool: PgPool) -> color_eyre::Result<()> { Ok(()) } -#[tokio::test] -async fn register_users_works() { +#[sqlx::test] +async fn register_users_works(pool: PgPool) { let settings = Settings::new(); - - let db = db_connect(settings.db.expose()) - .await - .expect("Failed to connect to Postgres."); - - let state = AppState { db, settings }; + let state = AppState::from((pool, settings)); let router = router(state).unwrap(); let form = &[ ("email", "my-email@email.com"), ("username", "my-username"), - ("password_hash", "my-password"), + ("password", "my-password"), ]; - let serialized_body = serde_urlencoded::to_string(&form).unwrap(); - let request = Request::post("/api/users/register") + let serialized_body = serde_urlencoded::to_string(form).unwrap(); + let request = Request::post("/api/auth/register") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(serialized_body) .unwrap(); + let response = router.clone().oneshot(request.clone()).await.unwrap(); + assert_eq!(response.status(), StatusCode::CREATED); + + // Doing the same thing again should return a 422 status code. let response = router.clone().oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let form = &[ - ("email", "my-email@email.com"), - ("username", "my-username"), + ("email", "another-email@email.com"), + ("username", "another-username"), ("access_token", "my-access-token"), ]; - let serialized_body = serde_urlencoded::to_string(&form).unwrap(); - let request = Request::post("/api/users/register") + let serialized_body = serde_urlencoded::to_string(form).unwrap(); + let request = Request::post("/api/auth/register") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(serialized_body) .unwrap(); let response = router.oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.status(), StatusCode::CREATED); }