Skip to content

Commit 208b80b

Browse files
authored
recvmsg: Check if CMSG buffer was too small and return an error (#2413)
If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they could have been truncated. Change RecvMsg::cmsgs() to return a Result, and to check for this flag (an API change). Update tests for API change. Add test for too-small buffer.
1 parent ecd12a9 commit 208b80b

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

changelog/2413.changed.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated.

src/sys/socket/mod.rs

+13-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t};
1313
#[cfg(all(feature = "uio", not(target_os = "redox")))]
1414
use libc::{
1515
c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE,
16+
MSG_CTRUNC,
1617
};
1718
#[cfg(not(target_os = "redox"))]
1819
use std::io::{IoSlice, IoSliceMut};
@@ -599,13 +600,19 @@ pub struct RecvMsg<'a, 's, S> {
599600
}
600601

601602
impl<'a, S> RecvMsg<'a, '_, S> {
602-
/// Iterate over the valid control messages pointed to by this
603-
/// msghdr.
604-
pub fn cmsgs(&self) -> CmsgIterator {
605-
CmsgIterator {
603+
/// Iterate over the valid control messages pointed to by this msghdr. If
604+
/// allocated space for CMSGs was too small it is not safe to iterate,
605+
/// instead return an `Error::ENOBUFS` error.
606+
pub fn cmsgs(&self) -> Result<CmsgIterator> {
607+
608+
if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC {
609+
return Err(Errno::ENOBUFS);
610+
}
611+
612+
Ok(CmsgIterator {
606613
cmsghdr: self.cmsghdr,
607614
mhdr: &self.mhdr
608-
}
615+
})
609616
}
610617
}
611618

@@ -700,7 +707,7 @@ pub enum ControlMessageOwned {
700707
/// let mut iov = [IoSliceMut::new(&mut buffer)];
701708
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
702709
/// .unwrap();
703-
/// let rtime = match r.cmsgs().next() {
710+
/// let rtime = match r.cmsgs().unwrap().next() {
704711
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
705712
/// Some(_) => panic!("Unexpected control message"),
706713
/// None => panic!("No control message")

test/sys/test_socket.rs

+39-25
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub fn test_timestamping() {
5555
.unwrap();
5656

5757
let mut ts = None;
58-
for c in recv.cmsgs() {
58+
for c in recv.cmsgs().unwrap() {
5959
if let ControlMessageOwned::ScmTimestampsns(timestamps) = c {
6060
ts = Some(timestamps.system);
6161
}
@@ -117,7 +117,7 @@ pub fn test_timestamping_realtime() {
117117
.unwrap();
118118

119119
let mut ts = None;
120-
for c in recv.cmsgs() {
120+
for c in recv.cmsgs().unwrap() {
121121
if let ControlMessageOwned::ScmRealtime(timeval) = c {
122122
ts = Some(timeval);
123123
}
@@ -179,7 +179,7 @@ pub fn test_timestamping_monotonic() {
179179
.unwrap();
180180

181181
let mut ts = None;
182-
for c in recv.cmsgs() {
182+
for c in recv.cmsgs().unwrap() {
183183
if let ControlMessageOwned::ScmMonotonic(timeval) = c {
184184
ts = Some(timeval);
185185
}
@@ -889,7 +889,7 @@ pub fn test_scm_rights() {
889889
)
890890
.unwrap();
891891

892-
for cmsg in msg.cmsgs() {
892+
for cmsg in msg.cmsgs().unwrap() {
893893
if let ControlMessageOwned::ScmRights(fd) = cmsg {
894894
assert_eq!(received_r, None);
895895
assert_eq!(fd.len(), 1);
@@ -1330,7 +1330,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() {
13301330
.flags
13311331
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
13321332

1333-
let mut cmsgs = msg.cmsgs();
1333+
let mut cmsgs = msg.cmsgs().unwrap();
13341334
match cmsgs.next() {
13351335
Some(ControlMessageOwned::ScmRights(fds)) => {
13361336
assert_eq!(
@@ -1399,7 +1399,7 @@ pub fn test_sendmsg_empty_cmsgs() {
13991399
)
14001400
.unwrap();
14011401

1402-
if msg.cmsgs().next().is_some() {
1402+
if msg.cmsgs().unwrap().next().is_some() {
14031403
panic!("unexpected cmsg");
14041404
}
14051405
assert!(!msg
@@ -1466,7 +1466,7 @@ fn test_scm_credentials() {
14661466
.unwrap();
14671467
let mut received_cred = None;
14681468

1469-
for cmsg in msg.cmsgs() {
1469+
for cmsg in msg.cmsgs().unwrap() {
14701470
let cred = match cmsg {
14711471
#[cfg(linux_android)]
14721472
ControlMessageOwned::ScmCredentials(cred) => cred,
@@ -1497,7 +1497,7 @@ fn test_scm_credentials() {
14971497
#[test]
14981498
fn test_scm_credentials_and_rights() {
14991499
let space = cmsg_space!(libc::ucred, RawFd);
1500-
test_impl_scm_credentials_and_rights(space);
1500+
test_impl_scm_credentials_and_rights(space).unwrap();
15011501
}
15021502

15031503
/// Ensure that passing a an oversized control message buffer to recvmsg
@@ -1509,11 +1509,23 @@ fn test_scm_credentials_and_rights() {
15091509
#[test]
15101510
fn test_too_large_cmsgspace() {
15111511
let space = vec![0u8; 1024];
1512-
test_impl_scm_credentials_and_rights(space);
1512+
test_impl_scm_credentials_and_rights(space).unwrap();
15131513
}
15141514

15151515
#[cfg(linux_android)]
1516-
fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
1516+
#[test]
1517+
fn test_too_small_cmsgspace() {
1518+
let space = vec![0u8; 4];
1519+
assert_eq!(
1520+
test_impl_scm_credentials_and_rights(space),
1521+
Err(nix::errno::Errno::ENOBUFS)
1522+
);
1523+
}
1524+
1525+
#[cfg(linux_android)]
1526+
fn test_impl_scm_credentials_and_rights(
1527+
mut space: Vec<u8>,
1528+
) -> Result<(), nix::errno::Errno> {
15171529
use libc::ucred;
15181530
use nix::sys::socket::sockopt::PassCred;
15191531
use nix::sys::socket::{
@@ -1573,9 +1585,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
15731585
.unwrap();
15741586
let mut received_cred = None;
15751587

1576-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1588+
assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs");
15771589

1578-
for cmsg in msg.cmsgs() {
1590+
for cmsg in msg.cmsgs()? {
15791591
match cmsg {
15801592
ControlMessageOwned::ScmRights(fds) => {
15811593
assert_eq!(received_r, None, "already received fd");
@@ -1606,6 +1618,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
16061618
read(received_r.as_raw_fd(), &mut buf).unwrap();
16071619
assert_eq!(&buf[..], b"world");
16081620
close(received_r).unwrap();
1621+
1622+
Ok(())
16091623
}
16101624

16111625
// Test creating and using named unix domain sockets
@@ -1837,7 +1851,7 @@ pub fn test_recv_ipv4pktinfo() {
18371851
.flags
18381852
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
18391853

1840-
let mut cmsgs = msg.cmsgs();
1854+
let mut cmsgs = msg.cmsgs().unwrap();
18411855
if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next()
18421856
{
18431857
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -1929,11 +1943,11 @@ pub fn test_recvif() {
19291943
assert!(!msg
19301944
.flags
19311945
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
1932-
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
1946+
assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs");
19331947

19341948
let mut rx_recvif = false;
19351949
let mut rx_recvdstaddr = false;
1936-
for cmsg in msg.cmsgs() {
1950+
for cmsg in msg.cmsgs().unwrap() {
19371951
match cmsg {
19381952
ControlMessageOwned::Ipv4RecvIf(dl) => {
19391953
rx_recvif = true;
@@ -2027,10 +2041,10 @@ pub fn test_recvif_ipv4() {
20272041
assert!(!msg
20282042
.flags
20292043
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2030-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2044+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
20312045

20322046
let mut rx_recvorigdstaddr = false;
2033-
for cmsg in msg.cmsgs() {
2047+
for cmsg in msg.cmsgs().unwrap() {
20342048
match cmsg {
20352049
ControlMessageOwned::Ipv4OrigDstAddr(addr) => {
20362050
rx_recvorigdstaddr = true;
@@ -2113,10 +2127,10 @@ pub fn test_recvif_ipv6() {
21132127
assert!(!msg
21142128
.flags
21152129
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
2116-
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
2130+
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");
21172131

21182132
let mut rx_recvorigdstaddr = false;
2119-
for cmsg in msg.cmsgs() {
2133+
for cmsg in msg.cmsgs().unwrap() {
21202134
match cmsg {
21212135
ControlMessageOwned::Ipv6OrigDstAddr(addr) => {
21222136
rx_recvorigdstaddr = true;
@@ -2214,7 +2228,7 @@ pub fn test_recv_ipv6pktinfo() {
22142228
.flags
22152229
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
22162230

2217-
let mut cmsgs = msg.cmsgs();
2231+
let mut cmsgs = msg.cmsgs().unwrap();
22182232
if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next()
22192233
{
22202234
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
@@ -2357,7 +2371,7 @@ fn test_recvmsg_timestampns() {
23572371
flags,
23582372
)
23592373
.unwrap();
2360-
let rtime = match r.cmsgs().next() {
2374+
let rtime = match r.cmsgs().unwrap().next() {
23612375
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
23622376
Some(_) => panic!("Unexpected control message"),
23632377
None => panic!("No control message"),
@@ -2418,7 +2432,7 @@ fn test_recvmmsg_timestampns() {
24182432
)
24192433
.unwrap()
24202434
.collect();
2421-
let rtime = match r[0].cmsgs().next() {
2435+
let rtime = match r[0].cmsgs().unwrap().next() {
24222436
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
24232437
Some(_) => panic!("Unexpected control message"),
24242438
None => panic!("No control message"),
@@ -2508,7 +2522,7 @@ fn test_recvmsg_rxq_ovfl() {
25082522
MsgFlags::MSG_DONTWAIT,
25092523
) {
25102524
Ok(r) => {
2511-
drop_counter = match r.cmsgs().next() {
2525+
drop_counter = match r.cmsgs().unwrap().next() {
25122526
Some(ControlMessageOwned::RxqOvfl(drop_counter)) => {
25132527
drop_counter
25142528
}
@@ -2687,7 +2701,7 @@ mod linux_errqueue {
26872701
assert_eq!(msg.address, Some(sock_addr));
26882702

26892703
// Check for expected control message.
2690-
let ext_err = match msg.cmsgs().next() {
2704+
let ext_err = match msg.cmsgs().unwrap().next() {
26912705
Some(cmsg) => testf(&cmsg),
26922706
None => panic!("No control message"),
26932707
};
@@ -2878,7 +2892,7 @@ fn test_recvmm2() -> nix::Result<()> {
28782892
#[cfg(not(any(qemu, target_arch = "aarch64")))]
28792893
let mut saw_time = false;
28802894
let mut recvd = 0;
2881-
for cmsg in rmsg.cmsgs() {
2895+
for cmsg in rmsg.cmsgs().unwrap() {
28822896
if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg {
28832897
let ts = timestamps.system;
28842898

0 commit comments

Comments
 (0)