diff --git a/Cargo.lock b/Cargo.lock
index 59f8430d1..8da4c8f88 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -377,6 +377,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+[[package]]
+name = "cfg_aliases"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
+
[[package]]
name = "clap"
version = "4.5.1"
@@ -548,6 +554,16 @@ dependencies = [
"typenum",
]
+[[package]]
+name = "ctrlc"
+version = "3.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
+dependencies = [
+ "nix 0.28.0",
+ "windows-sys 0.52.0",
+]
+
[[package]]
name = "data-encoding"
version = "2.5.0"
@@ -1323,6 +1339,18 @@ dependencies = [
"libc",
]
+[[package]]
+name = "nix"
+version = "0.28.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
+dependencies = [
+ "bitflags 2.4.2",
+ "cfg-if",
+ "cfg_aliases",
+ "libc",
+]
+
[[package]]
name = "nom"
version = "7.1.3"
@@ -1644,7 +1672,7 @@ dependencies = [
"inferno",
"libc",
"log",
- "nix",
+ "nix 0.26.4",
"once_cell",
"parking_lot",
"prost",
@@ -1958,6 +1986,7 @@ dependencies = [
"bytes",
"clap",
"config",
+ "ctrlc",
"flume",
"futures-util",
"metrics",
diff --git a/rumqttd/CHANGELOG.md b/rumqttd/CHANGELOG.md
index 090bef2fe..07ee5456d 100644
--- a/rumqttd/CHANGELOG.md
+++ b/rumqttd/CHANGELOG.md
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Assign random identifier to clients connecting with empty client id.
+- Ability to gracefully shut down the broker.
### Changed
- Public re-export `Strategy` for shared subscriptions
diff --git a/rumqttd/Cargo.toml b/rumqttd/Cargo.toml
index 8d1fd604b..6865457b6 100644
--- a/rumqttd/Cargo.toml
+++ b/rumqttd/Cargo.toml
@@ -16,6 +16,7 @@ tokio = { version = "1.36", features = ["rt", "time", "net", "io-util", "macros"
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.113"
bytes = { version = "1", features = ["serde"] }
+ctrlc = "3.4"
flume = { version = "0.11.0", default-features = false, features = ["async"]}
slab = "0.4.9"
thiserror = "1.0.57"
diff --git a/rumqttd/examples/external_auth.rs b/rumqttd/examples/external_auth.rs
index f5d7cdfcd..805b6b67f 100644
--- a/rumqttd/examples/external_auth.rs
+++ b/rumqttd/examples/external_auth.rs
@@ -32,9 +32,15 @@ fn main() {
// true
// });
- let mut broker = Broker::new(config);
+ let _guard = Broker::new(config).start().unwrap().drop_guard();
- broker.start().unwrap();
+ let (tx, rx) = flume::bounded::<()>(1);
+ ctrlc::set_handler(move || {
+ let _ = tx.send(());
+ })
+ .expect("Error setting Ctrl-C handler");
+
+ rx.recv().expect("Could not receive Ctrl-C signal");
}
async fn auth(_client_id: String, _username: String, _password: String) -> bool {
diff --git a/rumqttd/examples/graceful_shutdown.rs b/rumqttd/examples/graceful_shutdown.rs
new file mode 100644
index 000000000..ac318688c
--- /dev/null
+++ b/rumqttd/examples/graceful_shutdown.rs
@@ -0,0 +1,36 @@
+use rumqttd::{Broker, Config};
+use tracing::Level;
+
+use core::time::Duration;
+use std::thread;
+
+fn main() {
+ let builder = tracing_subscriber::fmt()
+ .pretty()
+ .with_max_level(Level::DEBUG)
+ .with_line_number(false)
+ .with_file(false)
+ .with_thread_ids(false)
+ .with_thread_names(false);
+
+ builder
+ .try_init()
+ .expect("initialized subscriber succesfully");
+
+ // As examples are compiled as seperate binary so this config is current path dependent. Run it
+ // from root of this crate
+ let config = config::Config::builder()
+ .add_source(config::File::with_name("rumqttd.toml"))
+ .build()
+ .unwrap();
+
+ let config: Config = config.try_deserialize().unwrap();
+
+ dbg!(&config);
+
+ let handler = Broker::new(config).start().unwrap();
+
+ thread::sleep(Duration::from_secs(1));
+ handler.shutdown();
+ thread::sleep(Duration::from_secs(1));
+}
diff --git a/rumqttd/examples/meters.rs b/rumqttd/examples/meters.rs
index ba872cca0..a4a1388fb 100644
--- a/rumqttd/examples/meters.rs
+++ b/rumqttd/examples/meters.rs
@@ -15,7 +15,7 @@ fn main() {
dbg!(&config);
- let mut broker = Broker::new(config);
+ let broker = Broker::new(config);
let meters = broker.meters().unwrap();
let (mut link_tx, mut link_rx) = broker.link("consumer").unwrap();
@@ -59,11 +59,7 @@ fn main() {
});
}
- thread::spawn(move || {
- if let Err(e) = broker.start() {
- println!("Broker stopped: {e}");
- }
- });
+ let _guard = broker.start().unwrap().drop_guard();
thread::sleep(Duration::from_secs(2));
loop {
diff --git a/rumqttd/examples/singlenode.rs b/rumqttd/examples/singlenode.rs
index bcab216c3..50477a568 100644
--- a/rumqttd/examples/singlenode.rs
+++ b/rumqttd/examples/singlenode.rs
@@ -1,7 +1,5 @@
use rumqttd::{Broker, Config, Notification};
-use std::thread;
-
fn main() {
let builder = tracing_subscriber::fmt()
.pretty()
@@ -25,11 +23,9 @@ fn main() {
dbg!(&config);
- let mut broker = Broker::new(config);
+ let broker = Broker::new(config);
let (mut link_tx, mut link_rx) = broker.link("singlenode").unwrap();
- thread::spawn(move || {
- broker.start().unwrap();
- });
+ let _guard = broker.start().unwrap().drop_guard();
link_tx.subscribe("#").unwrap();
diff --git a/rumqttd/src/lib.rs b/rumqttd/src/lib.rs
index 0b62f89f8..6b544617c 100644
--- a/rumqttd/src/lib.rs
+++ b/rumqttd/src/lib.rs
@@ -23,7 +23,7 @@ pub use link::local;
pub use link::meters;
pub use router::{Alert, IncomingMeter, Meter, Notification, OutgoingMeter};
use segments::Storage;
-pub use server::Broker;
+pub use server::{Broker, BrokerHandler, ShutdownDropGuard};
pub use self::router::shared_subs::Strategy;
diff --git a/rumqttd/src/link/bridge.rs b/rumqttd/src/link/bridge.rs
index 0c6b193be..eac2050dd 100644
--- a/rumqttd/src/link/bridge.rs
+++ b/rumqttd/src/link/bridge.rs
@@ -12,6 +12,7 @@ use std::{io, net::AddrParseError, time::Duration};
use tokio::{
net::TcpStream,
+ sync::watch,
time::{sleep, sleep_until, Instant},
};
@@ -48,6 +49,7 @@ pub async fn start
(
config: BridgeConfig,
router_tx: Sender<(ConnectionId, Event)>,
protocol: P,
+ mut shutdown_rx: watch::Receiver<()>,
) -> Result<(), BridgeError>
where
P: Protocol + Clone + Send + 'static,
@@ -154,6 +156,10 @@ where
// resetting timeout because tokio::select! consumes the old timeout future
timeout = sleep_until(ping_time + Duration::from_secs(config.ping_delay));
}
+ _ = shutdown_rx.changed() => {
+ debug!("Shutting down bridge");
+ break 'outer Ok(());
+ }
}
}
}
diff --git a/rumqttd/src/link/console.rs b/rumqttd/src/link/console.rs
index 9f03f8a1e..d566e1896 100644
--- a/rumqttd/src/link/console.rs
+++ b/rumqttd/src/link/console.rs
@@ -9,8 +9,8 @@ use axum::Json;
use axum::{routing::get, Router};
use flume::Sender;
use std::sync::Arc;
-use tokio::net::TcpListener;
-use tracing::info;
+use tokio::{net::TcpListener, sync::watch};
+use tracing::{debug, info};
#[derive(Debug)]
pub struct ConsoleLink {
@@ -39,7 +39,7 @@ impl ConsoleLink {
}
#[tracing::instrument]
-pub async fn start(console: Arc) {
+pub async fn start(console: Arc, mut shutdown_rx: watch::Receiver<()>) {
let listener = TcpListener::bind(console.config.listen.clone())
.await
.unwrap();
@@ -56,7 +56,13 @@ pub async fn start(console: Arc) {
.route("/logs", post(logs))
.with_state(console);
- axum::serve(listener, app).await.unwrap();
+ axum::serve(listener, app)
+ .with_graceful_shutdown(async move {
+ debug!("Shutting down console");
+ let _ = shutdown_rx.changed().await;
+ })
+ .await
+ .unwrap();
}
async fn root(State(console): State>) -> impl IntoResponse {
diff --git a/rumqttd/src/link/timer.rs b/rumqttd/src/link/timer.rs
index fcc3bf00c..e72d48d25 100644
--- a/rumqttd/src/link/timer.rs
+++ b/rumqttd/src/link/timer.rs
@@ -5,7 +5,8 @@ use crate::{router::Event, MetricType};
use crate::{ConnectionId, MetricSettings};
use flume::{SendError, Sender};
use tokio::select;
-use tracing::error;
+use tokio::sync::watch;
+use tracing::{debug, error};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@@ -18,6 +19,7 @@ pub enum Error {
pub async fn start(
config: HashMap,
router_tx: Sender<(ConnectionId, Event)>,
+ mut shutdown_rx: watch::Receiver<()>,
) {
let span = tracing::info_span!("metrics_timer");
let _guard = span.enter();
@@ -42,6 +44,10 @@ pub async fn start(
error!("Failed to push alerts: {e}");
}
}
+ _ = shutdown_rx.changed() => {
+ debug!("Shutting down metrics timer");
+ break;
+ }
}
}
}
diff --git a/rumqttd/src/main.rs b/rumqttd/src/main.rs
index 8b22805f8..d8ac9e747 100644
--- a/rumqttd/src/main.rs
+++ b/rumqttd/src/main.rs
@@ -84,10 +84,15 @@ fn main() {
validate_config(&configs);
- // println!("{:#?}", configs);
+ let _guard = Broker::new(configs).start().unwrap().drop_guard();
- let mut broker = Broker::new(configs);
- broker.start().unwrap();
+ let (tx, rx) = flume::bounded::<()>(1);
+ ctrlc::set_handler(move || {
+ let _ = tx.send(());
+ })
+ .expect("Error setting Ctrl-C handler");
+
+ rx.recv().expect("Could not receive Ctrl-C signal");
}
// Do any extra validation that needs to be done before starting the broker here.
diff --git a/rumqttd/src/server/broker.rs b/rumqttd/src/server/broker.rs
index 9886541c9..34ad68bdb 100644
--- a/rumqttd/src/server/broker.rs
+++ b/rumqttd/src/server/broker.rs
@@ -11,10 +11,10 @@ use crate::protocol::{Packet, Protocol};
use crate::server::tls::{self, TLSAcceptor};
use crate::{meters, ConnectionSettings, Meter};
use flume::{RecvError, SendError, Sender};
-use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::{Arc, Mutex};
-use tracing::{error, field, info, warn, Instrument};
+use std::{collections::HashMap, thread::JoinHandle};
+use tracing::{debug, error, field, info, warn, Instrument};
use uuid::Uuid;
#[cfg(feature = "websocket")]
@@ -39,6 +39,7 @@ use crate::router::{Event, Router};
use crate::{Config, ConnectionId, ServerSettings};
use tokio::net::{TcpListener, TcpStream};
+use tokio::sync::watch;
use tokio::time::error::Elapsed;
use tokio::{task, time};
@@ -155,8 +156,12 @@ impl Broker {
Ok((link_tx, link_rx))
}
+ /// Starts the MQTT broker server and spawns various components like the router, cluster, metrics, and console.
+ ///
+ /// This function sets up the necessary components for the MQTT broker server to run. The function returns a
+ /// [`BrokerHandler`] that can be used to gracefully shut down the server.
#[tracing::instrument(skip(self))]
- pub fn start(&mut self) -> Result<(), Error> {
+ pub fn start(self) -> Result {
if self.config.v4.is_none()
&& self.config.v5.is_none()
&& (cfg!(not(feature = "websocket")) || self.config.ws.is_none())
@@ -171,16 +176,18 @@ impl Broker {
// we don't know which servers (v4/v5/ws) user will spawn
// so we collect handles for all of the spawned servers
let mut server_thread_handles = Vec::new();
+ let (shutdown_tx, shutdown_rx) = watch::channel(());
if let Some(metrics_config) = self.config.metrics.clone() {
let timer_thread = thread::Builder::new().name("timer".to_owned());
let router_tx = self.router_tx.clone();
+ let shutdown_rx = shutdown_rx.clone();
timer_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
runtime.block_on(async move {
- timer::start(metrics_config, router_tx).await;
+ timer::start(metrics_config, router_tx, shutdown_rx).await;
});
})?;
}
@@ -189,12 +196,13 @@ impl Broker {
if let Some(bridge_config) = self.config.bridge.clone() {
let bridge_thread = thread::Builder::new().name(bridge_config.name.clone());
let router_tx = self.router_tx.clone();
+ let shutdown_rx = shutdown_rx.clone();
bridge_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
runtime.block_on(async move {
- if let Err(e) = bridge::start(bridge_config, router_tx, V4).await {
+ if let Err(e) = bridge::start(bridge_config, router_tx, V4, shutdown_rx).await {
error!(error=?e, "Bridge Link error");
};
});
@@ -206,13 +214,16 @@ impl Broker {
for (_, config) in v4_config.clone() {
let server_thread = thread::Builder::new().name(config.name.clone());
let mut server = Server::new(config, self.router_tx.clone(), V4);
+ let shutdown_rx = shutdown_rx.clone();
let handle = server_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
runtime.block_on(async {
- if let Err(e) = server.start(LinkType::Remote).await {
+ if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await {
error!(error=?e, "Server error - V4");
+ } else {
+ debug!("Shutting down v4 server");
}
});
})?;
@@ -224,13 +235,16 @@ impl Broker {
for (_, config) in v5_config.clone() {
let server_thread = thread::Builder::new().name(config.name.clone());
let mut server = Server::new(config, self.router_tx.clone(), V5);
+ let shutdown_rx = shutdown_rx.clone();
let handle = server_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
runtime.block_on(async {
- if let Err(e) = server.start(LinkType::Remote).await {
+ if let Err(e) = server.start(LinkType::Remote, shutdown_rx).await {
error!(error=?e, "Server error - V5");
+ } else {
+ debug!("Shutting down v5 server");
}
});
})?;
@@ -249,13 +263,16 @@ impl Broker {
let server_thread = thread::Builder::new().name(config.name.clone());
//TODO: Add support for V5 procotol with websockets. Registered in config or on ServerSettings
let mut server = Server::new(config, self.router_tx.clone(), V4);
+ let shutdown_rx = shutdown_rx.clone();
let handle = server_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
runtime.block_on(async {
- if let Err(e) = server.start(LinkType::Websocket).await {
+ if let Err(e) = server.start(LinkType::Websocket, shutdown_rx).await {
error!(error=?e, "Server error - WS");
+ } else {
+ debug!("Shutting down websocket server");
}
});
})?;
@@ -280,6 +297,7 @@ impl Broker {
};
let metrics_thread = thread::Builder::new().name("Metrics".to_owned());
let meter_link = self.meters().unwrap();
+ let shutdown_rx = shutdown_rx.clone();
metrics_thread.spawn(move || {
let builder = PrometheusBuilder::new().with_http_listener(addr);
builder.install().unwrap();
@@ -301,6 +319,11 @@ impl Broker {
}
}
+ if shutdown_rx.has_changed().is_ok_and(|flag| flag) {
+ debug!("Shutting down metrics");
+ break;
+ }
+
std::thread::sleep(Duration::from_secs(timeout));
}
})?;
@@ -311,23 +334,21 @@ impl Broker {
let console_link = Arc::new(console_link);
let console_thread = thread::Builder::new().name("Console".to_string());
+ let shutdown_rx = shutdown_rx.clone();
console_thread.spawn(move || {
let mut runtime = tokio::runtime::Builder::new_current_thread();
let runtime = runtime.enable_all().build().unwrap();
- runtime.block_on(console::start(console_link));
+ runtime.block_on(console::start(console_link, shutdown_rx));
})?;
}
- // in ideal case, where server doesn't crash, join() will never resolve
- // we still try to join threads so that we don't return from function
- // unless everything crashes.
- server_thread_handles.into_iter().for_each(|handle| {
- // join() might panic in case the thread panics
- // we just ignore it
- let _ = handle.join();
- });
-
- Ok(())
+ Ok(BrokerHandler {
+ inner: InnerHandler {
+ threads: server_thread_handles,
+ shutdown_tx,
+ }
+ .into(),
+ })
}
}
@@ -379,7 +400,11 @@ impl Server {
Ok((Box::new(stream), None))
}
- async fn start(&mut self, link_type: LinkType) -> Result<(), Error> {
+ async fn start(
+ &mut self,
+ link_type: LinkType,
+ mut shutdown_rx: watch::Receiver<()>,
+ ) -> Result<(), Error> {
let listener = TcpListener::bind(&self.config.listen).await?;
let delay = Duration::from_millis(self.config.next_connection_delay_ms);
let mut count: usize = 0;
@@ -392,19 +417,33 @@ impl Server {
);
loop {
// Await new network connection.
- let (stream, addr) = match listener.accept().await {
- Ok((s, r)) => (s, r),
- Err(e) => {
- error!(error=?e, "Unable to accept socket.");
- continue;
+ let (stream, addr) = tokio::select! {
+ accept = listener.accept() => {
+ match accept {
+ Ok((s, r)) => (s, r),
+ Err(e) => {
+ error!(error=?e, "Unable to accept socket.");
+ continue;
+ }
+ }
+ }
+ _ = shutdown_rx.changed() => {
+ return Ok(());
}
};
- let (network, tenant_id) = match self.tls_accept(stream).await {
- Ok(o) => o,
- Err(e) => {
- error!(error=?e, "Tls accept error");
- continue;
+ let (network, tenant_id) = tokio::select! {
+ accept = self.tls_accept(stream) => {
+ match accept {
+ Ok(o) => o,
+ Err(e) => {
+ error!(error=?e, "Tls accept error");
+ continue;
+ }
+ }
+ }
+ _ = shutdown_rx.changed() => {
+ return Ok(());
}
};
@@ -420,11 +459,18 @@ impl Server {
match link_type {
#[cfg(feature = "websocket")]
LinkType::Websocket => {
- let stream = match accept_hdr_async(network, WSCallback).await {
- Ok(s) => Box::new(WsStream::new(s)),
- Err(e) => {
- error!(error=?e, "Websocket failed handshake");
- continue;
+ let stream = tokio::select! {
+ hdr_accept = accept_hdr_async(network, WSCallback) => {
+ match hdr_accept {
+ Ok(s) => Box::new(WsStream::new(s)),
+ Err(e) => {
+ error!(error=?e, "Websocket failed handshake");
+ continue;
+ }
+ }
+ }
+ _ = shutdown_rx.changed() => {
+ return Ok(());
}
};
task::spawn(
@@ -461,7 +507,12 @@ impl Server {
),
};
- time::sleep(delay).await;
+ tokio::select! {
+ _ = time::sleep(delay) => {}
+ _ = shutdown_rx.changed() => {
+ return Ok(());
+ }
+ };
}
}
}
@@ -627,3 +678,100 @@ async fn remote(
router_tx.send((connection_id, message)).ok();
}
}
+
+/// An internal handler that manages the shutdown process for a broker.
+///
+/// The `InnerHandler` struct is responsible for coordinating the shutdown of a broker by maintaining a list of running
+/// threads and a channel for sending a shutdown signal. It is used internally by the `ShutdownHandler` to manage the
+/// shutdown process.
+#[derive(Debug)]
+struct InnerHandler {
+ threads: Vec>,
+ shutdown_tx: watch::Sender<()>,
+}
+
+/// A struct that handles the shutdown process for a broker.
+///
+/// The `ShutdownHandler` struct is responsible for coordinating the shutdown of a broker by sending a shutdown signal to
+/// all running threads and waiting for them to join.
+#[derive(Debug)]
+pub struct BrokerHandler {
+ inner: Option,
+}
+
+impl BrokerHandler {
+ /// Shuts down the server by sending a shutdown signal to all running threads and waiting for them to join.
+ ///
+ /// This method is responsible for coordinating the shutdown of the broker by taking ownership of the `BrokerHandler`
+ /// and sending a shutdown signal to all running threads. It then waits for each thread to join before returning.
+ /// This ensures that all broker resources are properly cleaned up and the broker can be safely shut down.
+ pub fn shutdown(mut self) {
+ if let Some(handler) = self.inner.take() {
+ let _ = handler.shutdown_tx.send(());
+ for thread in handler.threads {
+ let _ = thread.join();
+ }
+ }
+ }
+
+ /// Joins all running threads associated with the `BrokerHandler`.
+ ///
+ /// This method takes ownership of the `BrokerHandler` and waits for all running threads to join. This ensures that all
+ /// broker resources are properly cleaned up and the broker can be safely shut down.
+ pub fn join(mut self) {
+ if let Some(handler) = self.inner.take() {
+ for thread in handler.threads {
+ let _ = thread.join();
+ }
+ }
+ }
+
+ /// Creates a `ShutdownDropGuard` that will automatically handle the shutdown process when the guard is dropped.
+ ///
+ /// The `drop_guard()` method takes ownership of the `BrokerHandler` and returns a [`ShutdownDropGuard`] that will
+ /// automatically send a shutdown signal to all running threads and wait for them to join when the guard is dropped.
+ /// This is useful for ensuring that the server is properly shut down, even in the event of an unexpected error or
+ /// early exit from the program.
+ #[inline]
+ pub fn drop_guard(mut self) -> ShutdownDropGuard {
+ ShutdownDropGuard {
+ inner: self.inner.take(),
+ }
+ }
+}
+
+/// A guard that automatically handles the shutdown process for a broker when dropped.
+///
+/// The `ShutdownDropGuard` struct is responsible for coordinating the shutdown of a broker by sending a shutdown signal to
+/// all running threads and waiting for them to join when the guard is dropped. This ensures that the broker is properly
+/// shut down, even in the event of an unexpected error or early exit from the program.
+#[derive(Debug)]
+pub struct ShutdownDropGuard {
+ inner: Option,
+}
+
+impl ShutdownDropGuard {
+ /// Disarms the `ShutdownDropGuard` and returns a [`BrokerHandler`] that can be used to manually shut down the broker.
+ ///
+ /// This method takes ownership of the `ShutdownDropGuard` and returns a new [`BrokerHandler`] that contains the same
+ /// internal state as the `ShutdownDropGuard`. This allows the caller to take control of the shutdown process and
+ /// manually shut down the server when needed, rather than relying on the automatic shutdown when the `ShutdownDropGuard`
+ /// is dropped.
+ #[inline]
+ pub fn disarm(mut self) -> BrokerHandler {
+ BrokerHandler {
+ inner: self.inner.take(),
+ }
+ }
+}
+
+impl Drop for ShutdownDropGuard {
+ fn drop(&mut self) {
+ if let Some(handler) = self.inner.take() {
+ let _ = handler.shutdown_tx.send(());
+ for thread in handler.threads {
+ let _ = thread.join();
+ }
+ }
+ }
+}
diff --git a/rumqttd/src/server/mod.rs b/rumqttd/src/server/mod.rs
index fb151b1bf..4fcfc9a00 100644
--- a/rumqttd/src/server/mod.rs
+++ b/rumqttd/src/server/mod.rs
@@ -4,7 +4,7 @@ mod broker;
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
mod tls;
-pub use broker::Broker;
+pub use broker::{Broker, BrokerHandler, ShutdownDropGuard};
// pub trait IO: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
// impl IO for T {}