@@ -18,16 +18,17 @@ use thiserror::*;
18
18
use tls_codec:: DeserializeBytes ;
19
19
use tokio:: {
20
20
net:: TcpStream ,
21
- sync:: broadcast:: { self , Receiver , Sender } ,
22
- task:: JoinHandle ,
21
+ sync:: mpsc,
23
22
time:: { sleep, Instant } ,
24
23
} ;
25
24
use tokio_tungstenite:: {
26
25
connect_async,
27
26
tungstenite:: { client:: IntoClientRequest , protocol:: Message } ,
28
27
MaybeTlsStream , WebSocketStream ,
29
28
} ;
29
+ use tokio_util:: sync:: { CancellationToken , DropGuard } ;
30
30
use tracing:: { error, info} ;
31
+ use uuid:: Uuid ;
31
32
32
33
use crate :: { ApiClient , Protocol } ;
33
34
@@ -53,9 +54,12 @@ impl ConnectionStatus {
53
54
Self { connected : false }
54
55
}
55
56
56
- fn set_connected ( & mut self , tx : & Sender < WsEvent > ) -> Result < ( ) , ConnectionStatusError > {
57
+ async fn set_connected (
58
+ & mut self ,
59
+ tx : & mpsc:: Sender < WsEvent > ,
60
+ ) -> Result < ( ) , ConnectionStatusError > {
57
61
if !self . connected {
58
- if let Err ( error) = tx. send ( WsEvent :: ConnectedEvent ) {
62
+ if let Err ( error) = tx. send ( WsEvent :: ConnectedEvent ) . await {
59
63
error ! ( %error, "Error sending to channel" ) ;
60
64
self . connected = false ;
61
65
return Err ( ConnectionStatusError :: ChannelClosed ) ;
@@ -65,9 +69,12 @@ impl ConnectionStatus {
65
69
Ok ( ( ) )
66
70
}
67
71
68
- fn set_disconnected ( & mut self , tx : & Sender < WsEvent > ) -> Result < ( ) , ConnectionStatusError > {
72
+ async fn set_disconnected (
73
+ & mut self ,
74
+ tx : & mpsc:: Sender < WsEvent > ,
75
+ ) -> Result < ( ) , ConnectionStatusError > {
69
76
if self . connected {
70
- if let Err ( error) = tx. send ( WsEvent :: DisconnectedEvent ) {
77
+ if let Err ( error) = tx. send ( WsEvent :: DisconnectedEvent ) . await {
71
78
error ! ( %error, "Error sending to channel" ) ;
72
79
return Err ( ConnectionStatusError :: ChannelClosed ) ;
73
80
}
@@ -79,48 +86,30 @@ impl ConnectionStatus {
79
86
80
87
/// A websocket connection to the QS server. See the
81
88
/// [`ApiClient::spawn_websocket`] method for more information.
89
+ ///
90
+ /// When dropped, the websocket connection will be closed.
82
91
pub struct QsWebSocket {
83
- rx : Receiver < WsEvent > ,
84
- tx : Sender < WsEvent > ,
85
- handle : JoinHandle < ( ) > ,
92
+ rx : mpsc:: Receiver < WsEvent > ,
93
+ _cancel : DropGuard ,
86
94
}
87
95
88
96
impl QsWebSocket {
89
97
/// Returns the next [`WsEvent`] event. This will block until an event is
90
98
/// sent or the connection is closed (in which case a final `None` is
91
99
/// returned).
92
100
pub async fn next ( & mut self ) -> Option < WsEvent > {
93
- match self . rx . recv ( ) . await {
94
- Ok ( message) => Some ( message) ,
95
- Err ( error) => {
96
- error ! ( %error, "Error receiving from channel" ) ;
97
- None
98
- }
99
- }
100
- }
101
-
102
- /// Subscribe to the event stream
103
- pub fn subscribe ( & self ) -> Receiver < WsEvent > {
104
- self . tx . subscribe ( )
105
- }
106
-
107
- /// Join the websocket connection task. This will block until the task has
108
- /// completed.
109
- pub async fn join ( self ) -> Result < ( ) , tokio:: task:: JoinError > {
110
- self . handle . await
111
- }
112
-
113
- /// Abort the websocket connection task. This will close the websocket connection.
114
- pub fn abort ( & mut self ) {
115
- self . handle . abort ( ) ;
101
+ self . rx . recv ( ) . await
116
102
}
117
103
118
104
/// Internal helper function to handle an established websocket connection
105
+ ///
106
+ /// Returns `true` if the connection should be re-established, otherwise `false`.
119
107
async fn handle_connection (
120
108
ws_stream : WebSocketStream < MaybeTlsStream < TcpStream > > ,
121
- tx : & Sender < WsEvent > ,
109
+ tx : & mpsc :: Sender < WsEvent > ,
122
110
timeout : u64 ,
123
- ) {
111
+ cancel : & CancellationToken ,
112
+ ) -> bool {
124
113
let mut last_ping = Instant :: now ( ) ;
125
114
126
115
// Watchdog to monitor the connection.
@@ -131,26 +120,31 @@ impl QsWebSocket {
131
120
132
121
// Initialize the connection status
133
122
let mut connection_status = ConnectionStatus :: new ( ) ;
134
- if connection_status. set_connected ( tx) . is_err ( ) {
123
+ if connection_status. set_connected ( tx) . await . is_err ( ) {
135
124
// Close the stream if all subscribers of the watch have been dropped
136
125
let _ = ws_stream. close ( ) . await ;
137
- return ;
126
+ return false ;
138
127
}
139
128
140
129
// Loop while the connection is open
141
130
loop {
142
131
tokio:: select! {
132
+ // Check is the handler is cancelled
133
+ _ = cancel. cancelled( ) => {
134
+ info!( "QS WebSocket connection cancelled" ) ;
135
+ break false ;
136
+ } ,
143
137
// Check if the connection is still alive
144
138
_ = interval. tick( ) => {
145
139
let now = Instant :: now( ) ;
146
140
// Check if we have reached the timeout
147
141
if now. duration_since( last_ping) > Duration :: from_secs( timeout) {
148
142
// Change the status to Disconnected and send an event
149
143
let _ = ws_stream. close( ) . await ;
150
- if connection_status. set_disconnected( tx) . is_err( ) {
144
+ if connection_status. set_disconnected( tx) . await . is_err( ) {
151
145
// Close the stream if all subscribers of the watch have been dropped
152
146
info!( "Closing the connection because all subscribers are dropped" ) ;
153
- return ;
147
+ return false ;
154
148
}
155
149
}
156
150
} ,
@@ -163,53 +157,53 @@ impl QsWebSocket {
163
157
// Reset the last ping time
164
158
last_ping = Instant :: now( ) ;
165
159
// Change the status to Connected and send an event
166
- if connection_status. set_connected( tx) . is_err( ) {
160
+ if connection_status. set_connected( tx) . await . is_err( ) {
167
161
// Close the stream if all subscribers of the watch have been dropped
168
162
info!( "Closing the connection because all subscribers are dropped" ) ;
169
163
let _ = ws_stream. close( ) . await ;
170
- return ;
164
+ return false ;
171
165
}
172
166
// Try to deserialize the message
173
167
if let Ok ( QsWsMessage :: QueueUpdate ) =
174
168
QsWsMessage :: tls_deserialize_exact_bytes( & data)
175
169
{
176
170
// We received a new message notification from the QS
177
171
// Send the event to the channel
178
- if tx. send( WsEvent :: MessageEvent ( QsWsMessage :: QueueUpdate ) ) . is_err( ) {
172
+ if tx. send( WsEvent :: MessageEvent ( QsWsMessage :: QueueUpdate ) ) . await . is_err( ) {
179
173
info!( "Closing the connection because all subscribers are dropped" ) ;
180
174
// Close the stream if all subscribers of the watch have been dropped
181
175
let _ = ws_stream. close( ) . await ;
182
- return ;
176
+ return false ;
183
177
}
184
178
}
185
179
} ,
186
180
// We received a ping
187
181
Message :: Ping ( _) => {
188
182
// We update the last ping time
189
183
last_ping = Instant :: now( ) ;
190
- if connection_status. set_connected( tx) . is_err( ) {
184
+ if connection_status. set_connected( tx) . await . is_err( ) {
191
185
// Close the stream if all subscribers of the watch have been dropped
192
186
info!( "Closing the connection because all subscribers are dropped" ) ;
193
187
let _ = ws_stream. close( ) . await ;
194
- return ;
188
+ return false ;
195
189
}
196
190
}
197
191
Message :: Close ( _) => {
198
192
// Change the status to Disconnected and send an
199
193
// event
200
- let _ = connection_status. set_disconnected( tx) ;
194
+ let _ = connection_status. set_disconnected( tx) . await ;
201
195
// We close the websocket
202
196
let _ = ws_stream. close( ) . await ;
203
- return ;
197
+ return true ;
204
198
}
205
199
_ => {
206
200
}
207
201
}
208
202
} else {
209
203
// It seems the connection is closed, send disconnect
210
204
// event
211
- let _ = connection_status. set_disconnected( tx) ;
212
- break ;
205
+ let _ = connection_status. set_disconnected( tx) . await ;
206
+ break true ;
213
207
}
214
208
} ,
215
209
}
@@ -255,13 +249,14 @@ impl ApiClient {
255
249
/// [`WsEvent::ConnectedEvent].
256
250
///
257
251
/// The connection will be closed if all subscribers of the [`QsWebSocket`]
258
- /// have been dropped, or when it is manually closed with using the
259
- /// [`QsWebSocket::abort()`] function .
252
+ /// have been dropped, or when it is manually closed by cancelling the token
253
+ /// `cancel` .
260
254
///
261
255
/// # Arguments
262
256
/// - `queue_id` - The ID of the queue monitor.
263
257
/// - `timeout` - The timeout for the connection in seconds.
264
258
/// - `retry_interval` - The interval between connection attempts in seconds.
259
+ /// - `cancel` - The cancellation token to stop the socket.
265
260
///
266
261
/// # Returns
267
262
/// A new [`QsWebSocket`] that represents the websocket connection.
@@ -270,6 +265,7 @@ impl ApiClient {
270
265
queue_id : QsClientId ,
271
266
timeout : u64 ,
272
267
retry_interval : u64 ,
268
+ cancel : CancellationToken ,
273
269
) -> Result < QsWebSocket , SpawnWsError > {
274
270
// Set the request parameter
275
271
let qs_ws_open_params = QsOpenWsParams { queue_id } ;
@@ -289,19 +285,19 @@ impl ApiClient {
289
285
} ) ?;
290
286
291
287
// We create a channel to send events to
292
- let ( tx, rx) = broadcast :: channel ( 100 ) ;
288
+ let ( tx, rx) = mpsc :: channel ( 100 ) ;
293
289
294
- // We clone the sender, so that we can subscribe to more receivers
295
- let tx_clone = tx . clone ( ) ;
290
+ let connection_id = Uuid :: new_v4 ( ) ;
291
+ info ! ( %connection_id , "Spawning the websocket connection..." ) ;
296
292
297
- info ! ( "Spawning the websocket connection..." ) ;
293
+ let cancel_guard = cancel . clone ( ) . drop_guard ( ) ;
298
294
299
295
// Spawn the connection task
300
- let handle = tokio:: spawn ( async move {
296
+ tokio:: spawn ( async move {
301
297
// Connection loop
302
298
#[ cfg( test) ]
303
299
let mut counter = 0 ;
304
- loop {
300
+ while !cancel . is_cancelled ( ) {
305
301
// We build the request and set a custom header
306
302
let req = match address. clone ( ) . into_client_request ( ) {
307
303
Ok ( mut req) => {
@@ -319,13 +315,15 @@ impl ApiClient {
319
315
match connect_async ( req) . await {
320
316
// The connection was established
321
317
Ok ( ( ws_stream, _) ) => {
322
- info ! ( "Connected to QS WebSocket" ) ;
318
+ info ! ( %connection_id , "Connected to QS WebSocket" ) ;
323
319
// Hand over the connection to the handler
324
- QsWebSocket :: handle_connection ( ws_stream, & tx, timeout) . await ;
320
+ if !QsWebSocket :: handle_connection ( ws_stream, & tx, timeout, & cancel) . await {
321
+ break ;
322
+ }
325
323
}
326
324
// The connection was not established, wait and try again
327
- Err ( e ) => {
328
- error ! ( "Error connecting to QS WebSocket: {}" , e ) ;
325
+ Err ( error ) => {
326
+ error ! ( %error , "Error connecting to QS WebSocket" ) ;
329
327
#[ cfg( test) ]
330
328
{
331
329
counter += 1 ;
@@ -336,17 +334,20 @@ impl ApiClient {
336
334
}
337
335
}
338
336
info ! (
337
+ %connection_id,
339
338
retry_in_sec = retry_interval,
339
+ is_cancelled = cancel. is_cancelled( ) ,
340
340
"The websocket was closed, will reconnect..." ,
341
341
) ;
342
342
sleep ( time:: Duration :: from_secs ( retry_interval) ) . await ;
343
343
}
344
+
345
+ info ! ( %connection_id, "QS WebSocket closed" ) ;
344
346
} ) ;
345
347
346
348
Ok ( QsWebSocket {
347
349
rx,
348
- tx : tx_clone,
349
- handle,
350
+ _cancel : cancel_guard,
350
351
} )
351
352
}
352
353
}
0 commit comments