diff --git a/Cargo.toml b/Cargo.toml index 2f1db89..290f032 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ tracing = { version = "0.1.37", default-features = false, features = ["std", "lo serde = { version = "1.0.145", features = ["derive"], optional = true } [dev-dependencies] -tokio = { version = "1.32.0", features = ["rt", "macros"] } +tokio = { version = "1.32.0", features = ["full"] } [features] default = ["serde"] diff --git a/examples/logchanges.rs b/examples/logchanges.rs new file mode 100644 index 0000000..c041bfa --- /dev/null +++ b/examples/logchanges.rs @@ -0,0 +1,11 @@ +use timestamped_socket::interface::ChangeDetector; + +#[tokio::main] +async fn main() { + let mut detector = ChangeDetector::new().unwrap(); + + loop { + detector.wait_for_change().await; + println!("Change detected"); + } +} diff --git a/src/interface.rs b/src/interface.rs index b5436e1..2150c53 100644 --- a/src/interface.rs +++ b/src/interface.rs @@ -6,6 +6,22 @@ use std::{ use super::cerr; +#[cfg(target_os = "linux")] +mod linux; +#[cfg(target_os = "linux")] +pub use linux::ChangeDetector; + +// NOTE: this detection logic is not sharable with macos! +#[cfg(target_os = "freebsd")] +mod freebsd; +#[cfg(target_os = "freebsd")] +pub use freebsd::ChangeDetector; + +#[cfg(not(any(target_os = "linux", target_os = "freebsd")))] +mod fallback; +#[cfg(not(any(target_os = "linux", target_os = "freebsd")))] +pub use fallback::ChangeDetector; + pub fn interfaces() -> std::io::Result> { let mut elements = HashMap::default(); diff --git a/src/interface/fallback.rs b/src/interface/fallback.rs new file mode 100644 index 0000000..8eccac5 --- /dev/null +++ b/src/interface/fallback.rs @@ -0,0 +1,16 @@ +struct Private; + +pub struct ChangeDetector { + _private: Private, +} + +impl ChangeDetector { + pub fn new() -> std::io::Result { + Ok(Self { _private: Private }) + } + + pub async fn wait_for_change(&mut self) { + // No platform independent way, but checking every so often is fine for a fallback + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + } +} diff --git a/src/interface/freebsd.rs b/src/interface/freebsd.rs new file mode 100644 index 0000000..aca587d --- /dev/null +++ b/src/interface/freebsd.rs @@ -0,0 +1,101 @@ +use std::{io::ErrorKind, os::fd::RawFd}; + +use libc::recv; +use tokio::io::{unix::AsyncFd, Interest}; + +use crate::{cerr, control_message::zeroed_sockaddr_storage}; + +pub struct ChangeDetector { + fd: AsyncFd, +} + +impl ChangeDetector { + const SOCKET_PATH: &'static [u8] = b"/var/run/devd.seqpacket.pipe"; + pub fn new() -> std::io::Result { + const _: () = assert!( + std::mem::size_of::() + >= std::mem::size_of::() + ); + const _: () = assert!( + std::mem::align_of::() + >= std::mem::align_of::() + ); + + let mut address_buf = zeroed_sockaddr_storage(); + // Safety: the above assertions guarantee that alignment and size are correct. + // the resulting reference won't outlast the function, and result lives the entire + // duration of the function + let address = unsafe { + &mut *(&mut address_buf as *mut libc::sockaddr_storage as *mut libc::sockaddr_un) + }; + + address.sun_family = libc::AF_UNIX as _; + for i in 0..Self::SOCKET_PATH.len() { + address.sun_path[i] = Self::SOCKET_PATH[i] as _; + } + + // Safety: calling socket is safe + let fd = cerr(unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) })?; + // Safety: address is valid for the duration of the call + cerr(unsafe { + libc::bind( + fd, + address as *mut _ as *mut _, + std::mem::size_of_val(address) as _, + ) + })?; + + let nonblocking = 1 as libc::c_int; + // Safety: nonblocking lives for the duration of the call, and is 4 bytes long as expected for FIONBIO + cerr(unsafe { libc::ioctl(fd, libc::FIONBIO, &nonblocking) })?; + + Ok(ChangeDetector { + fd: AsyncFd::new(fd)?, + }) + } + + fn empty(fd: i32) { + loop { + // Safety: buf is valid for the duration of the call, and it's length is passed as the len argument + let mut buf = [0u8; 16]; + match cerr(unsafe { + recv( + fd, + &mut buf as *mut _ as *mut _, + std::mem::size_of_val(&buf) as _, + 0, + ) as _ + }) { + Ok(_) => continue, + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + tracing::error!("Could not receive on change socket: {}", e); + break; + } + } + } + } + + pub async fn wait_for_change(&mut self) { + if let Err(e) = self + .fd + .async_io(Interest::READABLE, |fd| { + // Safety: buf is valid for the duration of the call, and it's length is passed as the len argument + let mut buf = [0u8; 16]; + cerr(unsafe { + recv( + *fd, + &mut buf as *mut _ as *mut _, + std::mem::size_of_val(&buf) as _, + 0, + ) as _ + })?; + Self::empty(*fd); + Ok(()) + }) + .await + { + tracing::error!("Could not receive on change socket: {}", e); + } + } +} diff --git a/src/interface/linux.rs b/src/interface/linux.rs new file mode 100644 index 0000000..51aa6cb --- /dev/null +++ b/src/interface/linux.rs @@ -0,0 +1,100 @@ +use std::{io::ErrorKind, os::fd::RawFd}; + +use libc::recv; +use tokio::io::{unix::AsyncFd, Interest}; + +use crate::{cerr, control_message::zeroed_sockaddr_storage}; + +pub struct ChangeDetector { + fd: AsyncFd, +} + +impl ChangeDetector { + pub fn new() -> std::io::Result { + const _: () = assert!( + std::mem::size_of::() + >= std::mem::size_of::() + ); + const _: () = assert!( + std::mem::align_of::() + >= std::mem::align_of::() + ); + + let mut address_buf = zeroed_sockaddr_storage(); + // Safety: the above assertions guarantee that alignment and size are correct. + // the resulting reference won't outlast the function, and result lives the entire + // duration of the function + let address = unsafe { + &mut *(&mut address_buf as *mut libc::sockaddr_storage as *mut libc::sockaddr_nl) + }; + + address.nl_family = libc::AF_NETLINK as _; + address.nl_groups = + (libc::RTMGRP_IPV4_IFADDR | libc::RTMGRP_IPV6_IFADDR | libc::RTMGRP_LINK) as _; + + // Safety: calling socket is safe + let fd = + cerr(unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE) })?; + // Safety: address is valid for the duration of the call + cerr(unsafe { + libc::bind( + fd, + address as *mut _ as *mut _, + std::mem::size_of_val(address) as _, + ) + })?; + + let nonblocking = 1 as libc::c_int; + // Safety: nonblocking lives for the duration of the call, and is 4 bytes long as expected for FIONBIO + cerr(unsafe { libc::ioctl(fd, libc::FIONBIO, &nonblocking) })?; + + Ok(ChangeDetector { + fd: AsyncFd::new(fd)?, + }) + } + + fn empty(fd: i32) { + loop { + // Safety: buf is valid for the duration of the call, and it's length is passed as the len argument + let mut buf = [0u8; 16]; + match cerr(unsafe { + recv( + fd, + &mut buf as *mut _ as *mut _, + std::mem::size_of_val(&buf) as _, + 0, + ) as _ + }) { + Ok(_) => continue, + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + tracing::error!("Could not receive on change socket: {}", e); + break; + } + } + } + } + + pub async fn wait_for_change(&mut self) { + if let Err(e) = self + .fd + .async_io(Interest::READABLE, |fd| { + // Safety: buf is valid for the duration of the call, and it's length is passed as the len argument + let mut buf = [0u8; 16]; + cerr(unsafe { + recv( + *fd, + &mut buf as *mut _ as *mut _, + std::mem::size_of_val(&buf) as _, + 0, + ) as _ + })?; + Self::empty(*fd); + Ok(()) + }) + .await + { + tracing::error!("Could not receive on change socket: {}", e); + } + } +}