From 1cf788d8be06cda3e86202303b05433270200f28 Mon Sep 17 00:00:00 2001 From: Colin Marc Date: Wed, 18 Sep 2024 23:15:50 +0200 Subject: [PATCH] Add support for sendmmsg(2) on linux https://man7.org/linux/man-pages/man2/sendmmsg.2.html Partially addresses #1156. Signed-off-by: Colin Marc --- src/backend/libc/net/syscalls.rs | 21 +++++++ src/backend/linux_raw/c.rs | 10 +-- src/backend/linux_raw/net/syscalls.rs | 32 +++++++++- src/net/send_recv/msg.rs | 67 ++++++++++++++++++++ tests/net/v4.rs | 91 +++++++++++++++++++++++++++ tests/net/v6.rs | 91 +++++++++++++++++++++++++++ 6 files changed, 305 insertions(+), 7 deletions(-) diff --git a/src/backend/libc/net/syscalls.rs b/src/backend/libc/net/syscalls.rs index 14bce06e5..3acc7957b 100644 --- a/src/backend/libc/net/syscalls.rs +++ b/src/backend/libc/net/syscalls.rs @@ -3,9 +3,13 @@ use super::read_sockaddr::initialize_family_to_unspec; use super::send_recv::{RecvFlags, SendFlags}; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::ret_u32; use crate::backend::conv::{borrowed_fd, ret, ret_owned_fd, ret_send_recv, send_recv_len}; use crate::fd::{BorrowedFd, OwnedFd}; use crate::io; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; use crate::net::SocketAddrBuf; use crate::net::{ addr::SocketAddrArg, AddressFamily, Protocol, Shutdown, SocketAddrAny, SocketFlags, SocketType, @@ -231,6 +235,23 @@ pub(crate) fn sendmsg_addr( }) } +#[cfg(target_os = "linux")] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + unsafe { + ret_u32(c::sendmmsg( + borrowed_fd(sockfd), + msgs.as_mut_ptr() as _, + msgs.len().try_into().unwrap_or(c::c_uint::MAX), + bitflags_bits!(flags), + )) + .map(|ret| ret as usize) + } +} + #[cfg(not(any( apple, windows, diff --git a/src/backend/linux_raw/c.rs b/src/backend/linux_raw/c.rs index fee712e05..8a0ce02dd 100644 --- a/src/backend/linux_raw/c.rs +++ b/src/backend/linux_raw/c.rs @@ -76,12 +76,12 @@ pub(crate) use linux_raw_sys::{ general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK}, if_ether::*, net::{ - linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, __kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage, - cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK, - AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET, - AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC, - AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, + cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, linger, mmsghdr, + msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet, + AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, + AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, + AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE, AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP, IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING, IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS, diff --git a/src/backend/linux_raw/net/syscalls.rs b/src/backend/linux_raw/net/syscalls.rs index 29f275008..cfc4b69b2 100644 --- a/src/backend/linux_raw/net/syscalls.rs +++ b/src/backend/linux_raw/net/syscalls.rs @@ -9,6 +9,8 @@ use super::msghdr::{with_msghdr, with_noaddr_msghdr, with_recv_msghdr}; use super::read_sockaddr::initialize_family_to_unspec; use super::send_recv::{RecvFlags, ReturnFlags, SendFlags}; use crate::backend::c; +#[cfg(target_os = "linux")] +use crate::backend::conv::slice_mut; use crate::backend::conv::{ by_mut, by_ref, c_int, c_uint, pass_usize, ret, ret_owned_fd, ret_usize, size_of, slice, socklen_t, zero, @@ -16,6 +18,8 @@ use crate::backend::conv::{ use crate::backend::reg::raw_arg; use crate::fd::{BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; +#[cfg(target_os = "linux")] +use crate::net::MMsgHdr; use crate::net::SocketAddrBuf; use crate::net::{ addr::SocketAddrArg, AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsg, @@ -28,8 +32,8 @@ use { crate::backend::reg::{ArgReg, SocketArg}, linux_raw_sys::net::{ SYS_ACCEPT, SYS_ACCEPT4, SYS_BIND, SYS_CONNECT, SYS_GETPEERNAME, SYS_GETSOCKNAME, - SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMSG, SYS_SENDTO, - SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, + SYS_LISTEN, SYS_RECV, SYS_RECVFROM, SYS_RECVMSG, SYS_SEND, SYS_SENDMMSG, SYS_SENDMSG, + SYS_SENDTO, SYS_SHUTDOWN, SYS_SOCKET, SYS_SOCKETPAIR, }, }; @@ -331,6 +335,30 @@ pub(crate) fn sendmsg_addr( }) } +#[cfg(target_os = "linux")] +#[inline] +pub(crate) fn sendmmsg( + sockfd: BorrowedFd<'_>, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + let (msgs, len) = slice_mut(msgs); + + #[cfg(not(target_arch = "x86"))] + let result = unsafe { ret_usize(syscall!(__NR_sendmmsg, sockfd, msgs, len, flags)) }; + + #[cfg(target_arch = "x86")] + let result = unsafe { + ret_usize(syscall!( + __NR_socketcall, + x86_sys(SYS_SENDMMSG), + slice_just_addr::, _>(&[sockfd.into(), msgs, len, flags.into()]) + )) + }; + + result +} + #[inline] pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> { #[cfg(not(target_arch = "x86"))] diff --git a/src/net/send_recv/msg.rs b/src/net/send_recv/msg.rs index 72e12fc79..ce67c71e3 100644 --- a/src/net/send_recv/msg.rs +++ b/src/net/send_recv/msg.rs @@ -2,6 +2,8 @@ #![allow(unsafe_code)] +#[cfg(target_os = "linux")] +use crate::backend::net::msghdr::{with_msghdr, with_noaddr_msghdr}; use crate::backend::{self, c}; use crate::fd::{AsFd, BorrowedFd, OwnedFd}; use crate::io::{self, IoSlice, IoSliceMut}; @@ -591,6 +593,55 @@ impl<'buf> Iterator for AncillaryDrain<'buf> { impl FusedIterator for AncillaryDrain<'_> {} +/// An ABI-compatible wrapper for `mmsghdr`, for sending multiple messages with +/// [sendmmsg]. +#[cfg(target_os = "linux")] +#[repr(transparent)] +pub struct MMsgHdr<'a> { + raw: c::mmsghdr, + _phantom: PhantomData<&'a mut ()>, +} + +#[cfg(target_os = "linux")] +impl<'a> MMsgHdr<'a> { + /// Constructs a new message with no destination address. + pub fn new(iov: &'a [IoSlice<'_>], control: &'a mut SendAncillaryBuffer<'_, '_, '_>) -> Self { + with_noaddr_msghdr(iov, control, Self::wrap) + } + + /// Constructs a new message to a specific address. + /// + /// The lifetime of `addr` (and the underlying + /// [SocketAddrStorage](crate::net::addr::SocketAddrStorage)) must be valid + /// until the call to [sendmmsg], so types implementing + /// [SocketAddrArg](crate::net::addr::SocketAddrArg) can't be used here + /// without first being converted using + /// [SocketAddrArg::as_any](crate::net::addr::SocketAddrArg::as_any). + pub fn new_with_addr( + addr: &'a SocketAddrAny, + iov: &'a [IoSlice<'_>], + control: &'a mut SendAncillaryBuffer<'_, '_, '_>, + ) -> MMsgHdr<'a> { + with_msghdr(addr, iov, control, Self::wrap) + } + + fn wrap(msg_hdr: c::msghdr) -> Self { + Self { + raw: c::mmsghdr { + msg_hdr, + msg_len: 0, + }, + _phantom: PhantomData, + } + } + + /// Returns the number of bytes sent. This will return 0 until after a + /// successful call to [sendmmsg]. + pub fn bytes_sent(&self) -> usize { + self.raw.msg_len as _ + } +} + /// `sendmsg(msghdr)`—Sends a message on a socket. /// /// This function is for use on connected sockets, as it doesn't have @@ -656,6 +707,22 @@ pub fn sendmsg_addr( backend::net::syscalls::sendmsg_addr(socket.as_fd(), addr, iov, control, flags) } +/// `sendmmsg(msghdr)`—Sends multiple messages on a socket. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/sendmmsg.2.html +#[inline] +#[cfg(target_os = "linux")] +pub fn sendmmsg( + socket: impl AsFd, + msgs: &mut [MMsgHdr<'_>], + flags: SendFlags, +) -> io::Result { + backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags) +} + /// `recvmsg(msghdr)`—Receives a message from a socket. /// /// # References diff --git a/tests/net/v4.rs b/tests/net/v4.rs index fec3c320b..129468d41 100644 --- a/tests/net/v4.rs +++ b/tests/net/v4.rs @@ -194,3 +194,94 @@ fn test_v4_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v4_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::addr::SocketAddrArg as _; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0); + bind(&connection_socket, &name).unwrap(); + + let who = getsockname(&connection_socket).unwrap(); + let who = SocketAddrV4::try_from(who).unwrap(); + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port); + let data_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap(); + connect(&data_socket, &addr).unwrap(); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new_with_addr( + &addr.as_any(), + &[IoSlice::new(b"...world")], + &mut Default::default(), + ), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} diff --git a/tests/net/v6.rs b/tests/net/v6.rs index 82fb954c7..9bcba69c5 100644 --- a/tests/net/v6.rs +++ b/tests/net/v6.rs @@ -193,3 +193,94 @@ fn test_v6_msg() { client.join().unwrap(); server.join().unwrap(); } + +#[test] +#[cfg(target_os = "linux")] +fn test_v6_sendmmsg() { + crate::init(); + + use std::net::TcpStream; + + use rustix::io::IoSlice; + use rustix::net::addr::SocketAddrArg as _; + use rustix::net::{sendmmsg, MMsgHdr}; + + fn server(ready: Arc<(Mutex, Condvar)>) { + let connection_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + + let name = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 0, 0, 0); + bind(&connection_socket, &name).unwrap(); + + let who = getsockname(&connection_socket).unwrap(); + let who = SocketAddrV6::try_from(who).unwrap(); + + listen(&connection_socket, 1).unwrap(); + + { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + *port = who.port(); + cvar.notify_all(); + } + + let mut buffer = vec![0; 13]; + let mut data_socket: TcpStream = accept(&connection_socket).unwrap().into(); + + std::io::Read::read_exact(&mut data_socket, &mut buffer).unwrap(); + assert_eq!(String::from_utf8_lossy(&buffer), "hello...world"); + } + + fn client(ready: Arc<(Mutex, Condvar)>) { + let port = { + let (lock, cvar) = &*ready; + let mut port = lock.lock().unwrap(); + while *port == 0 { + port = cvar.wait(port).unwrap(); + } + *port + }; + + let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), port, 0, 0); + let data_socket = socket(AddressFamily::INET6, SocketType::STREAM, None).unwrap(); + connect(&data_socket, &addr).unwrap(); + + let mut off = 0; + while off < 2 { + let sent = sendmmsg( + &data_socket, + &mut [ + MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()), + MMsgHdr::new_with_addr( + &addr.as_any(), + &[IoSlice::new(b"...world")], + &mut Default::default(), + ), + ][off..], + SendFlags::empty(), + ) + .unwrap(); + + off += sent; + } + } + + let ready = Arc::new((Mutex::new(0_u16), Condvar::new())); + let ready_clone = Arc::clone(&ready); + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(ready); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(ready_clone); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +}