Skip to content

Commit

Permalink
refactor: pass in Response mutably
Browse files Browse the repository at this point in the history
  • Loading branch information
BastiDood committed Dec 20, 2023
1 parent 24dbbc6 commit 37839ac
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 48 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ anyhow = "1"
api = { path = "crates/api", package = "quizzo-api" }
env_logger = { version = "0.10", default-features = false }
hex = "0.4"
http-body-util = "0.1.0"
hyper-util = { version = "0.1.1", features = ["tokio"] }
log = "0.4"

Expand Down
2 changes: 1 addition & 1 deletion crates/api/src/bot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Bot {
_ => Err(error::Error::Schema),
};
result.unwrap_or_else(|err| {
log::error!("Interaction failed with `{err:?}`");
log::error!("interaction failed with `{err:?}`");
InteractionResponse {
kind: InteractionResponseType::ChannelMessageWithSource,
data: Some(InteractionResponseData {
Expand Down
102 changes: 70 additions & 32 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::num::NonZeroU64;
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
Request, Response, StatusCode,
HeaderMap, Method, Response, StatusCode,
};

pub use db::{Client, Config, Database, NoTls};
Expand All @@ -23,43 +23,58 @@ impl App {
Self { bot: Bot::new(db, id, token), public }
}

pub async fn try_respond(&self, req: Request<Incoming>) -> Result<Response<Full<Bytes>>, StatusCode> {
use hyper::{http::request::Parts, Method};
let (Parts { uri, method, headers, .. }, mut body) = req.into_parts();
let path = uri.path();

pub async fn try_respond(
&self,
response: &mut Response<Full<Bytes>>,
method: Method,
path: &str,
headers: HeaderMap,
mut body: Incoming,
) -> bool {
match method {
Method::GET | Method::HEAD => match path {
"/health" => {
log::info!("Health check pinged");
return Ok(Default::default());
log::info!("health check pinged");
return true;
}
_ => {
log::error!("Unexpected `{method} {path}` request received");
return Err(StatusCode::NOT_FOUND);
log::error!("unexpected `{method} {path}` request received");
*response.status_mut() = StatusCode::NOT_FOUND;
return false;
}
},
Method::POST => match path {
"/discord" => (),
_ => {
log::error!("Unexpected `POST {path}` request received");
return Err(StatusCode::NOT_FOUND);
log::error!("unexpected `POST {path}` request received");
*response.status_mut() = StatusCode::NOT_FOUND;
return false;
}
},
_ => {
log::error!("Unexpected `{method} {path}` request received");
return Err(StatusCode::METHOD_NOT_ALLOWED);
log::error!("unexpected `{method} {path}` request received");
*response.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
return false;
}
}

log::debug!("new Discord interaction received");

// Retrieve security headers
log::debug!("New Discord interaction received");
let signature = headers.get("X-Signature-Ed25519");
let timestamp = headers.get("X-Signature-Timestamp");
let (sig, timestamp) = signature.zip(timestamp).ok_or(StatusCode::UNAUTHORIZED)?;
let mut signature = [0; 64];
hex::decode_to_slice(sig, &mut signature).map_err(|_| StatusCode::BAD_REQUEST)?;
let signature = ed25519_dalek::Signature::from_bytes(&signature);
let Some((signature, timestamp)) = signature.zip(timestamp) else {
log::error!("no signatures in headers");
*response.status_mut() = StatusCode::UNAUTHORIZED;
return false;
};

let mut buffer = [0; 64];
if let Err(err) = hex::decode_to_slice(signature, &mut buffer) {
log::error!("bad signature hex encoding: {err}");
*response.status_mut() = StatusCode::BAD_REQUEST;
return false;
}

// Append body after the timestamp
use http_body_util::BodyExt;
Expand All @@ -70,30 +85,53 @@ impl App {
Ok(frame) => frame,
Err(err) => {
log::error!("body stream prematurely ended: {err}");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return false;
}
};
if let Some(data) = frame.data_ref() {
message.extend_from_slice(data);
}
}
log::debug!("Fully received payload body.");

log::debug!("fully received payload body.");

// Validate the challenge
self.public.verify_strict(&message, &signature).map_err(|_| StatusCode::UNAUTHORIZED)?;
let signature = ed25519_dalek::Signature::from_bytes(&buffer);
if let Err(err) = self.public.verify_strict(&message, &signature) {
log::error!("cannot verify message with signature: {err}");
*response.status_mut() = StatusCode::FORBIDDEN;
return false;
}

// Parse incoming interaction
let payload = message.get(start..).ok_or(StatusCode::BAD_REQUEST)?;
let interaction = serde_json::from_slice(payload).map_err(|_| StatusCode::BAD_REQUEST)?;
let Some(payload) = message.get(start..) else {
log::error!("body is empty");
*response.status_mut() = StatusCode::BAD_REQUEST;
return false;
};

// Construct new body
let reply = self.bot.on_message(interaction).await;
let bytes = serde_json::to_vec(&reply).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let reply = match serde_json::from_slice(payload) {
Ok(interaction) => self.bot.on_message(interaction).await,
Err(err) => {
log::error!("body is not JSON-encoded: {err}");
*response.status_mut() = StatusCode::BAD_REQUEST;
return false;
}
};

*response.body_mut() = match serde_json::to_vec(&reply) {
Ok(bytes) => bytes.into(),
Err(err) => {
log::error!("cannot encode reply to JSON: {err}");
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return false;
}
};

use hyper::header::{HeaderValue, CONTENT_TYPE};
let mut res = Response::new(Full::from(bytes));
let result = res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
assert!(result.is_none());
Ok(res)
if let Some(value) = response.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")) {
log::warn!("existing header value: {value:?}");
}
true
}
}
18 changes: 5 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
use http_body_util::Full;
use hyper::{body::Bytes, Response, StatusCode};

fn resolve_error_code(code: StatusCode) -> Response<Full<Bytes>> {
let mut response = Response::default();
*response.status_mut() = code;
response
}

fn main() -> anyhow::Result<()> {
env_logger::init();
log::info!("Starting up");
Expand Down Expand Up @@ -43,7 +34,7 @@ fn main() -> anyhow::Result<()> {
listener.set_nonblocking(true)?;

let addr = listener.local_addr()?;
log::info!("Listening to {addr}");
log::info!("listening to {addr}");

// Set up runtime
let runtime = tokio::runtime::Builder::new_multi_thread().enable_io().enable_time().build()?;
Expand All @@ -69,8 +60,10 @@ fn main() -> anyhow::Result<()> {
let outer = state.clone();
let service = hyper::service::service_fn(move |req| {
let inner = outer.clone();
let (hyper::http::request::Parts { method, uri, headers, .. }, body) = req.into_parts();
async move {
let response = inner.try_respond(req).await.unwrap_or_else(resolve_error_code);
let mut response = Default::default();
inner.try_respond(&mut response, method, uri.path(), headers, body).await;
Ok::<_, core::convert::Infallible>(response)
}
});
Expand All @@ -79,7 +72,7 @@ fn main() -> anyhow::Result<()> {
continue;
}
stop_res = &mut stop => {
log::info!("Stop signal received");
log::info!("stop signal received");
stop_res?;
break;
},
Expand All @@ -91,7 +84,6 @@ fn main() -> anyhow::Result<()> {
else => continue,
}
}

anyhow::Ok(())
})?;

Expand Down

0 comments on commit 37839ac

Please sign in to comment.