diff --git a/edge-http/src/io.rs b/edge-http/src/io.rs index 390cd2c..3d58acd 100644 --- a/edge-http/src/io.rs +++ b/edge-http/src/io.rs @@ -128,18 +128,15 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { unreachable!("Should not happen. HTTP header parsing is indeterminate.") } - self.http11 = if let Some(version) = parser.version { - if version > 1 { - Err(Error::InvalidHeaders)?; - } - - Some(version == 1) - } else { - None + self.http11 = match parser.version { + Some(0) => false, + Some(1) => true, + _ => Err(Error::InvalidHeaders)?, }; - self.method = parser.method.and_then(Method::new); - self.path = parser.path; + let method_str = parser.method.ok_or(Error::InvalidHeaders)?; + self.method = Method::new(method_str).ok_or(Error::InvalidHeaders)?; + self.path = parser.path.ok_or(Error::InvalidHeaders)?; trace!("Received:\n{}", self); @@ -151,8 +148,7 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { /// Resolve the connection type and body type from the headers pub fn resolve(&self) -> Result<(ConnectionType, BodyType), Error> { - self.headers - .resolve::(None, true, self.http11.unwrap_or(false)) + self.headers.resolve::(None, true, self.http11) } /// Send the headers to the output stream, returning the connection type and body type @@ -164,12 +160,10 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { where W: Write, { - let http11 = self.http11.unwrap_or(false); - - send_request(http11, self.method, self.path, &mut output).await?; + send_request(self.http11, self.method, self.path, &mut output).await?; self.headers - .send(None, true, http11, chunked_if_unspecified, output) + .send(None, true, self.http11, chunked_if_unspecified, output) .await } } @@ -199,17 +193,13 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { unreachable!("Should not happen. HTTP header parsing is indeterminate.") } - self.http11 = if let Some(version) = parser.version { - if version > 1 { - Err(Error::InvalidHeaders)?; - } - - Some(version == 1) - } else { - None + self.http11 = match parser.version { + Some(0) => false, + Some(1) => true, + _ => Err(Error::InvalidHeaders)?, }; - self.code = parser.code; + self.code = parser.code.ok_or(Error::InvalidHeaders)?; self.reason = parser.reason; trace!("Received:\n{}", self); @@ -225,11 +215,8 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { &self, request_connection_type: ConnectionType, ) -> Result<(ConnectionType, BodyType), Error> { - self.headers.resolve::( - Some(request_connection_type), - false, - self.http11.unwrap_or(false), - ) + self.headers + .resolve::(Some(request_connection_type), false, self.http11) } /// Send the headers to the output stream, returning the connection type and body type @@ -242,15 +229,13 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { where W: Write, { - let http11 = self.http11.unwrap_or(false); - - send_status(http11, self.code, self.reason, &mut output).await?; + send_status(self.http11, self.code, self.reason, &mut output).await?; self.headers .send( Some(request_connection_type), false, - http11, + self.http11, chunked_if_unspecified, output, ) @@ -260,42 +245,56 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { pub(crate) async fn send_request( http11: bool, - method: Option, - path: Option<&str>, - output: W, + method: Method, + path: &str, + mut output: W, ) -> Result<(), Error> where W: Write, { - raw::send_status_line( - true, - http11, - method.map(|method| method.as_str()), - path, - output, - ) - .await + // RFC 9112: request-line = method SP request-target SP HTTP-version + + output + .write_all(method.as_str().as_bytes()) + .await + .map_err(Error::Io)?; + output.write_all(b" ").await.map_err(Error::Io)?; + output.write_all(path.as_bytes()).await.map_err(Error::Io)?; + output.write_all(b" ").await.map_err(Error::Io)?; + raw::send_version(&mut output, http11).await?; + output.write_all(b"\r\n").await.map_err(Error::Io)?; + + Ok(()) } pub(crate) async fn send_status( http11: bool, - status: Option, + status: u16, reason: Option<&str>, - output: W, + mut output: W, ) -> Result<(), Error> where W: Write, { - let status_str: Option> = status.map(|status| status.try_into().unwrap()); + // RFC 9112: status-line = HTTP-version SP status-code SP [ reason-phrase ] - raw::send_status_line( - false, - http11, - status_str.as_ref().map(|status| status.as_str()), - reason, - output, - ) - .await + raw::send_version(&mut output, http11).await?; + output.write_all(b" ").await.map_err(Error::Io)?; + let status_str: heapless::String<5> = status.try_into().unwrap(); + output + .write_all(status_str.as_bytes()) + .await + .map_err(Error::Io)?; + output.write_all(b" ").await.map_err(Error::Io)?; + if let Some(reason) = reason { + output + .write_all(reason.as_bytes()) + .await + .map_err(Error::Io)?; + } + output.write_all(b"\r\n").await.map_err(Error::Io)?; + + Ok(()) } pub(crate) async fn send_headers<'a, H, W>( @@ -1181,61 +1180,6 @@ mod raw { } } - pub(crate) async fn send_status_line( - request: bool, - http11: bool, - token: Option<&str>, - extra: Option<&str>, - mut output: W, - ) -> Result<(), Error> - where - W: Write, - { - let mut written = false; - - if !request { - send_version(&mut output, http11).await?; - written = true; - } - - if let Some(token) = token { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } - - output - .write_all(token.as_bytes()) - .await - .map_err(Error::Io)?; - - written = true; - } - - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } - if let Some(extra) = extra { - output - .write_all(extra.as_bytes()) - .await - .map_err(Error::Io)?; - - written = true; - } - - if request { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } - - send_version(&mut output, http11).await?; - } - - output.write_all(b"\r\n").await.map_err(Error::Io)?; - - Ok(()) - } - pub(crate) async fn send_version(mut output: W, http11: bool) -> Result<(), Error> where W: Write, diff --git a/edge-http/src/io/client.rs b/edge-http/src/io/client.rs index 3775e1d..6f2fb10 100644 --- a/edge-http/src/io/client.rs +++ b/edge-http/src/io/client.rs @@ -174,7 +174,7 @@ where let mut state = self.unbind(); let result = async { - match send_request(http11, Some(method), Some(uri), state.io.as_mut().unwrap()).await { + match send_request(http11, method, uri, state.io.as_mut().unwrap()).await { Ok(_) => (), Err(Error::Io(_)) => { if !fresh_connection { @@ -182,8 +182,7 @@ where state.io = None; state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?); - send_request(http11, Some(method), Some(uri), state.io.as_mut().unwrap()) - .await?; + send_request(http11, method, uri, state.io.as_mut().unwrap()).await?; } } Err(other) => Err(other)?, @@ -263,7 +262,6 @@ where let mut state = self.unbind(); let buf_ptr: *mut [u8] = state.buf; - let mut response = ResponseHeaders::new(); match response diff --git a/edge-http/src/io/server.rs b/edge-http/src/io/server.rs index 1b92a7b..1d7a756 100644 --- a/edge-http/src/io/server.rs +++ b/edge-http/src/io/server.rs @@ -103,7 +103,7 @@ where message: Option<&str>, headers: &[(&str, &str)], ) -> Result<(), Error> { - self.complete_request(Some(status), message, headers).await + self.complete_request(status, message, headers).await } /// A convenience method to initiate a WebSocket upgrade response @@ -125,7 +125,7 @@ where /// If the connection is still in a request state, and empty 200 OK response is sent pub async fn complete(&mut self) -> Result<(), Error> { if self.is_request_initiated() { - self.complete_request(Some(200), Some("OK"), &[]).await?; + self.complete_request(200, Some("OK"), &[]).await?; } if self.is_response_initiated() { @@ -145,7 +145,7 @@ where Ok(_) => { let headers = [("Connection", "Close"), ("Content-Type", "text/plain")]; - self.complete_request(Some(500), Some("Internal Error"), &headers) + self.complete_request(500, Some("Internal Error"), &headers) .await?; let response = self.response_mut()?; @@ -181,7 +181,7 @@ where async fn complete_request( &mut self, - status: Option, + status: u16, reason: Option<&str>, headers: &[(&str, &str)], ) -> Result<(), Error> { @@ -190,7 +190,7 @@ where let mut buf = [0; COMPLETION_BUF_SIZE]; while request.io.read(&mut buf).await? > 0 {} - let http11 = request.request.http11.unwrap_or(false); + let http11 = request.request.http11; let request_connection_type = request.connection_type; let mut io = self.unbind_mut(); @@ -918,12 +918,7 @@ mod embedded_svc_compat { let headers = connection.headers().ok(); if let Some(headers) = headers { - if headers.path.map(|path| self.path == path).unwrap_or(false) - && headers - .method - .map(|method| self.method == method.into()) - .unwrap_or(false) - { + if headers.path == self.path && headers.method == self.method.into() { return self.handler.handle(connection).await; } } diff --git a/edge-http/src/lib.rs b/edge-http/src/lib.rs index bdb77c6..6fe9546 100644 --- a/edge-http/src/lib.rs +++ b/edge-http/src/lib.rs @@ -704,27 +704,27 @@ impl Display for BodyType { } /// Request headers including the request line (method, path) -#[derive(Default, Debug)] +#[derive(Debug)] pub struct RequestHeaders<'b, const N: usize> { - /// Whether the request is HTTP/1.1, if present. If not present, HTTP/1.0 should be assumed - pub http11: Option, - /// The HTTP method, if present - pub method: Option, - /// The request path, if present - pub path: Option<&'b str>, + /// Whether the request is HTTP/1.1 + pub http11: bool, + /// The HTTP method + pub method: Method, + /// The request path + pub path: &'b str, /// The headers pub headers: Headers<'b, N>, } impl RequestHeaders<'_, N> { - /// Create a new RequestHeaders instance for HTTP/1.1 + // Create a new RequestHeaders instance, defaults to GET / HTTP/1.1 #[inline(always)] pub const fn new() -> Self { Self { - http11: Some(true), - method: None, - path: None, - headers: Headers::::new(), + http11: true, + method: Method::Get, + path: "/", + headers: Headers::new(), } } @@ -734,15 +734,18 @@ impl RequestHeaders<'_, N> { } } +impl<'b, const N: usize> Default for RequestHeaders<'b, N> { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + impl Display for RequestHeaders<'_, N> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(http11) = self.http11 { - write!(f, "{} ", if http11 { "HTTP/1.1" } else { "HTTP/1.0" })?; - } + write!(f, "{} ", if self.http11 { "HTTP/1.1" } else { "HTTP/1.0" })?; - if let Some(method) = self.method { - writeln!(f, "{method} {}", self.path.unwrap_or(""))?; - } + writeln!(f, "{} {}", self.method, self.path)?; for (name, value) in self.headers.iter() { if name.is_empty() { @@ -757,12 +760,12 @@ impl Display for RequestHeaders<'_, N> { } /// Response headers including the response line (HTTP version, status code, reason phrase) -#[derive(Default, Debug)] +#[derive(Debug)] pub struct ResponseHeaders<'b, const N: usize> { - /// Whether the response is HTTP/1.1, if present. If not present, HTTP/1.0 should be assumed - pub http11: Option, - /// The status code, if present - pub code: Option, + /// Whether the response is HTTP/1.1 + pub http11: bool, + /// The status code + pub code: u16, /// The reason phrase, if present pub reason: Option<&'b str>, /// The headers @@ -770,14 +773,14 @@ pub struct ResponseHeaders<'b, const N: usize> { } impl ResponseHeaders<'_, N> { - /// Create a new ResponseHeaders instance for HTTP/1.1 + /// Create a new ResponseHeaders instance, defaults to HTTP/1.1 200 OK #[inline(always)] pub const fn new() -> Self { Self { - http11: Some(true), - code: None, + http11: true, + code: 200, reason: None, - headers: Headers::::new(), + headers: Headers::new(), } } @@ -792,15 +795,18 @@ impl ResponseHeaders<'_, N> { } } +impl<'b, const N: usize> Default for ResponseHeaders<'b, N> { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + impl Display for ResponseHeaders<'_, N> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(http11) = self.http11 { - writeln!(f, "{} ", if http11 { "HTTP/1.1 " } else { "HTTP/1.0" })?; - } + write!(f, "{} ", if self.http11 { "HTTP/1.1 " } else { "HTTP/1.0" })?; - if let Some(code) = self.code { - writeln!(f, "{code} {}", self.reason.unwrap_or(""))?; - } + writeln!(f, "{} {}", self.code, self.reason.unwrap_or(""))?; for (name, value) in self.headers.iter() { if name.is_empty() { @@ -859,11 +865,11 @@ pub mod ws { } /// Check if the request is a Websocket upgrade request - pub fn is_upgrade_request<'a, H>(method: Option, request_headers: H) -> bool + pub fn is_upgrade_request<'a, H>(method: Method, request_headers: H) -> bool where H: IntoIterator, { - if !matches!(method, Some(Method::Get)) { + if method != Method::Get { return false; } @@ -960,7 +966,7 @@ pub mod ws { /// - `nonce`: The nonce used for the `Sec-WebSocket-Key` header in the WS upgrade request /// - `buf`: A buffer to use when performing the check pub fn is_upgrade_accepted<'a, H>( - code: Option, + code: u16, response_headers: H, nonce: &[u8; NONCE_LEN], buf: &'a mut [u8; MAX_BASE64_KEY_RESPONSE_LEN], @@ -968,7 +974,7 @@ pub mod ws { where H: IntoIterator, { - if !matches!(code, Some(101)) { + if code != 101 { return false; } @@ -1407,11 +1413,11 @@ mod embedded_svc_compat { impl<'b, const N: usize> embedded_svc::http::Query for super::RequestHeaders<'b, N> { fn uri(&self) -> &'_ str { - self.path.unwrap_or("") + self.path } fn method(&self) -> Method { - self.method.unwrap_or(super::Method::Get).into() + self.method.into() } } @@ -1423,7 +1429,7 @@ mod embedded_svc_compat { impl<'b, const N: usize> embedded_svc::http::Status for super::ResponseHeaders<'b, N> { fn status(&self) -> u16 { - self.code.unwrap_or(200) + self.code } fn status_message(&self) -> Option<&'_ str> { diff --git a/examples/http_server.rs b/examples/http_server.rs index 5b01887..bbdda29 100644 --- a/examples/http_server.rs +++ b/examples/http_server.rs @@ -42,10 +42,10 @@ where async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; - if !matches!(headers.method, Some(Method::Get)) { + if headers.method != Method::Get { conn.initiate_response(405, Some("Method Not Allowed"), &[]) .await?; - } else if !matches!(headers.path, Some("/")) { + } else if headers.path != "/" { conn.initiate_response(404, Some("Not Found"), &[]).await?; } else { conn.initiate_response(200, Some("OK"), &[("Content-Type", "text/plain")]) diff --git a/examples/ws_server.rs b/examples/ws_server.rs index cdbd48a..a047a1a 100644 --- a/examples/ws_server.rs +++ b/examples/ws_server.rs @@ -58,10 +58,10 @@ where async fn handle(&self, conn: &mut Connection<'b, T, N>) -> Result<(), Self::Error> { let headers = conn.headers()?; - if !matches!(headers.method, Some(Method::Get)) { + if headers.method != Method::Get { conn.initiate_response(405, Some("Method Not Allowed"), &[]) .await?; - } else if !matches!(headers.path, Some("/")) { + } else if headers.path != "/" { conn.initiate_response(404, Some("Not Found"), &[]).await?; } else if !conn.is_ws_upgrade_request()? { conn.initiate_response(200, Some("OK"), &[("Content-Type", "text/plain")])