1
- use core:: marker:: PhantomData ;
2
-
3
1
use crate :: key_schedule:: ReadKeySchedule ;
4
2
use embedded_io:: { Error , Read as BlockingRead } ;
5
3
use embedded_io_async:: Read as AsyncRead ;
@@ -10,22 +8,15 @@ use crate::{
10
8
TlsError ,
11
9
} ;
12
10
13
- pub struct RecordReader < ' a , CipherSuite >
14
- where
15
- CipherSuite : TlsCipherSuite ,
16
- {
11
+ pub struct RecordReader < ' a > {
17
12
pub ( crate ) buf : & ' a mut [ u8 ] ,
18
13
/// The number of decoded bytes in the buffer
19
14
decoded : usize ,
20
15
/// The number of read but not yet decoded bytes in the buffer
21
16
pending : usize ,
22
- cipher_suite : PhantomData < CipherSuite > ,
23
17
}
24
18
25
- impl < ' a , CipherSuite > RecordReader < ' a , CipherSuite >
26
- where
27
- CipherSuite : TlsCipherSuite ,
28
- {
19
+ impl < ' a > RecordReader < ' a > {
29
20
pub fn new ( buf : & ' a mut [ u8 ] ) -> Self {
30
21
if buf. len ( ) < 16640 {
31
22
warn ! ( "Read buffer is smaller than 16640 bytes, which may cause problems!" ) ;
@@ -34,33 +25,26 @@ where
34
25
buf,
35
26
decoded : 0 ,
36
27
pending : 0 ,
37
- cipher_suite : PhantomData ,
38
28
}
39
29
}
40
30
41
- pub async fn read < ' m > (
31
+ pub async fn read < ' m , CipherSuite : TlsCipherSuite > (
42
32
& ' m mut self ,
43
33
transport : & mut impl AsyncRead ,
44
34
key_schedule : & mut ReadKeySchedule < CipherSuite > ,
45
35
) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
46
- let header = self . advance ( transport, 5 ) . await ?;
47
- let header = RecordHeader :: decode ( unwrap ! ( header. try_into( ) . ok( ) ) ) ?;
48
-
49
- let content_length = header. content_length ( ) ;
50
- debug ! (
51
- "advance: {:?} - content_length = {} bytes" ,
52
- header. content_type( ) ,
53
- content_length
54
- ) ;
55
- let data = self . advance ( transport, content_length) . await ?;
56
- ServerRecord :: decode ( header, data, key_schedule. transcript_hash ( ) )
36
+ self . advance ( transport, RecordHeader :: LEN ) . await ?;
37
+ let header = self . record_header ( ) ?;
38
+ self . advance ( transport, RecordHeader :: LEN + header. content_length ( ) )
39
+ . await ?;
40
+ self . consume ( header, key_schedule. transcript_hash ( ) )
57
41
}
58
42
59
43
async fn advance < ' m > (
60
44
& ' m mut self ,
61
45
transport : & mut impl AsyncRead ,
62
46
amount : usize ,
63
- ) -> Result < & ' m mut [ u8 ] , TlsError > {
47
+ ) -> Result < ( ) , TlsError > {
64
48
self . ensure_contiguous ( amount) ?;
65
49
66
50
while self . pending < amount {
@@ -74,27 +58,25 @@ where
74
58
self . pending += read;
75
59
}
76
60
77
- Ok ( self . consume ( amount ) )
61
+ Ok ( ( ) )
78
62
}
79
63
80
- pub fn read_blocking < ' m > (
64
+ pub fn read_blocking < ' m , CipherSuite : TlsCipherSuite > (
81
65
& ' m mut self ,
82
66
transport : & mut impl BlockingRead ,
83
67
key_schedule : & mut ReadKeySchedule < CipherSuite > ,
84
68
) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
85
- let header = self . advance_blocking ( transport, 5 ) ?;
86
- let header = RecordHeader :: decode ( unwrap ! ( header. try_into( ) . ok( ) ) ) ?;
87
-
88
- let content_length = header. content_length ( ) ;
89
- let data = self . advance_blocking ( transport, content_length) ?;
90
- ServerRecord :: decode ( header, data, key_schedule. transcript_hash ( ) )
69
+ self . advance_blocking ( transport, RecordHeader :: LEN ) ?;
70
+ let header = self . record_header ( ) ?;
71
+ self . advance_blocking ( transport, RecordHeader :: LEN + header. content_length ( ) ) ?;
72
+ self . consume ( header, key_schedule. transcript_hash ( ) )
91
73
}
92
74
93
75
fn advance_blocking < ' m > (
94
76
& ' m mut self ,
95
77
transport : & mut impl BlockingRead ,
96
78
amount : usize ,
97
- ) -> Result < & ' m mut [ u8 ] , TlsError > {
79
+ ) -> Result < ( ) , TlsError > {
98
80
self . ensure_contiguous ( amount) ?;
99
81
100
82
while self . pending < amount {
@@ -107,14 +89,30 @@ where
107
89
self . pending += read;
108
90
}
109
91
110
- Ok ( self . consume ( amount ) )
92
+ Ok ( ( ) )
111
93
}
112
94
113
- fn consume ( & mut self , amount : usize ) -> & mut [ u8 ] {
114
- let slice = & mut self . buf [ self . decoded ..self . decoded + amount] ;
115
- self . decoded += amount;
116
- self . pending -= amount;
117
- slice
95
+ fn record_header ( & self ) -> Result < RecordHeader , TlsError > {
96
+ RecordHeader :: decode ( unwrap ! ( self . buf
97
+ [ self . decoded..self . decoded + RecordHeader :: LEN ]
98
+ . try_into( )
99
+ . ok( ) ) )
100
+ }
101
+
102
+ fn consume < ' m , CipherSuite : TlsCipherSuite > (
103
+ & ' m mut self ,
104
+ header : RecordHeader ,
105
+ digest : & mut CipherSuite :: Hash ,
106
+ ) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
107
+ let content_len = header. content_length ( ) ;
108
+
109
+ let slice = & mut self . buf
110
+ [ self . decoded + RecordHeader :: LEN ..self . decoded + RecordHeader :: LEN + content_len] ;
111
+
112
+ self . decoded += RecordHeader :: LEN + content_len;
113
+ self . pending -= RecordHeader :: LEN + content_len;
114
+
115
+ ServerRecord :: decode ( header, slice, digest)
118
116
}
119
117
120
118
fn ensure_contiguous ( & mut self , len : usize ) -> Result < ( ) , TlsError > {
@@ -207,7 +205,7 @@ mod tests {
207
205
) ;
208
206
209
207
let mut buf = [ 0 ; 32 ] ;
210
- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
208
+ let mut reader = RecordReader :: new ( & mut buf) ;
211
209
let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
212
210
213
211
{
@@ -265,8 +263,8 @@ mod tests {
265
263
]
266
264
. as_slice ( ) ;
267
265
268
- let mut buf = [ 0 ; 5 ] ; // This buffer is so small that it cannot contain both the header and data
269
- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
266
+ let mut buf = [ 0 ; 9 ] ; // This buffer is so small that it cannot contain both the header and data
267
+ let mut reader = RecordReader :: new ( & mut buf) ;
270
268
let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
271
269
272
270
{
@@ -279,8 +277,8 @@ mod tests {
279
277
panic ! ( "Wrong server record" ) ;
280
278
}
281
279
282
- assert_eq ! ( 4 , reader. decoded) ; // The buffer is rotated after decoding the header
283
- assert_eq ! ( 1 , reader. pending) ;
280
+ assert_eq ! ( 9 , reader. decoded) ; // The buffer is rotated after decoding the header
281
+ assert_eq ! ( 0 , reader. pending) ;
284
282
}
285
283
286
284
{
@@ -293,7 +291,7 @@ mod tests {
293
291
panic ! ( "Wrong server record" ) ;
294
292
}
295
293
296
- assert_eq ! ( 2 , reader. decoded) ;
294
+ assert_eq ! ( 7 , reader. decoded) ;
297
295
assert_eq ! ( 0 , reader. pending) ;
298
296
}
299
297
}
@@ -318,7 +316,7 @@ mod tests {
318
316
. as_slice ( ) ;
319
317
320
318
let mut buf = [ 0 ; 32 ] ;
321
- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
319
+ let mut reader = RecordReader :: new ( & mut buf) ;
322
320
let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
323
321
324
322
{
0 commit comments