Skip to content

Commit 0c75fb5

Browse files
ghananiganscathay4t
authored andcommitted
Support full done message
Currently, done messages are treated as having a zero-sized payload but they are expected to have at least a 4 byte payload holding an error code. This causes issues when serializing done messages that are consumed applications expecting conformant done messages. See https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types for details. Fixes #11
1 parent 96b3136 commit 0c75fb5

File tree

4 files changed

+221
-12
lines changed

4 files changed

+221
-12
lines changed

src/done.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// SPDX-License-Identifier: MIT
2+
3+
use std::mem::size_of;
4+
5+
use byteorder::{ByteOrder, NativeEndian};
6+
use netlink_packet_utils::DecodeError;
7+
8+
use crate::{Emitable, Field, Parseable, Rest};
9+
10+
const CODE: Field = 0..4;
11+
const EXTENDED_ACK: Rest = 4..;
12+
const DONE_HEADER_LEN: usize = EXTENDED_ACK.start;
13+
14+
#[derive(Debug, PartialEq, Eq, Clone)]
15+
#[non_exhaustive]
16+
pub struct DoneBuffer<T> {
17+
buffer: T,
18+
}
19+
20+
impl<T: AsRef<[u8]>> DoneBuffer<T> {
21+
pub fn new(buffer: T) -> DoneBuffer<T> {
22+
DoneBuffer { buffer }
23+
}
24+
25+
/// Consume the packet, returning the underlying buffer.
26+
pub fn into_inner(self) -> T {
27+
self.buffer
28+
}
29+
30+
pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
31+
let packet = Self::new(buffer);
32+
packet.check_buffer_length()?;
33+
Ok(packet)
34+
}
35+
36+
fn check_buffer_length(&self) -> Result<(), DecodeError> {
37+
let len = self.buffer.as_ref().len();
38+
if len < DONE_HEADER_LEN {
39+
Err(format!(
40+
"invalid DoneBuffer: length is {len} but DoneBuffer are \
41+
at least {DONE_HEADER_LEN} bytes"
42+
)
43+
.into())
44+
} else {
45+
Ok(())
46+
}
47+
}
48+
49+
/// Return the error code
50+
pub fn code(&self) -> i32 {
51+
let data = self.buffer.as_ref();
52+
NativeEndian::read_i32(&data[CODE])
53+
}
54+
}
55+
56+
impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> {
57+
/// Return a pointer to the extended ack attributes.
58+
pub fn extended_ack(&self) -> &'a [u8] {
59+
let data = self.buffer.as_ref();
60+
&data[EXTENDED_ACK]
61+
}
62+
}
63+
64+
impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> {
65+
/// Return a mutable pointer to the extended ack attributes.
66+
pub fn extended_ack_mut(&mut self) -> &mut [u8] {
67+
let data = self.buffer.as_mut();
68+
&mut data[EXTENDED_ACK]
69+
}
70+
}
71+
72+
impl<T: AsRef<[u8]> + AsMut<[u8]>> DoneBuffer<T> {
73+
/// set the error code field
74+
pub fn set_code(&mut self, value: i32) {
75+
let data = self.buffer.as_mut();
76+
NativeEndian::write_i32(&mut data[CODE], value)
77+
}
78+
}
79+
80+
#[derive(Debug, Default, Clone, PartialEq, Eq)]
81+
#[non_exhaustive]
82+
pub struct DoneMessage {
83+
pub code: i32,
84+
pub extended_ack: Vec<u8>,
85+
}
86+
87+
impl Emitable for DoneMessage {
88+
fn buffer_len(&self) -> usize {
89+
size_of::<i32>() + self.extended_ack.len()
90+
}
91+
fn emit(&self, buffer: &mut [u8]) {
92+
let mut buffer = DoneBuffer::new(buffer);
93+
buffer.set_code(self.code);
94+
buffer
95+
.extended_ack_mut()
96+
.copy_from_slice(&self.extended_ack);
97+
}
98+
}
99+
100+
impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<DoneBuffer<&'buffer T>>
101+
for DoneMessage
102+
{
103+
fn parse(buf: &DoneBuffer<&'buffer T>) -> Result<DoneMessage, DecodeError> {
104+
Ok(DoneMessage {
105+
code: buf.code(),
106+
extended_ack: buf.extended_ack().to_vec(),
107+
})
108+
}
109+
}
110+
111+
#[cfg(test)]
112+
mod tests {
113+
use super::*;
114+
115+
#[test]
116+
fn serialize_and_parse() {
117+
let mut expected = DoneMessage::default();
118+
expected.code = 5;
119+
expected.extended_ack = vec![1, 2, 3];
120+
121+
let len = expected.buffer_len();
122+
assert_eq!(len, size_of::<i32>() + expected.extended_ack.len());
123+
124+
let mut buf = vec![0; len];
125+
expected.emit(&mut buf);
126+
127+
let done_buf = DoneBuffer::new(&buf);
128+
assert_eq!(done_buf.code(), expected.code);
129+
assert_eq!(done_buf.extended_ack(), &expected.extended_ack);
130+
131+
let got = DoneMessage::parse(&done_buf).unwrap();
132+
assert_eq!(got, expected);
133+
}
134+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ pub(crate) type Field = Range<usize>;
244244
/// Represent a field that starts at a given index in a packet
245245
pub(crate) type Rest = RangeFrom<usize>;
246246

247+
pub mod done;
248+
pub use self::done::*;
249+
247250
pub mod error;
248251
pub use self::error::*;
249252

src/message.rs

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use netlink_packet_utils::DecodeError;
77

88
use crate::{
99
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
10-
AckMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
11-
NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable,
12-
Parseable,
10+
AckMessage, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
11+
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
12+
NetlinkSerializable, Parseable,
1313
};
1414

1515
/// Represent a netlink message.
@@ -98,9 +98,8 @@ where
9898
let bytes = buf.payload();
9999
let payload = match header.message_type {
100100
NLMSG_ERROR => {
101-
let buf = ErrorBuffer::new_checked(&bytes)
102-
.context("failed to parse NLMSG_ERROR")?;
103-
let msg = ErrorMessage::parse(&buf)
101+
let msg = ErrorBuffer::new_checked(&bytes)
102+
.and_then(|buf| ErrorMessage::parse(&buf))
104103
.context("failed to parse NLMSG_ERROR")?;
105104
if msg.code >= 0 {
106105
Ack(msg as AckMessage)
@@ -109,7 +108,12 @@ where
109108
}
110109
}
111110
NLMSG_NOOP => Noop,
112-
NLMSG_DONE => Done,
111+
NLMSG_DONE => {
112+
let msg = DoneBuffer::new_checked(&bytes)
113+
.and_then(|buf| DoneMessage::parse(&buf))
114+
.context("failed to parse NLMSG_DONE")?;
115+
Done(msg)
116+
}
113117
NLMSG_OVERRUN => Overrun(bytes.to_vec()),
114118
message_type => {
115119
let inner_msg = I::deserialize(&header, bytes).context(
@@ -130,7 +134,8 @@ where
130134
use self::NetlinkPayload::*;
131135

132136
let payload_len = match self.payload {
133-
Noop | Done => 0,
137+
Noop => 0,
138+
Done(ref msg) => msg.buffer_len(),
134139
Overrun(ref bytes) => bytes.len(),
135140
Error(ref msg) => msg.buffer_len(),
136141
Ack(ref msg) => msg.buffer_len(),
@@ -148,7 +153,8 @@ where
148153
let buffer =
149154
&mut buffer[self.header.buffer_len()..self.header.length as usize];
150155
match self.payload {
151-
Noop | Done => {}
156+
Noop => {}
157+
Done(ref msg) => msg.emit(buffer),
152158
Overrun(ref bytes) => buffer.copy_from_slice(bytes),
153159
Error(ref msg) => msg.emit(buffer),
154160
Ack(ref msg) => msg.emit(buffer),
@@ -168,3 +174,69 @@ where
168174
}
169175
}
170176
}
177+
178+
#[cfg(test)]
179+
mod tests {
180+
use super::*;
181+
182+
use std::{convert::Infallible, mem::size_of};
183+
184+
#[derive(Clone, Debug, Default, PartialEq)]
185+
struct FakeNetlinkInnerMessage;
186+
187+
impl NetlinkSerializable for FakeNetlinkInnerMessage {
188+
fn message_type(&self) -> u16 {
189+
unimplemented!("unused by tests")
190+
}
191+
192+
fn buffer_len(&self) -> usize {
193+
unimplemented!("unused by tests")
194+
}
195+
196+
fn serialize(&self, _buffer: &mut [u8]) {
197+
unimplemented!("unused by tests")
198+
}
199+
}
200+
201+
impl NetlinkDeserializable for FakeNetlinkInnerMessage {
202+
type Error = Infallible;
203+
204+
fn deserialize(
205+
_header: &NetlinkHeader,
206+
_payload: &[u8],
207+
) -> Result<Self, Self::Error> {
208+
unimplemented!("unused by tests")
209+
}
210+
}
211+
212+
#[test]
213+
fn test_done() {
214+
let header = NetlinkHeader::default();
215+
let mut done_msg = DoneMessage::default();
216+
done_msg.code = 0;
217+
done_msg.extended_ack = vec![6, 7, 8, 9];
218+
let mut want = NetlinkMessage::new(
219+
header,
220+
NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
221+
);
222+
want.finalize();
223+
224+
let len = want.buffer_len();
225+
assert_eq!(
226+
len,
227+
header.buffer_len()
228+
+ size_of::<i32>()
229+
+ done_msg.extended_ack.len()
230+
);
231+
232+
let mut buf = vec![1; len];
233+
want.emit(&mut buf);
234+
235+
let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
236+
assert_eq!(done_buf.code(), done_msg.code);
237+
assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
238+
239+
let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
240+
assert_eq!(got, want);
241+
}
242+
}

src/payload.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use std::fmt::Debug;
44

5-
use crate::{AckMessage, ErrorMessage, NetlinkSerializable};
5+
use crate::{AckMessage, DoneMessage, ErrorMessage, NetlinkSerializable};
66

77
/// The message is ignored.
88
pub const NLMSG_NOOP: u16 = 1;
@@ -18,7 +18,7 @@ pub const NLMSG_ALIGNTO: u16 = 4;
1818
#[derive(Debug, PartialEq, Eq, Clone)]
1919
#[non_exhaustive]
2020
pub enum NetlinkPayload<I> {
21-
Done,
21+
Done(DoneMessage),
2222
Error(ErrorMessage),
2323
Ack(AckMessage),
2424
Noop,
@@ -32,7 +32,7 @@ where
3232
{
3333
pub fn message_type(&self) -> u16 {
3434
match self {
35-
NetlinkPayload::Done => NLMSG_DONE,
35+
NetlinkPayload::Done(_) => NLMSG_DONE,
3636
NetlinkPayload::Error(_) | NetlinkPayload::Ack(_) => NLMSG_ERROR,
3737
NetlinkPayload::Noop => NLMSG_NOOP,
3838
NetlinkPayload::Overrun(_) => NLMSG_OVERRUN,

0 commit comments

Comments
 (0)