1- use core:: marker:: PhantomData ;
2-
31use crate :: key_schedule:: ReadKeySchedule ;
42use embedded_io:: { Error , Read as BlockingRead } ;
53use embedded_io_async:: Read as AsyncRead ;
@@ -10,22 +8,15 @@ use crate::{
108 TlsError ,
119} ;
1210
13- pub struct RecordReader < ' a , CipherSuite >
14- where
15- CipherSuite : TlsCipherSuite ,
16- {
11+ pub struct RecordReader < ' a > {
1712 pub ( crate ) buf : & ' a mut [ u8 ] ,
1813 /// The number of decoded bytes in the buffer
1914 decoded : usize ,
2015 /// The number of read but not yet decoded bytes in the buffer
2116 pending : usize ,
22- cipher_suite : PhantomData < CipherSuite > ,
2317}
2418
25- impl < ' a , CipherSuite > RecordReader < ' a , CipherSuite >
26- where
27- CipherSuite : TlsCipherSuite ,
28- {
19+ impl < ' a > RecordReader < ' a > {
2920 pub fn new ( buf : & ' a mut [ u8 ] ) -> Self {
3021 if buf. len ( ) < 16640 {
3122 warn ! ( "Read buffer is smaller than 16640 bytes, which may cause problems!" ) ;
@@ -34,33 +25,26 @@ where
3425 buf,
3526 decoded : 0 ,
3627 pending : 0 ,
37- cipher_suite : PhantomData ,
3828 }
3929 }
4030
41- pub async fn read < ' m > (
31+ pub async fn read < ' m , CipherSuite : TlsCipherSuite > (
4232 & ' m mut self ,
4333 transport : & mut impl AsyncRead ,
4434 key_schedule : & mut ReadKeySchedule < CipherSuite > ,
4535 ) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
46- let header = self . advance ( transport, 5 ) . await ?;
47- let header = RecordHeader :: decode ( unwrap ! ( header. try_into( ) . ok( ) ) ) ?;
48-
49- let content_length = header. content_length ( ) ;
50- debug ! (
51- "advance: {:?} - content_length = {} bytes" ,
52- header. content_type( ) ,
53- content_length
54- ) ;
55- let data = self . advance ( transport, content_length) . await ?;
56- ServerRecord :: decode ( header, data, key_schedule. transcript_hash ( ) )
36+ self . advance ( transport, RecordHeader :: LEN ) . await ?;
37+ let header = self . record_header ( ) ?;
38+ self . advance ( transport, RecordHeader :: LEN + header. content_length ( ) )
39+ . await ?;
40+ self . consume ( header, key_schedule. transcript_hash ( ) )
5741 }
5842
5943 async fn advance < ' m > (
6044 & ' m mut self ,
6145 transport : & mut impl AsyncRead ,
6246 amount : usize ,
63- ) -> Result < & ' m mut [ u8 ] , TlsError > {
47+ ) -> Result < ( ) , TlsError > {
6448 self . ensure_contiguous ( amount) ?;
6549
6650 while self . pending < amount {
@@ -74,27 +58,25 @@ where
7458 self . pending += read;
7559 }
7660
77- Ok ( self . consume ( amount ) )
61+ Ok ( ( ) )
7862 }
7963
80- pub fn read_blocking < ' m > (
64+ pub fn read_blocking < ' m , CipherSuite : TlsCipherSuite > (
8165 & ' m mut self ,
8266 transport : & mut impl BlockingRead ,
8367 key_schedule : & mut ReadKeySchedule < CipherSuite > ,
8468 ) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
85- let header = self . advance_blocking ( transport, 5 ) ?;
86- let header = RecordHeader :: decode ( unwrap ! ( header. try_into( ) . ok( ) ) ) ?;
87-
88- let content_length = header. content_length ( ) ;
89- let data = self . advance_blocking ( transport, content_length) ?;
90- ServerRecord :: decode ( header, data, key_schedule. transcript_hash ( ) )
69+ self . advance_blocking ( transport, RecordHeader :: LEN ) ?;
70+ let header = self . record_header ( ) ?;
71+ self . advance_blocking ( transport, RecordHeader :: LEN + header. content_length ( ) ) ?;
72+ self . consume ( header, key_schedule. transcript_hash ( ) )
9173 }
9274
9375 fn advance_blocking < ' m > (
9476 & ' m mut self ,
9577 transport : & mut impl BlockingRead ,
9678 amount : usize ,
97- ) -> Result < & ' m mut [ u8 ] , TlsError > {
79+ ) -> Result < ( ) , TlsError > {
9880 self . ensure_contiguous ( amount) ?;
9981
10082 while self . pending < amount {
@@ -107,14 +89,30 @@ where
10789 self . pending += read;
10890 }
10991
110- Ok ( self . consume ( amount ) )
92+ Ok ( ( ) )
11193 }
11294
113- fn consume ( & mut self , amount : usize ) -> & mut [ u8 ] {
114- let slice = & mut self . buf [ self . decoded ..self . decoded + amount] ;
115- self . decoded += amount;
116- self . pending -= amount;
117- slice
95+ fn record_header ( & self ) -> Result < RecordHeader , TlsError > {
96+ RecordHeader :: decode ( unwrap ! ( self . buf
97+ [ self . decoded..self . decoded + RecordHeader :: LEN ]
98+ . try_into( )
99+ . ok( ) ) )
100+ }
101+
102+ fn consume < ' m , CipherSuite : TlsCipherSuite > (
103+ & ' m mut self ,
104+ header : RecordHeader ,
105+ digest : & mut CipherSuite :: Hash ,
106+ ) -> Result < ServerRecord < ' m , CipherSuite > , TlsError > {
107+ let content_len = header. content_length ( ) ;
108+
109+ let slice = & mut self . buf
110+ [ self . decoded + RecordHeader :: LEN ..self . decoded + RecordHeader :: LEN + content_len] ;
111+
112+ self . decoded += RecordHeader :: LEN + content_len;
113+ self . pending -= RecordHeader :: LEN + content_len;
114+
115+ ServerRecord :: decode ( header, slice, digest)
118116 }
119117
120118 fn ensure_contiguous ( & mut self , len : usize ) -> Result < ( ) , TlsError > {
@@ -207,7 +205,7 @@ mod tests {
207205 ) ;
208206
209207 let mut buf = [ 0 ; 32 ] ;
210- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
208+ let mut reader = RecordReader :: new ( & mut buf) ;
211209 let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
212210
213211 {
@@ -265,8 +263,8 @@ mod tests {
265263 ]
266264 . as_slice ( ) ;
267265
268- let mut buf = [ 0 ; 5 ] ; // This buffer is so small that it cannot contain both the header and data
269- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
266+ let mut buf = [ 0 ; 9 ] ; // This buffer is so small that it cannot contain both the header and data
267+ let mut reader = RecordReader :: new ( & mut buf) ;
270268 let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
271269
272270 {
@@ -279,8 +277,8 @@ mod tests {
279277 panic ! ( "Wrong server record" ) ;
280278 }
281279
282- assert_eq ! ( 4 , reader. decoded) ; // The buffer is rotated after decoding the header
283- assert_eq ! ( 1 , reader. pending) ;
280+ assert_eq ! ( 9 , reader. decoded) ; // The buffer is rotated after decoding the header
281+ assert_eq ! ( 0 , reader. pending) ;
284282 }
285283
286284 {
@@ -293,7 +291,7 @@ mod tests {
293291 panic ! ( "Wrong server record" ) ;
294292 }
295293
296- assert_eq ! ( 2 , reader. decoded) ;
294+ assert_eq ! ( 7 , reader. decoded) ;
297295 assert_eq ! ( 0 , reader. pending) ;
298296 }
299297 }
@@ -318,7 +316,7 @@ mod tests {
318316 . as_slice ( ) ;
319317
320318 let mut buf = [ 0 ; 32 ] ;
321- let mut reader = RecordReader :: < Aes128GcmSha256 > :: new ( & mut buf) ;
319+ let mut reader = RecordReader :: new ( & mut buf) ;
322320 let mut key_schedule = KeySchedule :: < Aes128GcmSha256 > :: new ( ) ;
323321
324322 {
0 commit comments