Skip to content

Commit f2e0ce3

Browse files
committed
unix: Extend UdpSocket to send and receive TTL
Add the function recv_vectored_with_ancillary_from, recv_vectored_with_ancillary, send_vectored_with_ancillary_to and send_vectored_with_ancillary to UdpSocket for Unix platforms. Also add set_recvttl and recvttl to UdpSocket to tell the kernel to receive the TTL. Therefore, rename the Ancillary(Data) to UnixAncillary(Data) for the UnixDatagram and UnixStream and also add the IpAncillary(Data) for UdpSocket.
1 parent 136eaa1 commit f2e0ce3

File tree

11 files changed

+1065
-326
lines changed

11 files changed

+1065
-326
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use crate::{
2+
convert::TryFrom,
3+
io::{self, IoSlice, IoSliceMut},
4+
marker::PhantomData,
5+
mem::{size_of, zeroed},
6+
ptr::{eq, read_unaligned},
7+
slice::from_raw_parts,
8+
sys::net::Socket,
9+
};
10+
11+
// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
12+
#[cfg(all(doc, not(target_os = "linux"), not(target_os = "android")))]
13+
#[allow(non_camel_case_types)]
14+
mod libc {
15+
pub use libc::c_int;
16+
pub struct cmsghdr;
17+
}
18+
19+
impl Socket {
20+
pub(super) unsafe fn recv_vectored_with_ancillary_from(
21+
&self,
22+
msg_name: *mut libc::c_void,
23+
msg_namelen: libc::socklen_t,
24+
bufs: &mut [IoSliceMut<'_>],
25+
ancillary: &mut Ancillary<'_>,
26+
) -> io::Result<(usize, bool, libc::socklen_t)> {
27+
let mut msg: libc::msghdr = zeroed();
28+
msg.msg_name = msg_name;
29+
msg.msg_namelen = msg_namelen;
30+
msg.msg_iov = bufs.as_mut_ptr().cast();
31+
msg.msg_iovlen = bufs.len() as _;
32+
msg.msg_controllen = ancillary.buffer.len() as _;
33+
// macos requires that the control pointer is null when the len is 0.
34+
if msg.msg_controllen > 0 {
35+
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
36+
}
37+
38+
let count = self.recv_msg(&mut msg)?;
39+
40+
ancillary.length = msg.msg_controllen as usize;
41+
ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
42+
43+
let truncated = msg.msg_flags & libc::MSG_TRUNC == libc::MSG_TRUNC;
44+
45+
Ok((count, truncated, msg.msg_namelen))
46+
}
47+
48+
pub(super) unsafe fn send_vectored_with_ancillary_to(
49+
&self,
50+
msg_name: *mut libc::c_void,
51+
msg_namelen: libc::socklen_t,
52+
bufs: &[IoSlice<'_>],
53+
ancillary: &mut Ancillary<'_>,
54+
) -> io::Result<usize> {
55+
let mut msg: libc::msghdr = zeroed();
56+
msg.msg_name = msg_name;
57+
msg.msg_namelen = msg_namelen;
58+
msg.msg_iov = bufs.as_ptr() as *mut _;
59+
msg.msg_iovlen = bufs.len() as _;
60+
msg.msg_controllen = ancillary.length as _;
61+
// macos requires that the control pointer is null when the len is 0.
62+
if msg.msg_controllen > 0 {
63+
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
64+
}
65+
66+
ancillary.truncated = false;
67+
68+
self.send_msg(&mut msg)
69+
}
70+
}
71+
72+
#[derive(Debug)]
73+
pub(crate) struct Ancillary<'a> {
74+
buffer: &'a mut [u8],
75+
length: usize,
76+
truncated: bool,
77+
}
78+
79+
impl<'a> Ancillary<'a> {
80+
pub(super) fn new(buffer: &'a mut [u8]) -> Self {
81+
Ancillary { buffer, length: 0, truncated: false }
82+
}
83+
}
84+
85+
impl Ancillary<'_> {
86+
pub(super) fn add_to_ancillary_data<T>(
87+
&mut self,
88+
source: &[T],
89+
cmsg_level: libc::c_int,
90+
cmsg_type: libc::c_int,
91+
) -> bool {
92+
self.truncated = false;
93+
94+
let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) {
95+
if let Ok(source_len) = u32::try_from(source_len) {
96+
source_len
97+
} else {
98+
return false;
99+
}
100+
} else {
101+
return false;
102+
};
103+
104+
unsafe {
105+
let additional_space = libc::CMSG_SPACE(source_len) as usize;
106+
107+
let new_length = if let Some(new_length) = additional_space.checked_add(self.length) {
108+
new_length
109+
} else {
110+
return false;
111+
};
112+
113+
if new_length > self.buffer.len() {
114+
return false;
115+
}
116+
117+
self.buffer[self.length..new_length].fill(0);
118+
119+
self.length = new_length;
120+
121+
let mut msg: libc::msghdr = zeroed();
122+
msg.msg_control = self.buffer.as_mut_ptr().cast();
123+
msg.msg_controllen = self.length as _;
124+
125+
let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
126+
let mut previous_cmsg = cmsg;
127+
while !cmsg.is_null() {
128+
previous_cmsg = cmsg;
129+
cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
130+
131+
// Most operating systems, but not Linux or emscripten, return the previous pointer
132+
// when its length is zero. Therefore, check if the previous pointer is the same as
133+
// the current one.
134+
if eq(cmsg, previous_cmsg) {
135+
break;
136+
}
137+
}
138+
139+
if previous_cmsg.is_null() {
140+
return false;
141+
}
142+
143+
(*previous_cmsg).cmsg_level = cmsg_level;
144+
(*previous_cmsg).cmsg_type = cmsg_type;
145+
(*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as _;
146+
147+
let data = libc::CMSG_DATA(previous_cmsg).cast();
148+
149+
libc::memcpy(data, source.as_ptr().cast(), source_len as usize);
150+
}
151+
true
152+
}
153+
154+
pub(super) fn capacity(&self) -> usize {
155+
self.buffer.len()
156+
}
157+
158+
pub(super) fn is_empty(&self) -> bool {
159+
self.length == 0
160+
}
161+
162+
pub(super) fn len(&self) -> usize {
163+
self.length
164+
}
165+
166+
pub(super) fn messages<T>(&self) -> Messages<'_, T> {
167+
Messages { buffer: &self.buffer[..self.length], current: None, phantom: PhantomData {} }
168+
}
169+
170+
pub(super) fn truncated(&self) -> bool {
171+
self.truncated
172+
}
173+
174+
pub(super) fn clear(&mut self) {
175+
self.length = 0;
176+
self.truncated = false;
177+
}
178+
}
179+
180+
pub(super) struct AncillaryDataIter<'a, T> {
181+
data: &'a [u8],
182+
phantom: PhantomData<T>,
183+
}
184+
185+
impl<'a, T> AncillaryDataIter<'a, T> {
186+
/// Create `AncillaryDataIter` struct to iterate through the data unit in the control message.
187+
///
188+
/// # Safety
189+
///
190+
/// `data` must contain a valid control message.
191+
pub(super) unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> {
192+
AncillaryDataIter { data, phantom: PhantomData }
193+
}
194+
}
195+
196+
impl<'a, T> Iterator for AncillaryDataIter<'a, T> {
197+
type Item = T;
198+
199+
fn next(&mut self) -> Option<T> {
200+
if size_of::<T>() <= self.data.len() {
201+
unsafe {
202+
let unit = read_unaligned(self.data.as_ptr().cast());
203+
self.data = &self.data[size_of::<T>()..];
204+
Some(unit)
205+
}
206+
} else {
207+
None
208+
}
209+
}
210+
}
211+
212+
/// The error type which is returned from parsing the type a control message.
213+
#[non_exhaustive]
214+
#[derive(Debug)]
215+
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
216+
pub enum AncillaryError {
217+
Unknown { cmsg_level: i32, cmsg_type: i32 },
218+
}
219+
220+
/// Return the data of `cmsghdr` as a `u8` slice.
221+
pub(super) unsafe fn get_data_from_cmsghdr(cmsg: &libc::cmsghdr) -> &[u8] {
222+
let cmsg_len_zero = libc::CMSG_LEN(0) as usize;
223+
let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero;
224+
let data = libc::CMSG_DATA(cmsg).cast();
225+
from_raw_parts(data, data_len)
226+
}
227+
228+
/// This struct is used to iterate through the control messages.
229+
#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")]
230+
pub struct Messages<'a, T> {
231+
buffer: &'a [u8],
232+
current: Option<&'a libc::cmsghdr>,
233+
phantom: PhantomData<T>,
234+
}
235+
236+
impl<'a, T> Messages<'a, T> {
237+
pub(super) unsafe fn next_cmsghdr(&mut self) -> Option<&'a libc::cmsghdr> {
238+
let mut msg: libc::msghdr = zeroed();
239+
msg.msg_control = self.buffer.as_ptr() as *mut _;
240+
msg.msg_controllen = self.buffer.len() as _;
241+
242+
let cmsg = if let Some(current) = self.current {
243+
libc::CMSG_NXTHDR(&msg, current)
244+
} else {
245+
libc::CMSG_FIRSTHDR(&msg)
246+
};
247+
248+
let cmsg = cmsg.as_ref()?;
249+
250+
// Most operating systems, but not Linux or emscripten, return the previous pointer
251+
// when its length is zero. Therefore, check if the previous pointer is the same as
252+
// the current one.
253+
if let Some(current) = self.current {
254+
if eq(current, cmsg) {
255+
return None;
256+
}
257+
}
258+
259+
self.current = Some(cmsg);
260+
Some(cmsg)
261+
}
262+
}

0 commit comments

Comments
 (0)