|
| 1 | +use crate::{ |
| 2 | + convert::TryFrom, |
| 3 | + io::{self, IoSlice, IoSliceMut}, |
| 4 | + marker::PhantomData, |
| 5 | + mem::{size_of, zeroed}, |
| 6 | + ptr::{eq, read_unaligned}, |
| 7 | + slice::from_raw_parts, |
| 8 | + sys::net::Socket, |
| 9 | +}; |
| 10 | + |
| 11 | +// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here? |
| 12 | +#[cfg(all(doc, not(target_os = "linux"), not(target_os = "android")))] |
| 13 | +#[allow(non_camel_case_types)] |
| 14 | +mod libc { |
| 15 | + pub use libc::c_int; |
| 16 | + pub struct cmsghdr; |
| 17 | +} |
| 18 | + |
| 19 | +impl Socket { |
| 20 | + pub(super) unsafe fn recv_vectored_with_ancillary_from( |
| 21 | + &self, |
| 22 | + msg_name: *mut libc::c_void, |
| 23 | + msg_namelen: libc::socklen_t, |
| 24 | + bufs: &mut [IoSliceMut<'_>], |
| 25 | + ancillary: &mut Ancillary<'_>, |
| 26 | + ) -> io::Result<(usize, bool, libc::socklen_t)> { |
| 27 | + let mut msg: libc::msghdr = zeroed(); |
| 28 | + msg.msg_name = msg_name; |
| 29 | + msg.msg_namelen = msg_namelen; |
| 30 | + msg.msg_iov = bufs.as_mut_ptr().cast(); |
| 31 | + msg.msg_iovlen = bufs.len() as _; |
| 32 | + msg.msg_controllen = ancillary.buffer.len() as _; |
| 33 | + // macos requires that the control pointer is null when the len is 0. |
| 34 | + if msg.msg_controllen > 0 { |
| 35 | + msg.msg_control = ancillary.buffer.as_mut_ptr().cast(); |
| 36 | + } |
| 37 | + |
| 38 | + let count = self.recv_msg(&mut msg)?; |
| 39 | + |
| 40 | + ancillary.length = msg.msg_controllen as usize; |
| 41 | + ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC; |
| 42 | + |
| 43 | + let truncated = msg.msg_flags & libc::MSG_TRUNC == libc::MSG_TRUNC; |
| 44 | + |
| 45 | + Ok((count, truncated, msg.msg_namelen)) |
| 46 | + } |
| 47 | + |
| 48 | + pub(super) unsafe fn send_vectored_with_ancillary_to( |
| 49 | + &self, |
| 50 | + msg_name: *mut libc::c_void, |
| 51 | + msg_namelen: libc::socklen_t, |
| 52 | + bufs: &[IoSlice<'_>], |
| 53 | + ancillary: &mut Ancillary<'_>, |
| 54 | + ) -> io::Result<usize> { |
| 55 | + let mut msg: libc::msghdr = zeroed(); |
| 56 | + msg.msg_name = msg_name; |
| 57 | + msg.msg_namelen = msg_namelen; |
| 58 | + msg.msg_iov = bufs.as_ptr() as *mut _; |
| 59 | + msg.msg_iovlen = bufs.len() as _; |
| 60 | + msg.msg_controllen = ancillary.length as _; |
| 61 | + // macos requires that the control pointer is null when the len is 0. |
| 62 | + if msg.msg_controllen > 0 { |
| 63 | + msg.msg_control = ancillary.buffer.as_mut_ptr().cast(); |
| 64 | + } |
| 65 | + |
| 66 | + ancillary.truncated = false; |
| 67 | + |
| 68 | + self.send_msg(&mut msg) |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +#[derive(Debug)] |
| 73 | +pub(crate) struct Ancillary<'a> { |
| 74 | + buffer: &'a mut [u8], |
| 75 | + length: usize, |
| 76 | + truncated: bool, |
| 77 | +} |
| 78 | + |
| 79 | +impl<'a> Ancillary<'a> { |
| 80 | + pub(super) fn new(buffer: &'a mut [u8]) -> Self { |
| 81 | + Ancillary { buffer, length: 0, truncated: false } |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +impl Ancillary<'_> { |
| 86 | + pub(super) fn add_to_ancillary_data<T>( |
| 87 | + &mut self, |
| 88 | + source: &[T], |
| 89 | + cmsg_level: libc::c_int, |
| 90 | + cmsg_type: libc::c_int, |
| 91 | + ) -> bool { |
| 92 | + self.truncated = false; |
| 93 | + |
| 94 | + let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) { |
| 95 | + if let Ok(source_len) = u32::try_from(source_len) { |
| 96 | + source_len |
| 97 | + } else { |
| 98 | + return false; |
| 99 | + } |
| 100 | + } else { |
| 101 | + return false; |
| 102 | + }; |
| 103 | + |
| 104 | + unsafe { |
| 105 | + let additional_space = libc::CMSG_SPACE(source_len) as usize; |
| 106 | + |
| 107 | + let new_length = if let Some(new_length) = additional_space.checked_add(self.length) { |
| 108 | + new_length |
| 109 | + } else { |
| 110 | + return false; |
| 111 | + }; |
| 112 | + |
| 113 | + if new_length > self.buffer.len() { |
| 114 | + return false; |
| 115 | + } |
| 116 | + |
| 117 | + self.buffer[self.length..new_length].fill(0); |
| 118 | + |
| 119 | + self.length = new_length; |
| 120 | + |
| 121 | + let mut msg: libc::msghdr = zeroed(); |
| 122 | + msg.msg_control = self.buffer.as_mut_ptr().cast(); |
| 123 | + msg.msg_controllen = self.length as _; |
| 124 | + |
| 125 | + let mut cmsg = libc::CMSG_FIRSTHDR(&msg); |
| 126 | + let mut previous_cmsg = cmsg; |
| 127 | + while !cmsg.is_null() { |
| 128 | + previous_cmsg = cmsg; |
| 129 | + cmsg = libc::CMSG_NXTHDR(&msg, cmsg); |
| 130 | + |
| 131 | + // Most operating systems, but not Linux or emscripten, return the previous pointer |
| 132 | + // when its length is zero. Therefore, check if the previous pointer is the same as |
| 133 | + // the current one. |
| 134 | + if eq(cmsg, previous_cmsg) { |
| 135 | + break; |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + if previous_cmsg.is_null() { |
| 140 | + return false; |
| 141 | + } |
| 142 | + |
| 143 | + (*previous_cmsg).cmsg_level = cmsg_level; |
| 144 | + (*previous_cmsg).cmsg_type = cmsg_type; |
| 145 | + (*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as _; |
| 146 | + |
| 147 | + let data = libc::CMSG_DATA(previous_cmsg).cast(); |
| 148 | + |
| 149 | + libc::memcpy(data, source.as_ptr().cast(), source_len as usize); |
| 150 | + } |
| 151 | + true |
| 152 | + } |
| 153 | + |
| 154 | + pub(super) fn capacity(&self) -> usize { |
| 155 | + self.buffer.len() |
| 156 | + } |
| 157 | + |
| 158 | + pub(super) fn is_empty(&self) -> bool { |
| 159 | + self.length == 0 |
| 160 | + } |
| 161 | + |
| 162 | + pub(super) fn len(&self) -> usize { |
| 163 | + self.length |
| 164 | + } |
| 165 | + |
| 166 | + pub(super) fn messages<T>(&self) -> Messages<'_, T> { |
| 167 | + Messages { buffer: &self.buffer[..self.length], current: None, phantom: PhantomData {} } |
| 168 | + } |
| 169 | + |
| 170 | + pub(super) fn truncated(&self) -> bool { |
| 171 | + self.truncated |
| 172 | + } |
| 173 | + |
| 174 | + pub(super) fn clear(&mut self) { |
| 175 | + self.length = 0; |
| 176 | + self.truncated = false; |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +pub(super) struct AncillaryDataIter<'a, T> { |
| 181 | + data: &'a [u8], |
| 182 | + phantom: PhantomData<T>, |
| 183 | +} |
| 184 | + |
| 185 | +impl<'a, T> AncillaryDataIter<'a, T> { |
| 186 | + /// Create `AncillaryDataIter` struct to iterate through the data unit in the control message. |
| 187 | + /// |
| 188 | + /// # Safety |
| 189 | + /// |
| 190 | + /// `data` must contain a valid control message. |
| 191 | + pub(super) unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> { |
| 192 | + AncillaryDataIter { data, phantom: PhantomData } |
| 193 | + } |
| 194 | +} |
| 195 | + |
| 196 | +impl<'a, T> Iterator for AncillaryDataIter<'a, T> { |
| 197 | + type Item = T; |
| 198 | + |
| 199 | + fn next(&mut self) -> Option<T> { |
| 200 | + if size_of::<T>() <= self.data.len() { |
| 201 | + unsafe { |
| 202 | + let unit = read_unaligned(self.data.as_ptr().cast()); |
| 203 | + self.data = &self.data[size_of::<T>()..]; |
| 204 | + Some(unit) |
| 205 | + } |
| 206 | + } else { |
| 207 | + None |
| 208 | + } |
| 209 | + } |
| 210 | +} |
| 211 | + |
| 212 | +/// The error type which is returned from parsing the type a control message. |
| 213 | +#[non_exhaustive] |
| 214 | +#[derive(Debug)] |
| 215 | +#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")] |
| 216 | +pub enum AncillaryError { |
| 217 | + Unknown { cmsg_level: i32, cmsg_type: i32 }, |
| 218 | +} |
| 219 | + |
| 220 | +/// Return the data of `cmsghdr` as a `u8` slice. |
| 221 | +pub(super) unsafe fn get_data_from_cmsghdr(cmsg: &libc::cmsghdr) -> &[u8] { |
| 222 | + let cmsg_len_zero = libc::CMSG_LEN(0) as usize; |
| 223 | + let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero; |
| 224 | + let data = libc::CMSG_DATA(cmsg).cast(); |
| 225 | + from_raw_parts(data, data_len) |
| 226 | +} |
| 227 | + |
| 228 | +/// This struct is used to iterate through the control messages. |
| 229 | +#[unstable(feature = "unix_socket_ancillary_data", issue = "76915")] |
| 230 | +pub struct Messages<'a, T> { |
| 231 | + buffer: &'a [u8], |
| 232 | + current: Option<&'a libc::cmsghdr>, |
| 233 | + phantom: PhantomData<T>, |
| 234 | +} |
| 235 | + |
| 236 | +impl<'a, T> Messages<'a, T> { |
| 237 | + pub(super) unsafe fn next_cmsghdr(&mut self) -> Option<&'a libc::cmsghdr> { |
| 238 | + let mut msg: libc::msghdr = zeroed(); |
| 239 | + msg.msg_control = self.buffer.as_ptr() as *mut _; |
| 240 | + msg.msg_controllen = self.buffer.len() as _; |
| 241 | + |
| 242 | + let cmsg = if let Some(current) = self.current { |
| 243 | + libc::CMSG_NXTHDR(&msg, current) |
| 244 | + } else { |
| 245 | + libc::CMSG_FIRSTHDR(&msg) |
| 246 | + }; |
| 247 | + |
| 248 | + let cmsg = cmsg.as_ref()?; |
| 249 | + |
| 250 | + // Most operating systems, but not Linux or emscripten, return the previous pointer |
| 251 | + // when its length is zero. Therefore, check if the previous pointer is the same as |
| 252 | + // the current one. |
| 253 | + if let Some(current) = self.current { |
| 254 | + if eq(current, cmsg) { |
| 255 | + return None; |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + self.current = Some(cmsg); |
| 260 | + Some(cmsg) |
| 261 | + } |
| 262 | +} |
0 commit comments