Skip to content

Commit ed698ac

Browse files
committed
Unifty httpd in blackholes
This commit unifies the httpd bits of the blackholes, removing a duplicated loop in three places. The type logic here started out general and got gradually more concrete and there's scope to reduce duplication even further but I don't know that it's a pressing issue. Signed-off-by: Brian L. Troutwine <[email protected]>
1 parent 8c123ec commit ed698ac

File tree

5 files changed

+286
-236
lines changed

5 files changed

+286
-236
lines changed

lading/src/blackhole.rs

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
use serde::{Deserialize, Serialize};
99

10+
mod common;
1011
pub mod http;
1112
pub mod splunk_hec;
1213
pub mod sqs;

lading/src/blackhole/common.rs

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use bytes::Bytes;
2+
use http_body_util::combinators::BoxBody;
3+
use hyper::service::Service;
4+
use hyper_util::{
5+
rt::{TokioExecutor, TokioIo},
6+
server::conn::auto,
7+
};
8+
use lading_signal::Watcher;
9+
use std::{net::SocketAddr, sync::Arc};
10+
use tokio::{net::TcpListener, pin, sync::Semaphore, task::JoinSet};
11+
use tracing::{debug, error, info};
12+
13+
#[derive(thiserror::Error, Debug)]
14+
pub enum Error {
15+
/// Wrapper for [`std::io::Error`].
16+
#[error("IO error: {0}")]
17+
Io(std::io::Error),
18+
}
19+
20+
pub(crate) async fn run_httpd<SF, S>(
21+
addr: SocketAddr,
22+
concurrency_limit: usize,
23+
shutdown: Watcher,
24+
make_service: SF,
25+
) -> Result<(), Error>
26+
where
27+
// "service factory"
28+
SF: Send + Sync + 'static + Clone + Fn() -> S,
29+
// The bounds on `S` per
30+
// https://docs.rs/hyper/latest/hyper/service/trait.Service.html and then
31+
// made concrete per
32+
// https://docs.rs/hyper-util/latest/hyper_util/server/conn/auto/struct.Builder.html#method.serve_connection.
33+
S: Service<
34+
hyper::Request<hyper::body::Incoming>,
35+
Response = hyper::Response<BoxBody<Bytes, hyper::Error>>,
36+
Error = hyper::Error,
37+
> + Send
38+
+ 'static,
39+
40+
S::Future: Send + 'static,
41+
{
42+
let listener = TcpListener::bind(addr).await.map_err(Error::Io)?;
43+
let sem = Arc::new(Semaphore::new(concurrency_limit));
44+
let mut join_set = JoinSet::new();
45+
46+
let shutdown_fut = shutdown.recv();
47+
pin!(shutdown_fut);
48+
loop {
49+
tokio::select! {
50+
() = &mut shutdown_fut => {
51+
info!("Shutdown signal received, stopping accept loop.");
52+
break;
53+
}
54+
55+
incoming = listener.accept() => {
56+
let (stream, addr) = match incoming {
57+
Ok(sa) => sa,
58+
Err(e) => {
59+
error!("Error accepting connection: {e}");
60+
continue;
61+
}
62+
};
63+
64+
let sem = Arc::clone(&sem);
65+
let service_factory = make_service.clone();
66+
67+
join_set.spawn(async move {
68+
debug!("Accepted connection from {addr}");
69+
let permit = match sem.acquire_owned().await {
70+
Ok(p) => p,
71+
Err(e) => {
72+
error!("Semaphore closed: {e}");
73+
return;
74+
}
75+
};
76+
77+
let builder = auto::Builder::new(TokioExecutor::new());
78+
let serve_future = builder.serve_connection_with_upgrades(
79+
TokioIo::new(stream),
80+
service_factory(),
81+
);
82+
83+
if let Err(e) = serve_future.await {
84+
error!("Error serving {addr}: {e}");
85+
}
86+
drop(permit);
87+
});
88+
}
89+
}
90+
}
91+
92+
drop(listener);
93+
while join_set.join_next().await.is_some() {}
94+
Ok(())
95+
}

lading/src/blackhole/http.rs

+31-77
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,14 @@
66
//! `requests_received`: Total requests received
77
//!
88
9-
use std::{net::SocketAddr, sync::Arc, time::Duration};
10-
119
use bytes::Bytes;
1210
use http::{header::InvalidHeaderValue, status::InvalidStatusCode, HeaderMap};
1311
use http_body_util::{combinators::BoxBody, BodyExt};
14-
use hyper::{header, service::service_fn, Request, Response, StatusCode};
15-
use hyper_util::{
16-
rt::{TokioExecutor, TokioIo},
17-
server::conn::auto,
18-
};
12+
use hyper::{header, Request, Response, StatusCode};
1913
use metrics::counter;
2014
use serde::{Deserialize, Serialize};
21-
use tokio::{pin, sync::Semaphore, task::JoinSet};
22-
use tracing::{debug, error, info};
15+
use std::{net::SocketAddr, time::Duration};
16+
use tracing::{debug, error};
2317

2418
use super::General;
2519

@@ -42,9 +36,9 @@ pub enum Error {
4236
/// Failed to deserialize the configuration.
4337
#[error("Failed to deserialize the configuration: {0}")]
4438
Serde(#[from] serde_json::Error),
45-
/// Wrapper for [`std::io::Error`].
46-
#[error("IO error: {0}")]
47-
Io(#[from] std::io::Error),
39+
/// Wrapper for [`crate::blackhole::common::Error`].
40+
#[error(transparent)]
41+
Common(#[from] crate::blackhole::common::Error),
4842
}
4943

5044
/// Body variant supported by this blackhole.
@@ -240,72 +234,32 @@ impl Http {
240234
/// Function will return an error if the configuration is invalid or if
241235
/// receiving a packet fails.
242236
pub async fn run(self) -> Result<(), Error> {
243-
let listener = tokio::net::TcpListener::bind(self.httpd_addr).await?;
244-
let sem = Arc::new(Semaphore::new(self.concurrency_limit));
245-
let mut join_set = JoinSet::new();
246-
247-
let shutdown = self.shutdown.recv();
248-
pin!(shutdown);
249-
loop {
250-
tokio::select! {
251-
() = &mut shutdown => {
252-
info!("shutdown signal received");
253-
break;
254-
}
255-
256-
incoming = listener.accept() => {
257-
let (stream, addr) = match incoming {
258-
Ok((s,a)) => (s,a),
259-
Err(e) => {
260-
error!("accept error: {e}");
261-
continue;
262-
}
263-
};
264-
265-
let metric_labels = self.metric_labels.clone();
266-
let body_bytes = self.body_bytes.clone();
267-
let headers = self.headers.clone();
268-
let status = self.status;
269-
let response_delay = self.response_delay;
270-
let sem = Arc::clone(&sem);
271-
272-
join_set.spawn(async move {
273-
debug!("Accepted connection from {addr}");
274-
let permit = match sem.acquire_owned().await {
275-
Ok(p) => p,
276-
Err(e) => {
277-
error!("Semaphore closed: {e}");
278-
return;
279-
}
280-
};
281-
282-
let builder = auto::Builder::new(TokioExecutor::new());
283-
let serve_future = builder
284-
.serve_connection(
285-
TokioIo::new(stream),
286-
service_fn(move |req: Request<hyper::body::Incoming>| {
287-
debug!("REQUEST: {:?}", req);
288-
srv(
289-
status,
290-
metric_labels.clone(),
291-
body_bytes.clone(),
292-
req,
293-
headers.clone(),
294-
response_delay,
295-
)
296-
})
297-
);
237+
crate::blackhole::common::run_httpd(
238+
self.httpd_addr,
239+
self.concurrency_limit,
240+
self.shutdown,
241+
move || {
242+
let metric_labels = self.metric_labels.clone();
243+
let body_bytes = self.body_bytes.clone();
244+
let headers = self.headers.clone();
245+
let status = self.status;
246+
let response_delay = self.response_delay;
247+
248+
hyper::service::service_fn(move |req| {
249+
debug!("REQUEST: {:?}", req);
250+
srv(
251+
status,
252+
metric_labels.clone(),
253+
body_bytes.clone(),
254+
req,
255+
headers.clone(),
256+
response_delay,
257+
)
258+
})
259+
},
260+
)
261+
.await?;
298262

299-
if let Err(e) = serve_future.await {
300-
error!("Error serving {addr}: {e}");
301-
}
302-
drop(permit);
303-
});
304-
}
305-
}
306-
}
307-
drop(listener);
308-
while join_set.join_next().await.is_some() {}
309263
Ok(())
310264
}
311265
}

0 commit comments

Comments
 (0)