From 37839aca5a4d06b39b48957964f6ae0a6e13a1cc Mon Sep 17 00:00:00 2001 From: Basti Ortiz <39114273+BastiDood@users.noreply.github.com> Date: Thu, 21 Dec 2023 04:38:16 +0800 Subject: [PATCH] refactor: pass in `Response` mutably --- Cargo.lock | 1 - Cargo.toml | 1 - crates/api/src/bot/mod.rs | 2 +- crates/api/src/lib.rs | 102 ++++++++++++++++++++++++++------------ src/main.rs | 18 ++----- 5 files changed, 76 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cffcc97..8062420 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -985,7 +985,6 @@ dependencies = [ "anyhow", "env_logger", "hex", - "http-body-util", "hyper 1.1.0", "hyper-util", "log", diff --git a/Cargo.toml b/Cargo.toml index 132bbe1..7118735 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ anyhow = "1" api = { path = "crates/api", package = "quizzo-api" } env_logger = { version = "0.10", default-features = false } hex = "0.4" -http-body-util = "0.1.0" hyper-util = { version = "0.1.1", features = ["tokio"] } log = "0.4" diff --git a/crates/api/src/bot/mod.rs b/crates/api/src/bot/mod.rs index 23ed694..3c7d75e 100644 --- a/crates/api/src/bot/mod.rs +++ b/crates/api/src/bot/mod.rs @@ -65,7 +65,7 @@ impl Bot { _ => Err(error::Error::Schema), }; result.unwrap_or_else(|err| { - log::error!("Interaction failed with `{err:?}`"); + log::error!("interaction failed with `{err:?}`"); InteractionResponse { kind: InteractionResponseType::ChannelMessageWithSource, data: Some(InteractionResponseData { diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 194416c..c3b9a9e 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -5,7 +5,7 @@ use core::num::NonZeroU64; use http_body_util::Full; use hyper::{ body::{Bytes, Incoming}, - Request, Response, StatusCode, + HeaderMap, Method, Response, StatusCode, }; pub use db::{Client, Config, Database, NoTls}; @@ -23,43 +23,58 @@ impl App { Self { bot: Bot::new(db, id, token), public } } - pub async fn try_respond(&self, req: Request) -> Result>, StatusCode> { - use hyper::{http::request::Parts, Method}; - let (Parts { uri, method, headers, .. }, mut body) = req.into_parts(); - let path = uri.path(); - + pub async fn try_respond( + &self, + response: &mut Response>, + method: Method, + path: &str, + headers: HeaderMap, + mut body: Incoming, + ) -> bool { match method { Method::GET | Method::HEAD => match path { "/health" => { - log::info!("Health check pinged"); - return Ok(Default::default()); + log::info!("health check pinged"); + return true; } _ => { - log::error!("Unexpected `{method} {path}` request received"); - return Err(StatusCode::NOT_FOUND); + log::error!("unexpected `{method} {path}` request received"); + *response.status_mut() = StatusCode::NOT_FOUND; + return false; } }, Method::POST => match path { "/discord" => (), _ => { - log::error!("Unexpected `POST {path}` request received"); - return Err(StatusCode::NOT_FOUND); + log::error!("unexpected `POST {path}` request received"); + *response.status_mut() = StatusCode::NOT_FOUND; + return false; } }, _ => { - log::error!("Unexpected `{method} {path}` request received"); - return Err(StatusCode::METHOD_NOT_ALLOWED); + log::error!("unexpected `{method} {path}` request received"); + *response.status_mut() = StatusCode::METHOD_NOT_ALLOWED; + return false; } } + log::debug!("new Discord interaction received"); + // Retrieve security headers - log::debug!("New Discord interaction received"); let signature = headers.get("X-Signature-Ed25519"); let timestamp = headers.get("X-Signature-Timestamp"); - let (sig, timestamp) = signature.zip(timestamp).ok_or(StatusCode::UNAUTHORIZED)?; - let mut signature = [0; 64]; - hex::decode_to_slice(sig, &mut signature).map_err(|_| StatusCode::BAD_REQUEST)?; - let signature = ed25519_dalek::Signature::from_bytes(&signature); + let Some((signature, timestamp)) = signature.zip(timestamp) else { + log::error!("no signatures in headers"); + *response.status_mut() = StatusCode::UNAUTHORIZED; + return false; + }; + + let mut buffer = [0; 64]; + if let Err(err) = hex::decode_to_slice(signature, &mut buffer) { + log::error!("bad signature hex encoding: {err}"); + *response.status_mut() = StatusCode::BAD_REQUEST; + return false; + } // Append body after the timestamp use http_body_util::BodyExt; @@ -70,30 +85,53 @@ impl App { Ok(frame) => frame, Err(err) => { log::error!("body stream prematurely ended: {err}"); - return Err(StatusCode::INTERNAL_SERVER_ERROR); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return false; } }; if let Some(data) = frame.data_ref() { message.extend_from_slice(data); } } - log::debug!("Fully received payload body."); + + log::debug!("fully received payload body."); // Validate the challenge - self.public.verify_strict(&message, &signature).map_err(|_| StatusCode::UNAUTHORIZED)?; + let signature = ed25519_dalek::Signature::from_bytes(&buffer); + if let Err(err) = self.public.verify_strict(&message, &signature) { + log::error!("cannot verify message with signature: {err}"); + *response.status_mut() = StatusCode::FORBIDDEN; + return false; + } - // Parse incoming interaction - let payload = message.get(start..).ok_or(StatusCode::BAD_REQUEST)?; - let interaction = serde_json::from_slice(payload).map_err(|_| StatusCode::BAD_REQUEST)?; + let Some(payload) = message.get(start..) else { + log::error!("body is empty"); + *response.status_mut() = StatusCode::BAD_REQUEST; + return false; + }; - // Construct new body - let reply = self.bot.on_message(interaction).await; - let bytes = serde_json::to_vec(&reply).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let reply = match serde_json::from_slice(payload) { + Ok(interaction) => self.bot.on_message(interaction).await, + Err(err) => { + log::error!("body is not JSON-encoded: {err}"); + *response.status_mut() = StatusCode::BAD_REQUEST; + return false; + } + }; + + *response.body_mut() = match serde_json::to_vec(&reply) { + Ok(bytes) => bytes.into(), + Err(err) => { + log::error!("cannot encode reply to JSON: {err}"); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return false; + } + }; use hyper::header::{HeaderValue, CONTENT_TYPE}; - let mut res = Response::new(Full::from(bytes)); - let result = res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - assert!(result.is_none()); - Ok(res) + if let Some(value) = response.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")) { + log::warn!("existing header value: {value:?}"); + } + true } } diff --git a/src/main.rs b/src/main.rs index 78db72a..0c5c60f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,3 @@ -use http_body_util::Full; -use hyper::{body::Bytes, Response, StatusCode}; - -fn resolve_error_code(code: StatusCode) -> Response> { - let mut response = Response::default(); - *response.status_mut() = code; - response -} - fn main() -> anyhow::Result<()> { env_logger::init(); log::info!("Starting up"); @@ -43,7 +34,7 @@ fn main() -> anyhow::Result<()> { listener.set_nonblocking(true)?; let addr = listener.local_addr()?; - log::info!("Listening to {addr}"); + log::info!("listening to {addr}"); // Set up runtime let runtime = tokio::runtime::Builder::new_multi_thread().enable_io().enable_time().build()?; @@ -69,8 +60,10 @@ fn main() -> anyhow::Result<()> { let outer = state.clone(); let service = hyper::service::service_fn(move |req| { let inner = outer.clone(); + let (hyper::http::request::Parts { method, uri, headers, .. }, body) = req.into_parts(); async move { - let response = inner.try_respond(req).await.unwrap_or_else(resolve_error_code); + let mut response = Default::default(); + inner.try_respond(&mut response, method, uri.path(), headers, body).await; Ok::<_, core::convert::Infallible>(response) } }); @@ -79,7 +72,7 @@ fn main() -> anyhow::Result<()> { continue; } stop_res = &mut stop => { - log::info!("Stop signal received"); + log::info!("stop signal received"); stop_res?; break; }, @@ -91,7 +84,6 @@ fn main() -> anyhow::Result<()> { else => continue, } } - anyhow::Ok(()) })?;