diff --git a/boltconn/src/transport/wireguard.rs b/boltconn/src/transport/wireguard.rs index 580ded7..e698b5c 100644 --- a/boltconn/src/transport/wireguard.rs +++ b/boltconn/src/transport/wireguard.rs @@ -15,6 +15,7 @@ use std::io; use std::io::ErrorKind; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Notify; // We left AllowedIPs since it's boltconn that manages routing. @@ -157,12 +158,16 @@ impl WireguardTunnel { let name = config.name.clone(); local_async_run(async move { // dedicated to poll UDP from small kernel buffer - loop { + while !in_tx.is_disconnected() { let key = match pool.clone().create_owned() { Some(mut buf) => { let key = BufferIndex::Pool(buf.key()); buf.resize(MAX_UDP_PKT_SIZE, 0); - let Ok(len) = socket.recv(&mut buf).await else { + let recv_result = tokio::select! { + recv_result = socket.recv(&mut buf) => recv_result, + _ = tokio::time::sleep(Duration::from_millis(500)) => continue, + }; + let Ok(len) = recv_result else { tracing::warn!( "WireGuard #{} failed to receive from socket", name, @@ -174,7 +179,11 @@ impl WireguardTunnel { } None => { let mut buf = vec![0; MAX_UDP_PKT_SIZE]; - let len = match socket.recv(&mut buf).await { + let recv_result = tokio::select! { + recv_result = socket.recv(&mut buf) => recv_result, + _ = tokio::time::sleep(Duration::from_millis(500)) => continue, + }; + let len = match recv_result { Ok(len) => len, Err(e) => { tracing::warn!( @@ -190,9 +199,14 @@ impl WireguardTunnel { } }; if let Err(err) = in_tx.try_send(key) { + let is_disconnected = + matches!(err, flume::TrySendError::Disconnected(_)); if let BufferIndex::Pool(key) = err.into_inner() { pool.clear(key); } + if is_disconnected { + break; + } tracing::warn!("channel full, dropping packet"); } }