diff --git a/changelog/2609.added.md b/changelog/2609.added.md new file mode 100644 index 0000000000..041d07e6c0 --- /dev/null +++ b/changelog/2609.added.md @@ -0,0 +1 @@ +Add `accept_from` function returning the remote-address. diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index e26e327d78..84fecaf351 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -2364,6 +2364,41 @@ pub fn accept4(sockfd: RawFd, flags: SockFlag) -> Result { Errno::result(res) } +/// Accept a connection on a socket, returning the remote address. +/// +/// [Further reading](https://man7.org/linux/man-pages/man2/accept.2.html) +#[cfg(any( + all( + target_os = "android", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ), + freebsdlike, + netbsdlike, + target_os = "emscripten", + target_os = "fuchsia", + solarish, + target_os = "linux", +))] +pub fn accept_from(sockfd: RawFd, flags: SockFlag) -> Result<(RawFd, Option)> { + let mut storage = std::mem::MaybeUninit::::uninit(); + let mut socklen = std::mem::size_of::() as libc::socklen_t; + let res = unsafe { + libc::accept4(sockfd, storage.as_mut_ptr().cast(), &mut socklen as *mut _, flags.bits()) + }; + + let sock = Errno::result(res)?; + let addr = unsafe { + let storage = storage.assume_init(); + S::from_raw((&storage as *const libc::sockaddr_storage).cast(), Some(socklen)) + }; + + Ok((sock, addr)) +} + /// Initiate a connection on a socket /// /// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 5a43aab437..c265cf74af 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -3161,3 +3161,56 @@ fn can_open_routing_socket() { socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None) .expect("Failed to open routing socket"); } + +#[cfg(any( + all( + target_os = "android", + any( + target_arch = "aarch64", + target_arch = "x86", + target_arch = "x86_64" + ) + ), + freebsdlike, + netbsdlike, + target_os = "emscripten", + target_os = "fuchsia", + solarish, + target_os = "linux", +))] +#[test] +fn test_accept_from() { + use nix::sys::socket::{accept_from, bind, connect, listen, socket}; + use nix::sys::socket::{Backlog, SockFlag, SockType, SockaddrIn}; + use std::net::Ipv4Addr; + + let sock_addr = SockaddrIn::from_str("127.0.0.1:6780").unwrap(); + let listener = socket( + AddressFamily::Inet, + SockType::Stream, + SockFlag::empty(), + None, + ) + .expect("listener socket failed"); + bind(listener.as_raw_fd(), &sock_addr).expect("bind failed"); + listen(&listener, Backlog::MAXCONN).expect("listen failed"); + + let connector = socket( + AddressFamily::Inet, + SockType::Stream, + SockFlag::empty(), + None, + ) + .expect("connector socket failed"); + + let send_thread = + std::thread::spawn(move || connect(connector.as_raw_fd(), &sock_addr)); + let (_peer, address) = + accept_from::(listener.as_raw_fd(), SockFlag::empty()) + .unwrap(); + let address = address.expect("no address"); + + assert_eq!(address.ip(), Ipv4Addr::LOCALHOST); + + send_thread.join().unwrap().unwrap(); +}