From d6b1d564105d5363dcd12a49f1689f855924f9e0 Mon Sep 17 00:00:00 2001 From: boxdot Date: Wed, 5 Feb 2025 19:17:31 +0100 Subject: [PATCH] fix: close websocket handler when web socket is dropped (#320) * When `QsWebSocket` was dropped, the handler was still running. Now, it is also cancelled. * `UserCubit` considered web socket disconnect events as an error, and reconnected again. This is not needed since internally `QsWebSocket` automatically reconnects. Additionally, replace the broadcast channel by mpsc one, since we have only a single subscriber. Also remove unused public APIs. --- Cargo.lock | 3 + Cargo.toml | 3 + apiclient/Cargo.toml | 2 + apiclient/src/qs_api/tests.rs | 4 +- apiclient/src/qs_api/ws.rs | 125 +++++++++++++++++---------------- applogic/Cargo.toml | 2 +- applogic/src/api/user_cubit.rs | 19 +++-- coreclient/Cargo.toml | 1 + coreclient/src/clients/mod.rs | 15 +++- server/Cargo.toml | 1 + server/tests/qs/ws.rs | 7 +- 11 files changed, 109 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f02179ca..952f6893 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3302,6 +3302,7 @@ dependencies = [ "tls_codec", "tokio", "tokio-tungstenite", + "tokio-util", "tracing", "tracing-subscriber", "url", @@ -3377,6 +3378,7 @@ dependencies = [ "tls_codec", "tokio", "tokio-stream", + "tokio-util", "tracing", "trait-variant", "url", @@ -3410,6 +3412,7 @@ dependencies = [ "thiserror 1.0.69", "tls_codec", "tokio", + "tokio-util", "tracing", "tracing-actix-web", "tracing-bunyan-formatter", diff --git a/Cargo.toml b/Cargo.toml index c1d34f5d..42a391ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,9 @@ tracing-subscriber = { version = "0.3", features = [ "parking_lot", ] } +tokio-util = "0.7.13" + + [patch.crates-io] #opaque-ke = { git = "https://github.com/facebook/opaque-ke", branch = "dependabot/cargo/voprf-eq-0.5.0" } diff --git a/apiclient/Cargo.toml b/apiclient/Cargo.toml index d9b67c29..911bf945 100644 --- a/apiclient/Cargo.toml +++ b/apiclient/Cargo.toml @@ -14,6 +14,7 @@ reqwest = { workspace = true } thiserror = "1" phnxtypes = { path = "../types" } tokio = { version = "1.18.2", features = ["macros"] } +tokio-util = { workspace = true } tokio-tungstenite = { version = "0.23", features = ["rustls-tls-webpki-roots"] } futures-util = "0.3.21" http = "1" @@ -23,6 +24,7 @@ mls-assist = { workspace = true } privacypass = { workspace = true } tls_codec = { workspace = true } url = "2" +uuid = { version = "1", features = ["v4"] } [dev-dependencies] tokio = { version = "1.18.2", features = ["macros"] } diff --git a/apiclient/src/qs_api/tests.rs b/apiclient/src/qs_api/tests.rs index d04b074d..0f8915ae 100644 --- a/apiclient/src/qs_api/tests.rs +++ b/apiclient/src/qs_api/tests.rs @@ -20,6 +20,7 @@ use phnxtypes::{ messages::{client_ds::QsWsMessage, client_qs::QsOpenWsParams}, }; use tls_codec::Serialize; +use tokio_util::sync::CancellationToken; use tracing::{error, info}; use uuid::Uuid; @@ -49,8 +50,9 @@ async fn ws_lifecycle() { let client = ApiClient::with_default_http_client(address).expect("Failed to initialize client"); // Spawn the websocket connection task + let cancel = CancellationToken::new(); let mut ws = client - .spawn_websocket(queue_id, timeout, retry_interval) + .spawn_websocket(queue_id, timeout, retry_interval, cancel) .await .expect("Failed to execute request"); diff --git a/apiclient/src/qs_api/ws.rs b/apiclient/src/qs_api/ws.rs index 903c8f76..23d5e148 100644 --- a/apiclient/src/qs_api/ws.rs +++ b/apiclient/src/qs_api/ws.rs @@ -18,8 +18,7 @@ use thiserror::*; use tls_codec::DeserializeBytes; use tokio::{ net::TcpStream, - sync::broadcast::{self, Receiver, Sender}, - task::JoinHandle, + sync::mpsc, time::{sleep, Instant}, }; use tokio_tungstenite::{ @@ -27,7 +26,9 @@ use tokio_tungstenite::{ tungstenite::{client::IntoClientRequest, protocol::Message}, MaybeTlsStream, WebSocketStream, }; +use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{error, info}; +use uuid::Uuid; use crate::{ApiClient, Protocol}; @@ -53,9 +54,12 @@ impl ConnectionStatus { Self { connected: false } } - fn set_connected(&mut self, tx: &Sender) -> Result<(), ConnectionStatusError> { + async fn set_connected( + &mut self, + tx: &mpsc::Sender, + ) -> Result<(), ConnectionStatusError> { if !self.connected { - if let Err(error) = tx.send(WsEvent::ConnectedEvent) { + if let Err(error) = tx.send(WsEvent::ConnectedEvent).await { error!(%error, "Error sending to channel"); self.connected = false; return Err(ConnectionStatusError::ChannelClosed); @@ -65,9 +69,12 @@ impl ConnectionStatus { Ok(()) } - fn set_disconnected(&mut self, tx: &Sender) -> Result<(), ConnectionStatusError> { + async fn set_disconnected( + &mut self, + tx: &mpsc::Sender, + ) -> Result<(), ConnectionStatusError> { if self.connected { - if let Err(error) = tx.send(WsEvent::DisconnectedEvent) { + if let Err(error) = tx.send(WsEvent::DisconnectedEvent).await { error!(%error, "Error sending to channel"); return Err(ConnectionStatusError::ChannelClosed); } @@ -79,10 +86,11 @@ impl ConnectionStatus { /// A websocket connection to the QS server. See the /// [`ApiClient::spawn_websocket`] method for more information. +/// +/// When dropped, the websocket connection will be closed. pub struct QsWebSocket { - rx: Receiver, - tx: Sender, - handle: JoinHandle<()>, + rx: mpsc::Receiver, + _cancel: DropGuard, } impl QsWebSocket { @@ -90,37 +98,18 @@ impl QsWebSocket { /// sent or the connection is closed (in which case a final `None` is /// returned). pub async fn next(&mut self) -> Option { - match self.rx.recv().await { - Ok(message) => Some(message), - Err(error) => { - error!(%error, "Error receiving from channel"); - None - } - } - } - - /// Subscribe to the event stream - pub fn subscribe(&self) -> Receiver { - self.tx.subscribe() - } - - /// Join the websocket connection task. This will block until the task has - /// completed. - pub async fn join(self) -> Result<(), tokio::task::JoinError> { - self.handle.await - } - - /// Abort the websocket connection task. This will close the websocket connection. - pub fn abort(&mut self) { - self.handle.abort(); + self.rx.recv().await } /// Internal helper function to handle an established websocket connection + /// + /// Returns `true` if the connection should be re-established, otherwise `false`. async fn handle_connection( ws_stream: WebSocketStream>, - tx: &Sender, + tx: &mpsc::Sender, timeout: u64, - ) { + cancel: &CancellationToken, + ) -> bool { let mut last_ping = Instant::now(); // Watchdog to monitor the connection. @@ -131,15 +120,20 @@ impl QsWebSocket { // Initialize the connection status let mut connection_status = ConnectionStatus::new(); - if connection_status.set_connected(tx).is_err() { + if connection_status.set_connected(tx).await.is_err() { // Close the stream if all subscribers of the watch have been dropped let _ = ws_stream.close().await; - return; + return false; } // Loop while the connection is open loop { tokio::select! { + // Check is the handler is cancelled + _ = cancel.cancelled() => { + info!("QS WebSocket connection cancelled"); + break false; + }, // Check if the connection is still alive _ = interval.tick() => { let now = Instant::now(); @@ -147,10 +141,10 @@ impl QsWebSocket { if now.duration_since(last_ping) > Duration::from_secs(timeout) { // Change the status to Disconnected and send an event let _ = ws_stream.close().await; - if connection_status.set_disconnected(tx).is_err() { + if connection_status.set_disconnected(tx).await.is_err() { // Close the stream if all subscribers of the watch have been dropped info!("Closing the connection because all subscribers are dropped"); - return; + return false; } } }, @@ -163,11 +157,11 @@ impl QsWebSocket { // Reset the last ping time last_ping = Instant::now(); // Change the status to Connected and send an event - if connection_status.set_connected(tx).is_err() { + if connection_status.set_connected(tx).await.is_err() { // Close the stream if all subscribers of the watch have been dropped info!("Closing the connection because all subscribers are dropped"); let _ = ws_stream.close().await; - return; + return false; } // Try to deserialize the message if let Ok(QsWsMessage::QueueUpdate) = @@ -175,11 +169,11 @@ impl QsWebSocket { { // We received a new message notification from the QS // Send the event to the channel - if tx.send(WsEvent::MessageEvent(QsWsMessage::QueueUpdate)).is_err() { + if tx.send(WsEvent::MessageEvent(QsWsMessage::QueueUpdate)).await.is_err() { info!("Closing the connection because all subscribers are dropped"); // Close the stream if all subscribers of the watch have been dropped let _ = ws_stream.close().await; - return; + return false; } } }, @@ -187,20 +181,20 @@ impl QsWebSocket { Message::Ping(_) => { // We update the last ping time last_ping = Instant::now(); - if connection_status.set_connected(tx).is_err() { + if connection_status.set_connected(tx).await.is_err() { // Close the stream if all subscribers of the watch have been dropped info!("Closing the connection because all subscribers are dropped"); let _ = ws_stream.close().await; - return; + return false; } } Message::Close(_) => { // Change the status to Disconnected and send an // event - let _ = connection_status.set_disconnected(tx); + let _ = connection_status.set_disconnected(tx).await; // We close the websocket let _ = ws_stream.close().await; - return; + return true; } _ => { } @@ -208,8 +202,8 @@ impl QsWebSocket { } else { // It seems the connection is closed, send disconnect // event - let _ = connection_status.set_disconnected(tx); - break; + let _ = connection_status.set_disconnected(tx).await; + break true; } }, } @@ -255,13 +249,14 @@ impl ApiClient { /// [`WsEvent::ConnectedEvent]. /// /// The connection will be closed if all subscribers of the [`QsWebSocket`] - /// have been dropped, or when it is manually closed with using the - /// [`QsWebSocket::abort()`] function. + /// have been dropped, or when it is manually closed by cancelling the token + /// `cancel`. /// /// # Arguments /// - `queue_id` - The ID of the queue monitor. /// - `timeout` - The timeout for the connection in seconds. /// - `retry_interval` - The interval between connection attempts in seconds. + /// - `cancel` - The cancellation token to stop the socket. /// /// # Returns /// A new [`QsWebSocket`] that represents the websocket connection. @@ -270,6 +265,7 @@ impl ApiClient { queue_id: QsClientId, timeout: u64, retry_interval: u64, + cancel: CancellationToken, ) -> Result { // Set the request parameter let qs_ws_open_params = QsOpenWsParams { queue_id }; @@ -289,19 +285,19 @@ impl ApiClient { })?; // We create a channel to send events to - let (tx, rx) = broadcast::channel(100); + let (tx, rx) = mpsc::channel(100); - // We clone the sender, so that we can subscribe to more receivers - let tx_clone = tx.clone(); + let connection_id = Uuid::new_v4(); + info!(%connection_id, "Spawning the websocket connection..."); - info!("Spawning the websocket connection..."); + let cancel_guard = cancel.clone().drop_guard(); // Spawn the connection task - let handle = tokio::spawn(async move { + tokio::spawn(async move { // Connection loop #[cfg(test)] let mut counter = 0; - loop { + while !cancel.is_cancelled() { // We build the request and set a custom header let req = match address.clone().into_client_request() { Ok(mut req) => { @@ -319,13 +315,15 @@ impl ApiClient { match connect_async(req).await { // The connection was established Ok((ws_stream, _)) => { - info!("Connected to QS WebSocket"); + info!(%connection_id, "Connected to QS WebSocket"); // Hand over the connection to the handler - QsWebSocket::handle_connection(ws_stream, &tx, timeout).await; + if !QsWebSocket::handle_connection(ws_stream, &tx, timeout, &cancel).await { + break; + } } // The connection was not established, wait and try again - Err(e) => { - error!("Error connecting to QS WebSocket: {}", e); + Err(error) => { + error!(%error, "Error connecting to QS WebSocket"); #[cfg(test)] { counter += 1; @@ -336,17 +334,20 @@ impl ApiClient { } } info!( + %connection_id, retry_in_sec = retry_interval, + is_cancelled = cancel.is_cancelled(), "The websocket was closed, will reconnect...", ); sleep(time::Duration::from_secs(retry_interval)).await; } + + info!(%connection_id, "QS WebSocket closed"); }); Ok(QsWebSocket { rx, - tx: tx_clone, - handle, + _cancel: cancel_guard, }) } } diff --git a/applogic/Cargo.toml b/applogic/Cargo.toml index cbbfb556..1ab4981a 100644 --- a/applogic/Cargo.toml +++ b/applogic/Cargo.toml @@ -37,6 +37,6 @@ flutter_rust_bridge = { version = "=2.7.0", features = ["chrono", "uuid"] } notify-rust = "4" chrono = { workspace = true } jni = "0.21" -tokio-util = "0.7.13" +tokio-util = { workspace = true } tokio-stream = "0.1.17" blake3 = "1.5.5" diff --git a/applogic/src/api/user_cubit.rs b/applogic/src/api/user_cubit.rs index cab6dce1..442d9122 100644 --- a/applogic/src/api/user_cubit.rs +++ b/applogic/src/api/user_cubit.rs @@ -110,6 +110,12 @@ pub struct UserCubitBase { _background_tasks_cancel: DropGuard, } +impl Drop for UserCubitBase { + fn drop(&mut self) { + info!("Dropping UserCubitBase"); + } +} + const WEBSOCKET_TIMEOUT: Duration = Duration::from_secs(30); const WEBSCOKET_RETRY_INTERVAL: Duration = Duration::from_secs(10); const POLLING_INTERVAL: Duration = Duration::from_secs(10); @@ -244,9 +250,12 @@ impl UserCubitBase { fn spawn_websocket(core_user: CoreUser, cancel: CancellationToken) { spawn_from_sync(async move { let mut backoff = FibonacciBackoff::new(); - while let Err(error) = run_websocket(&core_user, &cancel, &mut backoff).await { + let mut websocket_cancel = cancel.child_token(); + while let Err(error) = run_websocket(&core_user, &websocket_cancel, &mut backoff).await { let timeout = backoff.next_backoff(); info!(%error, retry_in =? timeout, "Websocket failed"); + websocket_cancel.cancel(); + websocket_cancel = cancel.child_token(); tokio::time::sleep(timeout).await; } info!("Websocket handler stopped normally"); @@ -263,6 +272,7 @@ async fn run_websocket( .websocket( WEBSOCKET_TIMEOUT.as_secs(), WEBSCOKET_RETRY_INTERVAL.as_secs(), + cancel.clone(), ) .await?; loop { @@ -271,7 +281,7 @@ async fn run_websocket( _ = cancel.cancelled() => return Ok(()), }; match event { - Some(event) => handle_websocket_message(event, core_user).await?, + Some(event) => handle_websocket_message(event, core_user).await, None => bail!("unexpected disconnect"), } backoff.reset(); // reset backoff after a successful message @@ -306,10 +316,10 @@ fn spawn_polling(core_user: CoreUser, cancel: CancellationToken) { }); } -async fn handle_websocket_message(event: WsEvent, core_user: &CoreUser) -> anyhow::Result<()> { +async fn handle_websocket_message(event: WsEvent, core_user: &CoreUser) { match event { WsEvent::ConnectedEvent => info!("connected to websocket"), - WsEvent::DisconnectedEvent => bail!("server disconnect"), + WsEvent::DisconnectedEvent => info!("disconnected from websocket"), WsEvent::MessageEvent(QsWsMessage::Event(event)) => { warn!("ignoring websocket event: {event:?}") } @@ -326,7 +336,6 @@ async fn handle_websocket_message(event: WsEvent, core_user: &CoreUser) -> anyho } } } - Ok(()) } async fn process_fetched_messages(_fetched_messages: FetchedMessages) { diff --git a/coreclient/Cargo.toml b/coreclient/Cargo.toml index a9421dcf..17b46537 100644 --- a/coreclient/Cargo.toml +++ b/coreclient/Cargo.toml @@ -27,6 +27,7 @@ anyhow = { version = "1.0", features = ["backtrace"] } rand = "0.8.4" rand_chacha = "0.3.1" tokio = { version = "1" } +tokio-util = { workspace = true } image = "0.25.1" kamadak-exif = "0.5.5" derive_more = { version = "0.99.18", features = ["from"] } diff --git a/coreclient/src/clients/mod.rs b/coreclient/src/clients/mod.rs index 251dc83f..7931b16d 100644 --- a/coreclient/src/clients/mod.rs +++ b/coreclient/src/clients/mod.rs @@ -49,6 +49,7 @@ use serde::{Deserialize, Serialize}; use store::ClientRecord; use thiserror::Error; use tokio_stream::Stream; +use tokio_util::sync::CancellationToken; use tracing::{error, info}; use crate::store::StoreNotificationsSender; @@ -1028,10 +1029,20 @@ impl CoreUser { .map(|group| group.pending_removes(connection)) } - pub async fn websocket(&self, timeout: u64, retry_interval: u64) -> Result { + pub async fn websocket( + &self, + timeout: u64, + retry_interval: u64, + cancel: CancellationToken, + ) -> Result { let api_client = self.inner.api_clients.default_client(); Ok(api_client? - .spawn_websocket(self.inner.qs_client_id.clone(), timeout, retry_interval) + .spawn_websocket( + self.inner.qs_client_id.clone(), + timeout, + retry_interval, + cancel, + ) .await?) } diff --git a/server/Cargo.toml b/server/Cargo.toml index 392afd3a..457a274a 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -31,6 +31,7 @@ async-trait = "0.1.74" actix-web-actors = "4.2.0" actix = "0.13" tokio = "1" +tokio-util = { workspace = true } base64 = "0.22" thiserror = "1.0" tracing = { version = "0.1", features = ["log"] } diff --git a/server/tests/qs/ws.rs b/server/tests/qs/ws.rs index 779c996d..cb5fb69b 100644 --- a/server/tests/qs/ws.rs +++ b/server/tests/qs/ws.rs @@ -8,6 +8,7 @@ use phnxbackend::qs::{WebsocketNotifier, WsNotification}; use phnxserver::network_provider::MockNetworkProvider; use phnxserver_test_harness::utils::spawn_app; use phnxtypes::{identifiers::QsClientId, messages::client_ds::QsWsMessage}; +use tokio_util::sync::CancellationToken; /// Test the websocket reconnect. #[actix_rt::test] @@ -29,8 +30,9 @@ async fn ws_reconnect() { let address = format!("http://{}", address); let client = ApiClient::with_default_http_client(address).expect("Failed to initialize client"); + let cancel = CancellationToken::new(); let mut ws = client - .spawn_websocket(client_id, timeout, retry_interval) + .spawn_websocket(client_id, timeout, retry_interval, cancel) .await .expect("Failed to execute request"); @@ -67,8 +69,9 @@ async fn ws_sending() { let address = format!("http://{}", address); let client = ApiClient::with_default_http_client(address).expect("Failed to initialize client"); + let cancel = CancellationToken::new(); let mut ws = client - .spawn_websocket(client_id.clone(), timeout, retry_interval) + .spawn_websocket(client_id.clone(), timeout, retry_interval, cancel) .await .expect("Failed to execute request");