Skip to content

Commit 3e88655

Browse files
committed
add opt-in http/1.0 support
1 parent 2d17054 commit 3e88655

File tree

5 files changed

+186
-52
lines changed

5 files changed

+186
-52
lines changed

src/server/decode.rs

+43-38
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,29 @@ use std::str::FromStr;
55
use async_dup::{Arc, Mutex};
66
use async_std::io::{BufReader, Read, Write};
77
use async_std::{prelude::*, task};
8-
use http_types::content::ContentLength;
9-
use http_types::headers::{EXPECT, TRANSFER_ENCODING};
10-
use http_types::{ensure, ensure_eq, format_err};
8+
use http_types::{content::ContentLength, Version};
9+
use http_types::{ensure, format_err};
10+
use http_types::{
11+
headers::{EXPECT, TRANSFER_ENCODING},
12+
StatusCode,
13+
};
1114
use http_types::{Body, Method, Request, Url};
1215

1316
use super::body_reader::BodyReader;
14-
use crate::chunked::ChunkedDecoder;
1517
use crate::read_notifier::ReadNotifier;
18+
use crate::{chunked::ChunkedDecoder, ServerOptions};
1619
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1720

1821
const LF: u8 = b'\n';
1922

20-
/// The number returned from httparse when the request is HTTP 1.1
21-
const HTTP_1_1_VERSION: u8 = 1;
22-
2323
const CONTINUE_HEADER_VALUE: &str = "100-continue";
2424
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
2525

2626
/// Decode an HTTP request on the server.
27-
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
27+
pub async fn decode<IO>(
28+
mut io: IO,
29+
opts: &ServerOptions,
30+
) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
2831
where
2932
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
3033
{
@@ -63,21 +66,22 @@ where
6366
let method = httparse_req.method;
6467
let method = method.ok_or_else(|| format_err!("No method found"))?;
6568

66-
let version = httparse_req.version;
67-
let version = version.ok_or_else(|| format_err!("No version found"))?;
68-
69-
ensure_eq!(
70-
version,
71-
HTTP_1_1_VERSION,
72-
"Unsupported HTTP version 1.{}",
73-
version
74-
);
69+
let version = match (&opts.default_host, httparse_req.version) {
70+
(Some(_), None) | (Some(_), Some(0)) => Version::Http1_0,
71+
(_, Some(1)) => Version::Http1_1,
72+
_ => {
73+
let mut err = format_err!("http version not supported");
74+
err.set_status(StatusCode::HttpVersionNotSupported);
75+
return Err(err);
76+
}
77+
};
7578

76-
let url = url_from_httparse_req(&httparse_req)?;
79+
let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref())
80+
.ok_or_else(|| format_err!("unable to construct url from request"))?;
7781

7882
let mut req = Request::new(Method::from_str(method)?, url);
7983

80-
req.set_version(Some(http_types::Version::Http1_1));
84+
req.set_version(Some(version));
8185

8286
for header in httparse_req.headers.iter() {
8387
req.append_header(header.name, std::str::from_utf8(header.value)?);
@@ -141,26 +145,27 @@ where
141145
}
142146
}
143147

144-
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
145-
let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
148+
fn url_from_httparse_req(
149+
req: &httparse::Request<'_, '_>,
150+
default_host: Option<&str>,
151+
) -> Option<Url> {
152+
let path = req.path?;
146153

147154
let host = req
148155
.headers
149156
.iter()
150157
.find(|x| x.name.eq_ignore_ascii_case("host"))
151-
.ok_or_else(|| format_err!("Mandatory Host header missing"))?
152-
.value;
153-
154-
let host = std::str::from_utf8(host)?;
158+
.and_then(|x| std::str::from_utf8(x.value).ok())
159+
.or(default_host)?;
155160

156161
if path.starts_with("http://") || path.starts_with("https://") {
157-
Ok(Url::parse(path)?)
162+
Url::parse(path).ok()
158163
} else if path.starts_with('/') {
159-
Ok(Url::parse(&format!("http://{}{}", host, path))?)
164+
Url::parse(&format!("http://{}{}", host, path)).ok()
160165
} else if req.method.unwrap().eq_ignore_ascii_case("connect") {
161-
Ok(Url::parse(&format!("http://{}/", path))?)
166+
Url::parse(&format!("http://{}/", path)).ok()
162167
} else {
163-
Err(format_err!("unexpected uri format"))
168+
None
164169
}
165170
}
166171

@@ -180,7 +185,7 @@ mod tests {
180185
httparse_req(
181186
"CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
182187
|req| {
183-
let url = url_from_httparse_req(&req).unwrap();
188+
let url = url_from_httparse_req(&req, None).unwrap();
184189
assert_eq!(url.as_str(), "http://server.example.com:443/");
185190
},
186191
);
@@ -191,7 +196,7 @@ mod tests {
191196
httparse_req(
192197
"GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
193198
|req| {
194-
let url = url_from_httparse_req(&req).unwrap();
199+
let url = url_from_httparse_req(&req, None).unwrap();
195200
assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
196201
},
197202
)
@@ -202,7 +207,7 @@ mod tests {
202207
httparse_req(
203208
"GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
204209
|req| {
205-
let url = url_from_httparse_req(&req).unwrap();
210+
let url = url_from_httparse_req(&req, None).unwrap();
206211
assert_eq!(url.as_str(), "http://domain.com/some/resource"); // host header MUST be ignored according to spec
207212
},
208213
)
@@ -213,7 +218,7 @@ mod tests {
213218
httparse_req(
214219
"CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
215220
|req| {
216-
let url = url_from_httparse_req(&req).unwrap();
221+
let url = url_from_httparse_req(&req, None).unwrap();
217222
assert_eq!(url.as_str(), "http://server.example.com:443/");
218223
},
219224
)
@@ -224,7 +229,7 @@ mod tests {
224229
httparse_req(
225230
"GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
226231
|req| {
227-
assert!(url_from_httparse_req(&req).is_err());
232+
assert!(url_from_httparse_req(&req, None).is_none());
228233
},
229234
)
230235
}
@@ -234,7 +239,7 @@ mod tests {
234239
httparse_req(
235240
"GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
236241
|req| {
237-
let url = url_from_httparse_req(&req).unwrap();
242+
let url = url_from_httparse_req(&req, None).unwrap();
238243
assert_eq!(
239244
url.as_str(),
240245
"http://server.example.com:443//double/slashes"
@@ -247,7 +252,7 @@ mod tests {
247252
httparse_req(
248253
"GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
249254
|req| {
250-
let url = url_from_httparse_req(&req).unwrap();
255+
let url = url_from_httparse_req(&req, None).unwrap();
251256
assert_eq!(
252257
url.as_str(),
253258
"http://server.example.com:443///triple/slashes"
@@ -261,7 +266,7 @@ mod tests {
261266
httparse_req(
262267
"GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
263268
|req| {
264-
let url = url_from_httparse_req(&req).unwrap();
269+
let url = url_from_httparse_req(&req, None).unwrap();
265270
assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
266271
},
267272
)
@@ -272,7 +277,7 @@ mod tests {
272277
httparse_req(
273278
"GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
274279
|req| {
275-
let url = url_from_httparse_req(&req).unwrap();
280+
let url = url_from_httparse_req(&req, None).unwrap();
276281
assert_eq!(
277282
url.as_str(),
278283
"http://server.example.com:443/foo?bar=1#anchor"

src/server/mod.rs

+42-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
33
use async_std::future::{timeout, Future, TimeoutError};
44
use async_std::io::{self, Read, Write};
5-
use http_types::headers::{CONNECTION, UPGRADE};
65
use http_types::upgrade::Connection;
6+
use http_types::{
7+
headers::{CONNECTION, UPGRADE},
8+
Version,
9+
};
710
use http_types::{Request, Response, StatusCode};
811
use std::{marker::PhantomData, time::Duration};
912
mod body_reader;
@@ -18,12 +21,42 @@ pub use encode::Encoder;
1821
pub struct ServerOptions {
1922
/// Timeout to handle headers. Defaults to 60s.
2023
headers_timeout: Option<Duration>,
24+
default_host: Option<String>,
25+
}
26+
27+
impl ServerOptions {
28+
/// constructs a new ServerOptions with default settings
29+
pub fn new() -> Self {
30+
Self::default()
31+
}
32+
33+
/// sets the timeout by which the headers must have been received
34+
pub fn with_headers_timeout(mut self, headers_timeout: Duration) -> Self {
35+
self.headers_timeout = Some(headers_timeout);
36+
self
37+
}
38+
39+
/// Sets the default http 1.0 host for this server. If no host
40+
/// header is provided on an http/1.0 request, this host will be
41+
/// used to construct the Request Url.
42+
///
43+
/// If this is not provided, the server will respond to all
44+
/// http/1.0 requests with status `505 http version not
45+
/// supported`, whether or not a host header is provided.
46+
///
47+
/// The default value for this is None, and as a result async-h1
48+
/// is by default an http-1.1-only server.
49+
pub fn with_default_host(mut self, default_host: &str) -> Self {
50+
self.default_host = Some(default_host.into());
51+
self
52+
}
2153
}
2254

2355
impl Default for ServerOptions {
2456
fn default() -> Self {
2557
Self {
2658
headers_timeout: Some(Duration::from_secs(60)),
59+
default_host: None,
2760
}
2861
}
2962
}
@@ -111,7 +144,7 @@ where
111144
Fut: Future<Output = Response>,
112145
{
113146
// Decode a new request, timing out if this takes longer than the timeout duration.
114-
let fut = decode(self.io.clone());
147+
let fut = decode(self.io.clone(), &self.opts);
115148

116149
let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
117150
match timeout(timeout_duration, fut).await {
@@ -133,7 +166,12 @@ where
133166
.unwrap_or("");
134167

135168
let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade");
136-
let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
169+
170+
let mut close_connection = if req.version() == Some(Version::Http1_0) {
171+
!connection_header_as_str.eq_ignore_ascii_case("keep-alive")
172+
} else {
173+
connection_header_as_str.eq_ignore_ascii_case("close")
174+
};
137175

138176
let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
139177

@@ -168,7 +206,7 @@ where
168206

169207
if let Some(upgrade_sender) = upgrade_sender {
170208
upgrade_sender.send(Connection::new(self.io.clone())).await;
171-
return Ok(ConnectionStatus::Close);
209+
Ok(ConnectionStatus::Close)
172210
} else if close_connection {
173211
Ok(ConnectionStatus::Close)
174212
} else {

tests/continue.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ async fn test_with_expect_when_reading_body() -> Result<()> {
1616
let (mut client, server) = TestIO::new();
1717
client.write_all(REQUEST_WITH_EXPECT).await?;
1818

19-
let (mut request, _) = async_h1::server::decode(server).await?.unwrap();
19+
let (mut request, _) = async_h1::server::decode(server, &Default::default())
20+
.await?
21+
.unwrap();
2022

2123
task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written
2224

@@ -44,7 +46,9 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> {
4446
let (mut client, server) = TestIO::new();
4547
client.write_all(REQUEST_WITH_EXPECT).await?;
4648

47-
let (_, _) = async_h1::server::decode(server).await?.unwrap();
49+
let (_, _) = async_h1::server::decode(server, &Default::default())
50+
.await?
51+
.unwrap();
4852

4953
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel
5054

tests/server-chunked-encode-large.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async fn server_chunked_large() -> Result<()> {
7676
let (mut client, server) = TestIO::new();
7777
async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?;
7878

79-
let (request, _) = server::decode(server).await?.unwrap();
79+
let (request, _) = server::decode(server, &Default::default()).await?.unwrap();
8080

8181
let mut response = Response::new(200);
8282
response.set_body(Body::from_reader(request, None));

0 commit comments

Comments
 (0)