-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlib.rs
269 lines (233 loc) · 8.76 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
//! Message Serialization/Deserialization (Protocol) for client <-> server communication
//!
//! Ideally you would use some existing Serialization/Deserialization,
//! but this is here to see what's going on under the hood.
//!
//! ## Libraries for serialization/deserialization:
//! [Serde](https://docs.rs/serde/1.0.114/serde/index.html)
//! [tokio_util::codec](https://docs.rs/tokio-util/0.3.1/tokio_util/codec/index.html)
//! [bincode](https://github.com/servo/bincode)
use std::convert::From;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, TcpStream};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
pub const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000";
/// Trait for something that can be converted to bytes (&[u8])
pub trait Serialize {
/// Serialize to a `Write`able buffer
fn serialize(&self, buf: &mut impl Write) -> io::Result<usize>;
}
/// Trait for something that can be converted from bytes (&[u8])
pub trait Deserialize {
/// The type that this deserializes to
type Output;
/// Deserialize from a `Read`able buffer
fn deserialize(buf: &mut impl Read) -> io::Result<Self::Output>;
}
/// Request object (client -> server)
#[derive(Debug)]
pub enum Request {
/// Echo a message back
Echo(String),
/// Jumble up a message with given amount of entropy before echoing
Jumble { message: String, amount: u16 },
}
/// Encode the Request type as a single byte (as long as we don't exceed 255 types)
///
/// We use `&Request` since we don't actually need to own or mutate the request fields
impl From<&Request> for u8 {
fn from(req: &Request) -> Self {
match req {
Request::Echo(_) => 1,
Request::Jumble { .. } => 2,
}
}
}
/// Message format for Request is:
/// ```ignore
/// | u8 | u16 | [u8] | ... u16 | ... [u8] |
/// | type | length | value bytes | ... length | ... value bytes |
/// ```
///
/// Starts with a type, and then is an arbitrary length of (length/bytes) tuples
impl Request {
/// View the message portion of this request
pub fn message(&self) -> &str {
match self {
Request::Echo(message) => &message,
Request::Jumble { message, .. } => &message,
}
}
}
impl Serialize for Request {
/// Serialize Request to bytes (to send to server)
fn serialize(&self, buf: &mut impl Write) -> io::Result<usize> {
buf.write_u8(self.into())?; // Message Type byte
let mut bytes_written: usize = 1;
match self {
Request::Echo(message) => {
// Write the variable length message string, preceded by it's length
let message = message.as_bytes();
buf.write_u16::<NetworkEndian>(message.len() as u16)?;
buf.write_all(&message)?;
bytes_written += 2 + message.len();
}
Request::Jumble { message, amount } => {
// Write the variable length message string, preceded by it's length
let message_bytes = message.as_bytes();
buf.write_u16::<NetworkEndian>(message_bytes.len() as u16)?;
buf.write_all(&message_bytes)?;
bytes_written += 2 + message.len();
// We know that `amount` is always 2 bytes long, but are adding
// the length here to stay consistent
buf.write_u16::<NetworkEndian>(2)?;
buf.write_u16::<NetworkEndian>(*amount)?;
bytes_written += 4;
}
}
Ok(bytes_written)
}
}
impl Deserialize for Request {
type Output = Request;
/// Deserialize Request from bytes (to receive from TcpStream)
fn deserialize(mut buf: &mut impl Read) -> io::Result<Self::Output> {
match buf.read_u8()? {
// Echo
1 => Ok(Request::Echo(extract_string(&mut buf)?)),
// Jumble
2 => {
let message = extract_string(&mut buf)?;
let _amount_len = buf.read_u16::<NetworkEndian>()?;
let amount = buf.read_u16::<NetworkEndian>()?;
Ok(Request::Jumble { message, amount })
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Request Type",
)),
}
}
}
/// Response object from server
///
/// In the real-world, this would likely be an enum as well to signal Success vs. Error
/// But since we're showing that capability with the `Request` struct, we'll keep this one simple
#[derive(Debug)]
pub struct Response(pub String);
/// Message format for Response is:
/// ```ignore
/// | u16 | [u8] |
/// | length | value bytes |
/// ```
///
impl Response {
/// Create a new response with a given message
pub fn new(message: String) -> Self {
Self(message)
}
/// Get the response message value
pub fn message(&self) -> &str {
&self.0
}
}
impl Serialize for Response {
/// Serialize Response to bytes (to send to client)
///
/// Returns the number of bytes written
fn serialize(&self, buf: &mut impl Write) -> io::Result<usize> {
let resp_bytes = self.0.as_bytes();
buf.write_u16::<NetworkEndian>(resp_bytes.len() as u16)?;
buf.write_all(&resp_bytes)?;
Ok(3 + resp_bytes.len()) // Type + len + bytes
}
}
impl Deserialize for Response {
type Output = Response;
/// Deserialize Response to bytes (to receive from server)
fn deserialize(mut buf: &mut impl Read) -> io::Result<Self::Output> {
let value = extract_string(&mut buf)?;
Ok(Response(value))
}
}
/// From a given readable buffer, read the next length (u16) and extract the string bytes
fn extract_string(buf: &mut impl Read) -> io::Result<String> {
// byteorder ReadBytesExt
let length = buf.read_u16::<NetworkEndian>()?;
// Given the length of our string, only read in that quantity of bytes
let mut bytes = vec![0u8; length as usize];
buf.read_exact(&mut bytes)?;
// And attempt to decode it as UTF8
String::from_utf8(bytes).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid utf8"))
}
/// Abstracted Protocol that wraps a TcpStream and manages
/// sending & receiving of messages
pub struct Protocol {
reader: io::BufReader<TcpStream>,
stream: TcpStream,
}
impl Protocol {
/// Wrap a TcpStream with Protocol
pub fn with_stream(stream: TcpStream) -> io::Result<Self> {
Ok(Self {
reader: io::BufReader::new(stream.try_clone()?),
stream,
})
}
/// Establish a connection, wrap stream in BufReader/Writer
pub fn connect(dest: SocketAddr) -> io::Result<Self> {
let stream = TcpStream::connect(dest)?;
eprintln!("Connecting to {}", dest);
Self::with_stream(stream)
}
/// Serialize a message to the server and write it to the TcpStream
pub fn send_message(&mut self, message: &impl Serialize) -> io::Result<()> {
message.serialize(&mut self.stream)?;
self.stream.flush()
}
/// Read a message from the inner TcpStream
///
/// NOTE: Will block until there's data to read (or deserialize fails with io::ErrorKind::Interrupted)
/// so only use when a message is expected to arrive
pub fn read_message<T: Deserialize>(&mut self) -> io::Result<T::Output> {
T::deserialize(&mut self.reader)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::io::Cursor;
#[test]
fn test_request_echo_roundtrip() {
let req = Request::Echo(String::from("Hello"));
let mut bytes: Vec<u8> = vec![];
req.serialize(&mut bytes).unwrap();
let mut reader = Cursor::new(bytes);
let roundtrip_req = Request::deserialize(&mut reader).unwrap();
assert!(matches!(roundtrip_req, Request::Echo(_)));
assert_eq!(roundtrip_req.message(), "Hello");
}
#[test]
fn test_request_jumble_roundtrip() {
let req = Request::Jumble {
message: String::from("Hello"),
amount: 42,
};
let mut bytes: Vec<u8> = vec![];
req.serialize(&mut bytes).unwrap();
let mut reader = Cursor::new(bytes);
let roundtrip_req = Request::deserialize(&mut reader).unwrap();
assert!(matches!(roundtrip_req, Request::Jumble { .. }));
assert_eq!(roundtrip_req.message(), "Hello");
}
#[test]
fn test_response_roundtrip() {
let resp = Response(String::from("Hello"));
let mut bytes: Vec<u8> = vec![];
resp.serialize(&mut bytes).unwrap();
let mut reader = Cursor::new(bytes);
let roundtrip_resp = Response::deserialize(&mut reader).unwrap();
assert!(matches!(roundtrip_resp, Response(_)));
assert_eq!(roundtrip_resp.0, "Hello");
}
}