diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 6d13a9e3..506ee486 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -1,5 +1,5 @@ use flume::bounded; -use rumqttc::{Client, MqttOptions, PromiseError, QoS}; +use rumqttc::{Client, MqttOptions, QoS, TokenError}; use std::error::Error; use std::thread::{self, sleep}; use std::time::Duration; @@ -26,7 +26,7 @@ fn main() -> Result<(), Box> { match client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() - .blocking_wait() + .wait() { Ok(pkid) => println!("Acknowledged Sub({pkid})"), Err(e) => println!("Subscription failed: {e:?}"), @@ -40,7 +40,7 @@ fn main() -> Result<(), Box> { match client .publish("hello/world", qos, false, vec![1; i]) .unwrap() - .blocking_wait() + .wait() { Ok(pkid) => println!("Acknowledged Pub({pkid})"), Err(e) => println!("Publish failed: {e:?}"), @@ -59,7 +59,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx = tx.clone(); thread::spawn(move || { - let res = token.blocking_wait(); + let res = token.wait(); tx.send(res).unwrap() }); } @@ -69,8 +69,8 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::AtMostOnce, false, vec![1; 4]) .unwrap(); thread::spawn(move || loop { - match token.try_resolve() { - Err(PromiseError::Waiting) => { + match token.check() { + Err(TokenError::Waiting) => { println!("Promise yet to resolve, retrying"); sleep(Duration::from_secs(1)); } @@ -89,7 +89,7 @@ fn main() -> Result<(), Box> { } // Unsubscribe and wait for broker acknowledgement - match client.unsubscribe("hello/world").unwrap().blocking_wait() { + match client.unsubscribe("hello/world").unwrap().wait() { Ok(pkid) => println!("Acknowledged Unsub({pkid})"), Err(e) => println!("Unsubscription failed: {e:?}"), } diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 47b9da0c..ffab94f0 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,9 +3,9 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; +use crate::tokens::{NoResponse, Resolver, Token}; use crate::{ - valid_filter, valid_topic, AckPromise, ConnectionError, Event, EventLoop, MqttOptions, - PromiseTx, Request, + valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Pkid, Request, }; use bytes::Bytes; @@ -23,15 +23,15 @@ pub enum ClientError { TryRequest(Request), } -impl From)>> for ClientError { - fn from(e: SendError<(Request, Option)>) -> Self { - Self::Request(e.into_inner().0) +impl From> for ClientError { + fn from(e: SendError) -> Self { + Self::Request(e.into_inner()) } } -impl From)>> for ClientError { - fn from(e: TrySendError<(Request, Option)>) -> Self { - Self::TryRequest(e.into_inner().0) +impl From> for ClientError { + fn from(e: TrySendError) -> Self { + Self::TryRequest(e.into_inner()) } } @@ -44,7 +44,7 @@ impl From)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Request, Option)>, + request_tx: Sender, } impl AsyncClient { @@ -64,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender) -> AsyncClient { AsyncClient { request_tx } } @@ -75,24 +75,22 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -102,43 +100,44 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send((publish, Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.send_async((ack, None)).await?; + self.request_tx.send_async(ack).await?; } - Ok(()) + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.try_send((ack, None))?; + self.request_tx.try_send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -148,19 +147,17 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result, ClientError> where S: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + let request = Request::Publish(publish, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -168,17 +165,18 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Subscribe to the `EventLoop` @@ -186,94 +184,101 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - self.request_tx - .send_async((unsubscribe.into(), Some(promise_tx))) - .await?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - self.request_tx - .try_send((unsubscribe.into(), Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); - self.request_tx.send_async((request, None)).await?; + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.send_async(request).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); - self.request_tx.try_send((request, None))?; + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.try_send(request)?; - Ok(()) + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), + QoS::AtMostOnce => { + resolver.resolve(()); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid), resolver), }; Some(ack) } @@ -313,7 +318,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { + pub fn from_sender(request_tx: Sender) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -326,22 +331,22 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send((publish, Some(promise_tx)))?; + self.client.request_tx.send(publish)?; - Ok(promise) + Ok(token) } pub fn try_publish( @@ -350,7 +355,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, @@ -359,18 +364,18 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.client.request_tx.send((ack, None))?; + self.client.request_tx.send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -379,17 +384,17 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -397,28 +402,29 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -426,31 +432,31 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result { - let (promise_tx, promise) = PromiseTx::new(); - let request = Request::Disconnect(Disconnect); - self.client.request_tx.send((request, Some(promise_tx)))?; + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result, ClientError> { self.client.try_disconnect() } } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index b98a390b..00063bd0 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,5 +1,5 @@ use crate::{framed::Network, Transport}; -use crate::{Incoming, MqttState, NetworkOptions, Packet, PromiseTx, Request, StateError}; +use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; use crate::framed::AsyncReadWrite; @@ -75,11 +75,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Request, Option)>, + requests_rx: Receiver, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Request, Option)>, + pub(crate) requests_tx: Sender, /// Pending packets from last session - pub pending: VecDeque<(Request, Option)>, + pub pending: VecDeque, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -132,9 +132,9 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|(request, _)| { + requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); @@ -241,8 +241,8 @@ impl EventLoop { &self.requests_rx, self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok((request, tx)) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { + Ok(request) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -260,7 +260,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq), None)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -282,10 +282,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Request, Option)>, - rx: &Receiver<(Request, Option)>, + pending: &mut VecDeque, + rx: &Receiver, pending_throttle: Duration, - ) -> Result<(Request, Option), ConnectionError> { + ) -> Result { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 2a6ead18..b707782d 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -98,12 +98,7 @@ #[macro_use] extern crate log; -use std::{ - fmt::{self, Debug, Formatter}, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; +use std::fmt::{self, Debug, Formatter}; #[cfg(any(feature = "use-rustls", feature = "websocket"))] use std::sync::Arc; @@ -135,6 +130,7 @@ type RequestModifierFn = Arc< #[cfg(feature = "proxy")] mod proxy; +mod tokens; pub use client::{AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError}; pub use eventloop::{ConnectionError, Event, EventLoop}; @@ -145,7 +141,8 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; +use tokens::Resolver; +pub use tokens::{Token, TokenError}; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -189,130 +186,21 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq(PingReq), - PingResp(PingResp), - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect(Disconnect), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } + Publish(Publish, Resolver), + PubAck(PubAck, Resolver<()>), + PubRec(PubRec, Resolver<()>), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), + PingReq, } /// Packet Identifier with which Publish/Subscribe/Unsubscribe packets are identified while inflight. pub type Pkid = u16; -#[derive(Debug, thiserror::Error)] -pub enum PromiseError { - #[error("Sender has nothing to send instantly")] - Waiting, - #[error("Sender side of channel was dropped")] - Disconnected, - #[error("Broker rejected the request, reason: {reason}")] - Rejected { reason: String }, -} - -/// Resolves with [`Pkid`] used against packet when: -/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe -/// 2. QoS 0 packet finishes processing in the [`EventLoop`] -pub struct AckPromise { - rx: oneshot::Receiver>, -} - -impl Future for AckPromise { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); - - match polled { - Poll::Ready(Ok(p)) => Poll::Ready(p), - Poll::Ready(Err(_)) => Poll::Ready(Err(PromiseError::Disconnected)), - Poll::Pending => Poll::Pending, - } - } -} - -impl AckPromise { - /// Blocks on the current thread and waits till the packet is acknowledged by the broker. - /// - /// Returns [`PromiseError::Disconnected`] if the [`EventLoop`] was dropped(usually), - /// [`PromiseError::Rejected`] if the packet acknowledged but not accepted. - pub fn blocking_wait(self) -> Result { - self.rx - .blocking_recv() - .map_err(|_| PromiseError::Disconnected)? - } - - /// Attempts to check if the broker acknowledged the packet, without blocking the current thread. - /// - /// Returns [`PromiseError::Waiting`] if the packet wasn't acknowledged yet. - /// - /// Multiple calls to this functions can fail with [`PromiseError::Disconnected`] if the promise - /// has already been resolved. - pub fn try_resolve(&mut self) -> Result { - match self.rx.try_recv() { - Ok(Ok(p)) => Ok(p), - Ok(Err(e)) => Err(e), - Err(TryRecvError::Empty) => Err(PromiseError::Waiting), - Err(TryRecvError::Closed) => Err(PromiseError::Disconnected), - } - } -} - -#[derive(Debug)] -pub struct PromiseTx { - tx: oneshot::Sender>, -} - -impl PromiseTx { - fn new() -> (PromiseTx, AckPromise) { - let (tx, rx) = oneshot::channel(); - - (PromiseTx { tx }, AckPromise { rx }) - } - - fn resolve(self, pkid: Pkid) { - if self.tx.send(Ok(pkid)).is_err() { - trace!("Promise was dropped") - } - } - - fn fail(self, reason: String) { - if self - .tx - .send(Err(PromiseError::Rejected { reason })) - .is_err() - { - trace!("Promise was dropped") - } - } -} - /// Transport methods. Defaults to TCP. #[derive(Clone)] pub enum Transport { diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 37ddc462..4698996a 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,9 +1,10 @@ -use crate::{Event, Incoming, Outgoing, PromiseTx, Request}; +use crate::Pkid; +use crate::{tokens::Resolver, Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; /// Errors during state handling @@ -67,13 +68,17 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Option)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Waiters for publish/subscribe/unsubscribe acknowledgements - ack_waiter: Vec>, + /// Waiters for publish acknowledgements + pub_ack_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -98,12 +103,14 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, - ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Request, Option)> { + pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); let (first_half, second_half) = self .outgoing_pub @@ -111,17 +118,18 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { - let tx = self.ack_waiter[publish.pkid as usize].take(); - let request = Request::Publish(publish); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); + pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let tx = self.ack_waiter[pkid].take(); - let request = Request::PubRel(PubRel::new(pkid as u16)); - pending.push((request, tx)); + let pkid = pkid as u16; + let resolver = self.pub_ack_waiter.remove(&pkid).unwrap(); + let request = Request::PubRel(PubRel::new(pkid), resolver); + pending.push(request); } self.outgoing_rel.clear(); @@ -143,18 +151,29 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, - tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, - Request::PingReq(_) => self.outgoing_ping()?, - Request::Disconnect(_) => self.outgoing_disconnect()?, - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect(resolver) => { + resolver.resolve(()); + self.outgoing_disconnect()? + } + Request::PubAck(puback, resolver) => { + resolver.resolve(()); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => { + resolver.resolve(()); + self.outgoing_pubrec(pubrec)? + } }; self.last_outgoing = Instant::now(); @@ -171,7 +190,7 @@ impl MqttState { ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.clone())); - let outgoing = match &packet { + let outgoing = match packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, @@ -190,26 +209,18 @@ impl MqttState { Ok(outgoing) } - fn is_pkid_of_publish(&self, pkid: u16) -> bool { - self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) - } - - fn handle_incoming_suback(&mut self, suback: &SubAck) -> Result, StateError> { - // Expected ack for a subscribe packet, not a publish packet - if self.is_pkid_of_publish(suback.pkid) { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { return Err(StateError::Unsolicited(suback.pkid)); - } - - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - tx.resolve(suback.pkid); - } else { - tx.fail(format!("{:?}", suback.return_codes)); - } + }; + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { + resolver.resolve(suback.pkid); + } else { + resolver.reject(suback.return_codes); } Ok(None) @@ -217,23 +228,20 @@ impl MqttState { fn handle_incoming_unsuback( &mut self, - unsuback: &UnsubAck, + unsuback: UnsubAck, ) -> Result, StateError> { - // Expected ack for a unsubscribe packet, not a publish packet - if self.is_pkid_of_publish(unsuback.pkid) { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { return Err(StateError::Unsolicited(unsuback.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - tx.resolve(unsuback.pkid); - } + resolver.resolve(unsuback.pkid); Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { + fn handle_incoming_publish(&mut self, publish: Publish) -> Result, StateError> { let qos = publish.qos; match qos { @@ -258,7 +266,7 @@ impl MqttState { } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { let p = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -271,28 +279,32 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } - if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { - // Resolve promise for QoS 1 - tx.resolve(puback.pkid); - } + let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { + return Err(StateError::Unsolicited(puback.pkid)); + }; + + // Resolve promise for QoS 1 + resolver.resolve(puback.pkid); self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); - self.inflight += 1; + let packet = self + .check_collision(puback.pkid) + .map(|(publish, resolver)| { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[puback.pkid as usize] = tx; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { if self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -313,7 +325,7 @@ impl MqttState { Ok(Some(Packet::PubRel(pubrel))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { if !self.incoming_pub.contains(pubrel.pkid as usize) { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); return Err(StateError::Unsolicited(pubrel.pkid)); @@ -327,27 +339,31 @@ impl MqttState { Ok(Some(Packet::PubComp(pubcomp))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { if !self.outgoing_rel.contains(pubcomp.pkid as usize) { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { - // Resolve promise for QoS 2 - tx.resolve(pubcomp.pkid); - } + let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { + return Err(StateError::Unsolicited(pubcomp.pkid)); + }; + + // Resolve promise for QoS 2 + resolver.resolve(pubcomp.pkid); self.outgoing_rel.set(pubcomp.pkid as usize, false); self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[pubcomp.pkid as usize] = tx; + let packet = self + .check_collision(pubcomp.pkid) + .map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } @@ -363,7 +379,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -378,7 +394,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some((publish, tx)); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -399,20 +415,26 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); - match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(publish.pkid), - (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + if publish.qos == QoS::AtMostOnce { + resolver.resolve(publish.pkid); + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(pubrel))) } @@ -469,7 +491,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -485,7 +507,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); - self.ack_waiter[subscription.pkid as usize] = tx; + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -493,7 +515,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -505,7 +527,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); - self.ack_waiter[unsub.pkid as usize] = tx; + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -519,7 +541,7 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -568,7 +590,8 @@ impl MqttState { mod test { use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; - use crate::mqttbytes::*; + use crate::tokens::Resolver; + use crate::{mqttbytes::*, Pkid}; use crate::{Event, Incoming, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { @@ -619,7 +642,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -627,12 +651,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -640,12 +666,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -659,9 +687,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -676,9 +704,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -703,9 +731,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); @@ -717,7 +745,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -731,14 +759,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -749,7 +779,7 @@ mod test { fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() { let mut mqtt = build_mqttstate(); - let got = mqtt.handle_incoming_puback(&PubAck::new(101)).unwrap_err(); + let got = mqtt.handle_incoming_puback(PubAck::new(101)).unwrap_err(); match got { StateError::Unsolicited(pkid) => assert_eq!(pkid, 101), @@ -764,10 +794,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -783,14 +815,15 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish, None).unwrap().unwrap(); + let resolver = Resolver::mock(); + let packet = mqtt.outgoing_publish(publish, resolver).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrec(&PubRec::new(1)) + .handle_incoming_pubrec(PubRec::new(1)) .unwrap() .unwrap(); match packet { @@ -804,14 +837,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrel(&PubRel::new(1)) + .handle_incoming_pubrel(PubRel::new(1)) .unwrap() .unwrap(); match packet { @@ -825,10 +858,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -839,7 +873,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish), None) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -868,8 +903,8 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> Vec> { - vec![ + fn build_outgoing_pub(state: &mut MqttState) { + state.outgoing_pub = vec![ None, Some(Publish { dup: false, @@ -905,39 +940,47 @@ mod test { pkid: 6, payload: "".into(), }), - ] + ]; + for (i, _) in state + .outgoing_pub + .iter() + .enumerate() + .filter(|(_, p)| p.is_some()) + { + state.pub_ack_waiter.insert(i as Pkid, Resolver::mock()); + } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() diff --git a/rumqttc/src/tokens.rs b/rumqttc/src/tokens.rs new file mode 100644 index 00000000..26f307f0 --- /dev/null +++ b/rumqttc/src/tokens.rs @@ -0,0 +1,118 @@ +use std::{ + fmt::Debug, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::oneshot::{self, error::TryRecvError}; + +pub trait Reason: Debug + Send {} +impl Reason for T where T: Debug + Send {} + +#[derive(Debug, thiserror::Error)] +#[error("Broker rejected the request, reason: {0:?}")] +pub struct Rejection(Box); + +impl Rejection { + fn new(reason: R) -> Self { + Self(Box::new(reason)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum TokenError { + #[error("Sender has nothing to send instantly")] + Waiting, + #[error("Sender side of channel was dropped")] + Disconnected, + #[error("Broker rejected the request, reason: {0:?}")] + Rejection(#[from] Rejection), +} + +pub type NoResponse = (); + +/// Resolves with [`Pkid`] used against packet when: +/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe +/// 2. QoS 0 packet finishes processing in the [`EventLoop`] +pub struct Token { + rx: oneshot::Receiver>, +} + +impl Future for Token { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); + + match polled { + Poll::Ready(Ok(Ok(p))) => Poll::Ready(Ok(p)), + Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(TokenError::Rejection(e))), + Poll::Ready(Err(_)) => Poll::Ready(Err(TokenError::Disconnected)), + Poll::Pending => Poll::Pending, + } + } +} + +/// There is a type of token returned for each type of [`Request`] when it is created and +/// sent to the [`EventLoop`] for further processing from the [`Client`]/[`AsyncClient`]. +/// Some tokens such as those associated with the resolve with the `pkid` value used in the packet sent to the broker while other +/// tokens don't return such a value. +impl Token { + /// Blocks on the current thread and waits till the packet completes being handled. + /// + /// ## Errors + /// Returns [`TokenError::Disconnected`] if the [`EventLoop`] was dropped(usually), + /// [`TokenError::Rejection`] if the packet acknowledged but not accepted. + pub fn wait(self) -> Result { + self.rx + .blocking_recv() + .map_err(|_| TokenError::Disconnected)? + .map_err(|e| TokenError::Rejection(e)) + } + + /// Attempts to check if the packet handling has been completed, without blocking the current thread. + /// + /// ## Errors + /// Returns [`TokenError::Waiting`] if the packet wasn't acknowledged yet. + /// Multiple calls to this functions can fail with [`TokenError::Disconnected`] + /// if the promise has already been resolved. + pub fn check(&mut self) -> Result { + match self.rx.try_recv() { + Ok(r) => r.map_err(|e| TokenError::Rejection(e)), + Err(TryRecvError::Empty) => Err(TokenError::Waiting), + Err(TryRecvError::Closed) => Err(TokenError::Disconnected), + } + } +} + +#[derive(Debug)] +pub struct Resolver { + tx: oneshot::Sender>, +} + +impl Resolver { + pub fn new() -> (Self, Token) { + let (tx, rx) = oneshot::channel(); + + (Self { tx }, Token { rx }) + } + + #[cfg(test)] + pub fn mock() -> Self { + let (tx, _) = oneshot::channel(); + + Self { tx } + } + + pub fn resolve(self, resolved: T) { + if self.tx.send(Ok(resolved)).is_err() { + trace!("Promise was dropped") + } + } + + pub fn reject(self, reasons: R) { + if self.tx.send(Err(Rejection::new(reasons))).is_err() { + trace!("Promise was dropped") + } + } +} diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index a099f54d..6f537f4d 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,8 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::QoS; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::{valid_filter, valid_topic, AckPromise, PromiseTx}; +use crate::tokens::{NoResponse, Resolver, Token}; +use crate::{valid_filter, valid_topic, Pkid}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -25,15 +26,15 @@ pub enum ClientError { TryRequest(Request), } -impl From)>> for ClientError { - fn from(e: SendError<(Request, Option)>) -> Self { - Self::Request(e.into_inner().0) +impl From> for ClientError { + fn from(e: SendError) -> Self { + Self::Request(e.into_inner()) } } -impl From)>> for ClientError { - fn from(e: TrySendError<(Request, Option)>) -> Self { - Self::TryRequest(e.into_inner().0) +impl From> for ClientError { + fn from(e: TrySendError) -> Self { + Self::TryRequest(e.into_inner()) } } @@ -46,7 +47,7 @@ impl From)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Request, Option)>, + request_tx: Sender, } impl AsyncClient { @@ -66,7 +67,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender) -> AsyncClient { AsyncClient { request_tx } } @@ -78,24 +79,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + self.request_tx.send_async(publish).await?; - Ok(promise) + Ok(token) } pub async fn publish_with_properties( @@ -105,7 +104,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -120,7 +119,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -136,22 +135,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send((publish, Some(promise_tx)))?; + self.request_tx.try_send(publish)?; - Ok(promise) + Ok(token) } pub fn try_publish_with_properties( @@ -161,7 +160,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -175,7 +174,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -184,24 +183,26 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.send_async((ack, None)).await?; + self.request_tx.send_async(ack).await?; } - Ok(()) + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.try_send((ack, None))?; + self.request_tx.try_send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -212,20 +213,18 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + let publish = Request::Publish(publish, resolver); + self.request_tx.send_async(publish).await?; - Ok(promise) + Ok(token) } pub async fn publish_bytes_with_properties( @@ -235,7 +234,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, { @@ -249,7 +248,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result, ClientError> where S: Into, { @@ -263,18 +262,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn subscribe_with_properties>( @@ -282,7 +281,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)).await } @@ -290,7 +289,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None).await } @@ -300,17 +299,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_subscribe_with_properties>( @@ -318,7 +318,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, Some(properties)) } @@ -326,7 +326,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, None) } @@ -335,34 +335,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -374,33 +374,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -412,26 +413,24 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx - .send_async((request, Some(promise_tx))) - .await?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result { + pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None).await } @@ -440,49 +439,54 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_try_unsubscribe(topic, None) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_tx.send_async((request, None)).await?; + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.send_async(request).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_tx.try_send((request, None))?; + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.try_send(request)?; - Ok(()) + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None)), + QoS::AtMostOnce => { + resolver.resolve(()); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None), resolver), }; Some(ack) } @@ -523,7 +527,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { + pub fn from_sender(request_tx: Sender) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -537,22 +541,22 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send((publish, Some(promise_tx)))?; + self.client.request_tx.send(publish)?; - Ok(promise) + Ok(token) } pub fn publish_with_properties( @@ -562,7 +566,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -576,7 +580,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -591,7 +595,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -606,7 +610,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -615,18 +619,19 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.client.request_tx.send((ack, None))?; + self.client.request_tx.send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -636,18 +641,18 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn subscribe_with_properties>( @@ -655,7 +660,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)) } @@ -663,7 +668,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None) } @@ -673,7 +678,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.client .try_subscribe_with_properties(topic, qos, properties) } @@ -682,7 +687,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } @@ -691,34 +696,34 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -729,7 +734,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { @@ -737,7 +742,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -749,24 +754,24 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None) } @@ -775,26 +780,26 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result { - let (promise_tx, promise) = PromiseTx::new(); - let request = Request::Disconnect; - self.client.request_tx.send((request, Some(promise_tx)))?; + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result, ClientError> { self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index b2c2fc50..ea361b4e 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,7 +3,6 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; -use crate::PromiseTx; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -74,11 +73,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Request, Option)>, + requests_rx: Receiver, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Request, Option)>, + pub(crate) requests_tx: Sender, /// Pending packets from last session - pub pending: VecDeque<(Request, Option)>, + pub pending: VecDeque, /// Network connection to the broker network: Option, /// Keep alive time @@ -129,9 +128,9 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|(request, _)| { + requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); @@ -224,8 +223,8 @@ impl EventLoop { &self.requests_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok((request, tx)) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { + Ok(request) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { network.write(outgoing).await?; } network.flush().await?; @@ -246,7 +245,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq, None)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } network.flush().await?; @@ -256,10 +255,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Request, Option)>, - rx: &Receiver<(Request, Option)>, + pending: &mut VecDeque, + rx: &Receiver, pending_throttle: Duration, - ) -> Result<(Request, Option), ConnectionError> { + ) -> Result { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e0e4393..22b1942c 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,9 @@ mod framed; pub mod mqttbytes; mod state; -use crate::Outgoing; +use crate::tokens::Resolver; use crate::{NetworkOptions, Transport}; +use crate::{Outgoing, Pkid}; use mqttbytes::v5::*; @@ -33,26 +34,16 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), + Publish(Publish, Resolver), + PubAck(PubAck, Resolver<()>), + PubRec(PubRec, Resolver<()>), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Self { - Self::Subscribe(subscribe) - } } #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 97a6def3..6f9e430c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,13 +1,17 @@ -use crate::PromiseTx; - -use super::mqttbytes::v5::{ - ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, - PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, +use crate::{tokens::Resolver, Pkid}; + +use super::{ + mqttbytes::{ + self, + v5::{ + ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, + PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, + Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, + }, + Error as MqttError, QoS, + }, + Event, Incoming, Outgoing, Request, }; -use super::mqttbytes::{self, Error as MqttError, QoS}; - -use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use fixedbitset::FixedBitSet; @@ -99,7 +103,7 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Option)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -112,8 +116,12 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, - /// Waiters for publish/subscribe/unsubscribe acknowledgements - ack_waiter: Vec>, + /// Waiters for publish acknowledgements + pub_ack_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -141,27 +149,29 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, - ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Request, Option)> { + pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes for publish in self.outgoing_pub.iter_mut() { if let Some(publish) = publish.take() { - let tx = self.ack_waiter[publish.pkid as usize].take(); - let request = Request::Publish(publish); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); + pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let tx = self.ack_waiter[pkid].take(); - let request = Request::PubRel(PubRel::new(pkid as u16, None)); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&(pkid as u16)).unwrap(); + let request = Request::PubRel(PubRel::new(pkid as u16, None), resolver); + pending.push(request); } self.outgoing_rel.clear(); @@ -183,20 +193,29 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, - tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } Request::PingReq => self.outgoing_ping()?, - Request::Disconnect => { + Request::Disconnect(resolver) => { + resolver.resolve(()); self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::PubAck(puback, resolver) => { + resolver.resolve(()); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => { + resolver.resolve(()); + self.outgoing_pubrec(pubrec)? + } }; self.last_outgoing = Instant::now(); @@ -209,11 +228,11 @@ impl MqttState { /// be forwarded to user and Pubck packet will be written to network pub fn handle_incoming_packet( &mut self, - mut packet: Incoming, + packet: Incoming, ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.to_owned())); - let outgoing = match &mut packet { + let outgoing = match packet { Incoming::PingResp(_) => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, @@ -239,18 +258,11 @@ impl MqttState { self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn is_pkid_of_publish(&self, pkid: u16) -> bool { - self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) - } - - fn handle_incoming_suback( - &mut self, - suback: &mut SubAck, - ) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { // Expected ack for a subscribe packet, not a publish packet - if self.is_pkid_of_publish(suback.pkid) { + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { return Err(StateError::Unsolicited(suback.pkid)); - } + }; for reason in suback.return_codes.iter() { match reason { @@ -263,16 +275,14 @@ impl MqttState { } } - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - tx.resolve(suback.pkid); - } else { - tx.fail(format!("{:?}", suback.return_codes)); - } + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { + resolver.resolve(suback.pkid); + } else { + resolver.reject(suback.return_codes); } Ok(None) @@ -280,12 +290,11 @@ impl MqttState { fn handle_incoming_unsuback( &mut self, - unsuback: &mut UnsubAck, + unsuback: UnsubAck, ) -> Result, StateError> { - // Expected ack for a unsubscribe packet, not a publish packet - if self.is_pkid_of_publish(unsuback.pkid) { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { return Err(StateError::Unsolicited(unsuback.pkid)); - } + }; for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { @@ -293,25 +302,20 @@ impl MqttState { } } - if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - if unsuback - .reasons - .iter() - .all(|r| matches!(r, UnsubAckReason::Success)) - { - tx.resolve(unsuback.pkid); - } else { - tx.fail(format!("{:?}", unsuback.reasons)); - } + if unsuback + .reasons + .iter() + .all(|r| matches!(r, UnsubAckReason::Success)) + { + resolver.resolve(unsuback.pkid); + } else { + resolver.reject(unsuback.reasons); } Ok(None) } - fn handle_incoming_connack( - &mut self, - connack: &mut ConnAck, - ) -> Result, StateError> { + fn handle_incoming_connack(&mut self, connack: ConnAck) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -335,7 +339,7 @@ impl MqttState { fn handle_incoming_disconn( &mut self, - disconn: &mut Disconnect, + disconn: Disconnect, ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { @@ -353,7 +357,7 @@ impl MqttState { /// in case of QoS1 and Replys rec in case of QoS while also storing the message fn handle_incoming_publish( &mut self, - publish: &mut Publish, + mut publish: Publish, ) -> Result, StateError> { let qos = publish.qos; @@ -396,24 +400,22 @@ impl MqttState { } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { - // Resolve promise for QoS 1 - if puback.reason == PubAckReason::Success { - tx.resolve(puback.pkid); - } else { - tx.fail(format!("{:?}", puback.reason)); - } + self.outgoing_pub + .get_mut(puback.pkid as usize) + .ok_or(StateError::Unsolicited(puback.pkid))? + .take(); + + // Resolve promise for QoS 1 + if puback.reason == PubAckReason::Success { + resolver.resolve(puback.pkid); + } else { + resolver.reject(puback.reason); } self.inflight -= 1; @@ -428,7 +430,7 @@ impl MqttState { return Ok(None); } - if let Some((publish, tx)) = self.check_collision(puback.pkid) { + if let Some((publish, resolver)) = self.check_collision(puback.pkid) { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; @@ -436,7 +438,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; - self.ack_waiter[puback.pkid as usize] = tx; + self.pub_ack_waiter.insert(puback.pkid, resolver); return Ok(Some(Packet::Publish(publish))); } @@ -444,7 +446,7 @@ impl MqttState { Ok(None) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -473,7 +475,7 @@ impl MqttState { Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None)))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { if !self.incoming_pub.contains(pubrel.pkid as usize) { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); return Err(StateError::Unsolicited(pubrel.pkid)); @@ -494,30 +496,30 @@ impl MqttState { Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None)))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { - // Resolve promise for QoS 2 - if pubcomp.reason == PubCompReason::Success { - tx.resolve(pubcomp.pkid); - } else { - tx.fail(format!("{:?}", pubcomp.reason)); - } + // Resolve promise for QoS 2 + if pubcomp.reason == PubCompReason::Success { + resolver.resolve(pubcomp.pkid); + } else { + resolver.reject(pubcomp.reason); } self.outgoing_rel.set(pubcomp.pkid as usize, false); - let outgoing = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[pubcomp.pkid as usize] = tx; + let outgoing = self + .check_collision(pubcomp.pkid) + .map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(pubcomp.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); if pubcomp.reason != PubCompReason::Success { warn!( @@ -541,7 +543,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -556,7 +558,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some((publish, tx)); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -592,21 +594,27 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(0), - (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + if publish.qos == QoS::AtMostOnce { + resolver.resolve(0); + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(PubRel::new(pubrel.pkid, None)))) } @@ -662,7 +670,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -679,7 +687,7 @@ impl MqttState { let pkid = subscription.pkid; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - self.ack_waiter[subscription.pkid as usize] = tx; + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -687,7 +695,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -700,7 +708,7 @@ impl MqttState { let pkid = unsub.pkid; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - self.ack_waiter[unsub.pkid as usize] = tx; + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -716,7 +724,7 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -763,6 +771,8 @@ impl MqttState { #[cfg(test)] mod test { + use crate::tokens::Resolver; + use super::mqttbytes::v5::*; use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; @@ -816,7 +826,9 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -824,12 +836,15 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -837,12 +852,15 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -854,27 +872,31 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -884,13 +906,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -901,13 +923,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -928,13 +950,13 @@ mod test { mqtt.manual_acks = true; // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); assert!(mqtt.events.is_empty()); @@ -943,9 +965,9 @@ mod test { #[test] fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -958,14 +980,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -977,7 +1001,7 @@ mod test { let mut mqtt = build_mqttstate(); let got = mqtt - .handle_incoming_puback(&PubAck::new(101, None)) + .handle_incoming_puback(PubAck::new(101, None)) .unwrap_err(); match got { @@ -993,10 +1017,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -1012,13 +1038,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish, None).unwrap().unwrap() { + let resolver = Resolver::mock(); + match mqtt.outgoing_publish(publish, resolver).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrec(&PubRec::new(1, None)) + .handle_incoming_pubrec(PubRec::new(1, None)) .unwrap() .unwrap() { @@ -1030,15 +1057,15 @@ mod test { #[test] fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrel(&PubRel::new(1, None)) + .handle_incoming_pubrel(PubRel::new(1, None)) .unwrap() .unwrap() { @@ -1052,11 +1079,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1, None)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) - .unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -1067,7 +1094,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish), None) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap();