Skip to content

Commit 9de49c2

Browse files
authored
Merge pull request #143 from MathiasKoch/enhancement/cancel-safe-read
enhancement(async): Make RecordReaders read fn cancel-safe
2 parents f48952f + ff80fdc commit 9de49c2

File tree

5 files changed

+53
-53
lines changed

5 files changed

+53
-53
lines changed

src/asynch.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ where
2929
delegate: Socket,
3030
opened: bool,
3131
key_schedule: KeySchedule<CipherSuite>,
32-
record_reader: RecordReader<'a, CipherSuite>,
32+
record_reader: RecordReader<'a>,
3333
record_write_buf: WriteBuffer<'a>,
3434
decrypted: DecryptedBufferInfo,
3535
}
@@ -365,7 +365,7 @@ where
365365
state: State,
366366
delegate: Socket,
367367
key_schedule: ReadKeySchedule<CipherSuite>,
368-
record_reader: RecordReader<'a, CipherSuite>,
368+
record_reader: RecordReader<'a>,
369369
decrypted: DecryptedBufferInfo,
370370
}
371371

src/blocking.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ where
2828
delegate: Socket,
2929
opened: bool,
3030
key_schedule: KeySchedule<CipherSuite>,
31-
record_reader: RecordReader<'a, CipherSuite>,
31+
record_reader: RecordReader<'a>,
3232
record_write_buf: WriteBuffer<'a>,
3333
decrypted: DecryptedBufferInfo,
3434
}
@@ -356,7 +356,7 @@ where
356356
state: State,
357357
delegate: Socket,
358358
key_schedule: ReadKeySchedule<CipherSuite>,
359-
record_reader: RecordReader<'a, CipherSuite>,
359+
record_reader: RecordReader<'a>,
360360
decrypted: DecryptedBufferInfo,
361361
}
362362

src/connection.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ impl<'a> State {
168168
self,
169169
transport: &mut Transport,
170170
handshake: &mut Handshake<Provider::CipherSuite>,
171-
record_reader: &mut RecordReader<'_, Provider::CipherSuite>,
171+
record_reader: &mut RecordReader<'_>,
172172
tx_buf: &mut WriteBuffer<'_>,
173173
key_schedule: &mut KeySchedule<Provider::CipherSuite>,
174174
config: &TlsConfig<'a>,
@@ -237,7 +237,7 @@ impl<'a> State {
237237
self,
238238
transport: &mut Transport,
239239
handshake: &mut Handshake<Provider::CipherSuite>,
240-
record_reader: &mut RecordReader<'_, Provider::CipherSuite>,
240+
record_reader: &mut RecordReader<'_>,
241241
tx_buf: &mut WriteBuffer,
242242
key_schedule: &mut KeySchedule<Provider::CipherSuite>,
243243
config: &TlsConfig<'a>,

src/record.rs

+2
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ pub struct RecordHeader {
171171
}
172172

173173
impl RecordHeader {
174+
pub const LEN: usize = 5;
175+
174176
pub fn content_type(&self) -> ContentType {
175177
// Content type already validated in read
176178
unwrap!(ContentType::of(self.header[0]))

src/record_reader.rs

+45-47
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use core::marker::PhantomData;
2-
31
use crate::key_schedule::ReadKeySchedule;
42
use embedded_io::{Error, Read as BlockingRead};
53
use embedded_io_async::Read as AsyncRead;
@@ -10,22 +8,15 @@ use crate::{
108
TlsError,
119
};
1210

13-
pub struct RecordReader<'a, CipherSuite>
14-
where
15-
CipherSuite: TlsCipherSuite,
16-
{
11+
pub struct RecordReader<'a> {
1712
pub(crate) buf: &'a mut [u8],
1813
/// The number of decoded bytes in the buffer
1914
decoded: usize,
2015
/// The number of read but not yet decoded bytes in the buffer
2116
pending: usize,
22-
cipher_suite: PhantomData<CipherSuite>,
2317
}
2418

25-
impl<'a, CipherSuite> RecordReader<'a, CipherSuite>
26-
where
27-
CipherSuite: TlsCipherSuite,
28-
{
19+
impl<'a> RecordReader<'a> {
2920
pub fn new(buf: &'a mut [u8]) -> Self {
3021
if buf.len() < 16640 {
3122
warn!("Read buffer is smaller than 16640 bytes, which may cause problems!");
@@ -34,33 +25,26 @@ where
3425
buf,
3526
decoded: 0,
3627
pending: 0,
37-
cipher_suite: PhantomData,
3828
}
3929
}
4030

41-
pub async fn read<'m>(
31+
pub async fn read<'m, CipherSuite: TlsCipherSuite>(
4232
&'m mut self,
4333
transport: &mut impl AsyncRead,
4434
key_schedule: &mut ReadKeySchedule<CipherSuite>,
4535
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
46-
let header = self.advance(transport, 5).await?;
47-
let header = RecordHeader::decode(unwrap!(header.try_into().ok()))?;
48-
49-
let content_length = header.content_length();
50-
debug!(
51-
"advance: {:?} - content_length = {} bytes",
52-
header.content_type(),
53-
content_length
54-
);
55-
let data = self.advance(transport, content_length).await?;
56-
ServerRecord::decode(header, data, key_schedule.transcript_hash())
36+
self.advance(transport, RecordHeader::LEN).await?;
37+
let header = self.record_header()?;
38+
self.advance(transport, RecordHeader::LEN + header.content_length())
39+
.await?;
40+
self.consume(header, key_schedule.transcript_hash())
5741
}
5842

5943
async fn advance<'m>(
6044
&'m mut self,
6145
transport: &mut impl AsyncRead,
6246
amount: usize,
63-
) -> Result<&'m mut [u8], TlsError> {
47+
) -> Result<(), TlsError> {
6448
self.ensure_contiguous(amount)?;
6549

6650
while self.pending < amount {
@@ -74,27 +58,25 @@ where
7458
self.pending += read;
7559
}
7660

77-
Ok(self.consume(amount))
61+
Ok(())
7862
}
7963

80-
pub fn read_blocking<'m>(
64+
pub fn read_blocking<'m, CipherSuite: TlsCipherSuite>(
8165
&'m mut self,
8266
transport: &mut impl BlockingRead,
8367
key_schedule: &mut ReadKeySchedule<CipherSuite>,
8468
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
85-
let header = self.advance_blocking(transport, 5)?;
86-
let header = RecordHeader::decode(unwrap!(header.try_into().ok()))?;
87-
88-
let content_length = header.content_length();
89-
let data = self.advance_blocking(transport, content_length)?;
90-
ServerRecord::decode(header, data, key_schedule.transcript_hash())
69+
self.advance_blocking(transport, RecordHeader::LEN)?;
70+
let header = self.record_header()?;
71+
self.advance_blocking(transport, RecordHeader::LEN + header.content_length())?;
72+
self.consume(header, key_schedule.transcript_hash())
9173
}
9274

9375
fn advance_blocking<'m>(
9476
&'m mut self,
9577
transport: &mut impl BlockingRead,
9678
amount: usize,
97-
) -> Result<&'m mut [u8], TlsError> {
79+
) -> Result<(), TlsError> {
9880
self.ensure_contiguous(amount)?;
9981

10082
while self.pending < amount {
@@ -107,14 +89,30 @@ where
10789
self.pending += read;
10890
}
10991

110-
Ok(self.consume(amount))
92+
Ok(())
11193
}
11294

113-
fn consume(&mut self, amount: usize) -> &mut [u8] {
114-
let slice = &mut self.buf[self.decoded..self.decoded + amount];
115-
self.decoded += amount;
116-
self.pending -= amount;
117-
slice
95+
fn record_header(&self) -> Result<RecordHeader, TlsError> {
96+
RecordHeader::decode(unwrap!(self.buf
97+
[self.decoded..self.decoded + RecordHeader::LEN]
98+
.try_into()
99+
.ok()))
100+
}
101+
102+
fn consume<'m, CipherSuite: TlsCipherSuite>(
103+
&'m mut self,
104+
header: RecordHeader,
105+
digest: &mut CipherSuite::Hash,
106+
) -> Result<ServerRecord<'m, CipherSuite>, TlsError> {
107+
let content_len = header.content_length();
108+
109+
let slice = &mut self.buf
110+
[self.decoded + RecordHeader::LEN..self.decoded + RecordHeader::LEN + content_len];
111+
112+
self.decoded += RecordHeader::LEN + content_len;
113+
self.pending -= RecordHeader::LEN + content_len;
114+
115+
ServerRecord::decode(header, slice, digest)
118116
}
119117

120118
fn ensure_contiguous(&mut self, len: usize) -> Result<(), TlsError> {
@@ -207,7 +205,7 @@ mod tests {
207205
);
208206

209207
let mut buf = [0; 32];
210-
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
208+
let mut reader = RecordReader::new(&mut buf);
211209
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
212210

213211
{
@@ -265,8 +263,8 @@ mod tests {
265263
]
266264
.as_slice();
267265

268-
let mut buf = [0; 5]; // This buffer is so small that it cannot contain both the header and data
269-
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
266+
let mut buf = [0; 9]; // This buffer is so small that it cannot contain both the header and data
267+
let mut reader = RecordReader::new(&mut buf);
270268
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
271269

272270
{
@@ -279,8 +277,8 @@ mod tests {
279277
panic!("Wrong server record");
280278
}
281279

282-
assert_eq!(4, reader.decoded); // The buffer is rotated after decoding the header
283-
assert_eq!(1, reader.pending);
280+
assert_eq!(9, reader.decoded); // The buffer is rotated after decoding the header
281+
assert_eq!(0, reader.pending);
284282
}
285283

286284
{
@@ -293,7 +291,7 @@ mod tests {
293291
panic!("Wrong server record");
294292
}
295293

296-
assert_eq!(2, reader.decoded);
294+
assert_eq!(7, reader.decoded);
297295
assert_eq!(0, reader.pending);
298296
}
299297
}
@@ -318,7 +316,7 @@ mod tests {
318316
.as_slice();
319317

320318
let mut buf = [0; 32];
321-
let mut reader = RecordReader::<Aes128GcmSha256>::new(&mut buf);
319+
let mut reader = RecordReader::new(&mut buf);
322320
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
323321

324322
{

0 commit comments

Comments
 (0)