4
4
use std:: collections:: VecDeque ;
5
5
use std:: fs:: File ;
6
6
use std:: io:: { Read , Write } ;
7
+ use std:: os:: unix:: io:: FromRawFd ;
7
8
8
9
use crate :: common:: ascii:: { CR , CRLF_LEN , LF } ;
9
10
use crate :: common:: Body ;
@@ -15,6 +16,7 @@ use crate::server::MAX_PAYLOAD_SIZE;
15
16
use vmm_sys_util:: sock_ctrl_msg:: ScmSocket ;
16
17
17
18
const BUFFER_SIZE : usize = 1024 ;
19
+ const SCM_MAX_FD : usize = 253 ;
18
20
19
21
/// Describes the state machine of an HTTP connection.
20
22
enum ConnectionState {
@@ -52,9 +54,9 @@ pub struct HttpConnection<T> {
52
54
/// A buffer containing the bytes of a response that is currently
53
55
/// being sent.
54
56
response_buffer : Option < Vec < u8 > > ,
55
- /// The latest file that has been received and which must be associated
57
+ /// The list of files that has been received and which must be associated
56
58
/// with the pending request.
57
- file : Option < File > ,
59
+ files : Vec < File > ,
58
60
/// Optional payload max size.
59
61
payload_max_size : usize ,
60
62
}
@@ -73,7 +75,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
73
75
parsed_requests : VecDeque :: new ( ) ,
74
76
response_queue : VecDeque :: new ( ) ,
75
77
response_buffer : None ,
76
- file : None ,
78
+ files : Vec :: new ( ) ,
77
79
payload_max_size : MAX_PAYLOAD_SIZE ,
78
80
}
79
81
}
@@ -123,7 +125,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
123
125
self . state = ConnectionState :: WaitingForRequestLine ;
124
126
self . body_bytes_to_be_read = 0 ;
125
127
let mut pending_request = self . pending_request . take ( ) . unwrap ( ) ;
126
- pending_request. file = self . file . take ( ) ;
128
+ pending_request. files = self . files . drain ( .. ) . collect ( ) ;
127
129
self . parsed_requests . push_back ( pending_request) ;
128
130
}
129
131
} ;
@@ -143,15 +145,11 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
143
145
}
144
146
// Append new bytes to what we already have in the buffer.
145
147
// The slice access is safe, the index is checked above.
146
- let ( bytes_read, file) = self
147
- . stream
148
- . recv_with_fd ( & mut self . buffer [ self . read_cursor ..] )
149
- . map_err ( ConnectionError :: StreamReadError ) ?;
150
-
151
- // Update the internal file that must be associated with the request.
152
- if file. is_some ( ) {
153
- self . file = file;
154
- }
148
+ let ( bytes_read, new_files) = self . recv_with_fds ( ) ?;
149
+
150
+ // Update the internal list of files that must be associated with the
151
+ // request.
152
+ self . files . extend ( new_files) ;
155
153
156
154
// If the read returned 0 then the client has closed the connection.
157
155
if bytes_read == 0 {
@@ -162,6 +160,43 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
162
160
. ok_or ( ConnectionError :: ParseError ( RequestError :: Overflow ) )
163
161
}
164
162
163
+ /// Receive data along with optional files descriptors.
164
+ /// It is a wrapper around the same function from vmm-sys-util.
165
+ ///
166
+ /// # Errors
167
+ /// `StreamError` is returned if any error occurred while reading the stream.
168
+ fn recv_with_fds ( & mut self ) -> Result < ( usize , Vec < File > ) , ConnectionError > {
169
+ let buf = & mut self . buffer [ self . read_cursor ..] ;
170
+ // We must allocate the maximum number of receivable file descriptors
171
+ // if don't want to miss any of them. Allocating a too small number
172
+ // would lead to the incapacity of receiving the file descriptors.
173
+ let mut fds = [ 0 ; SCM_MAX_FD ] ;
174
+ let mut iovecs = [ libc:: iovec {
175
+ iov_base : buf. as_mut_ptr ( ) as * mut libc:: c_void ,
176
+ iov_len : buf. len ( ) ,
177
+ } ] ;
178
+
179
+ // Safe because we have mutably borrowed buf and it's safe to write
180
+ // arbitrary data to a slice.
181
+ let ( read_count, fd_count) = unsafe {
182
+ self . stream
183
+ . recv_with_fds ( & mut iovecs, & mut fds)
184
+ . map_err ( ConnectionError :: StreamReadError ) ?
185
+ } ;
186
+
187
+ Ok ( (
188
+ read_count,
189
+ fds. iter ( )
190
+ . take ( fd_count)
191
+ . map ( |fd| {
192
+ // Safe because all fds are owned by us after they have been
193
+ // received through the socket.
194
+ unsafe { File :: from_raw_fd ( * fd) }
195
+ } )
196
+ . collect ( ) ,
197
+ ) )
198
+ }
199
+
165
200
/// Parses bytes in `buffer` for a valid request line.
166
201
/// Returns `false` if there are no more bytes to be parsed in the buffer.
167
202
///
@@ -197,7 +232,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
197
232
. map_err ( ConnectionError :: ParseError ) ?,
198
233
headers : Headers :: default ( ) ,
199
234
body : None ,
200
- file : None ,
235
+ files : Vec :: new ( ) ,
201
236
} ) ;
202
237
self . state = ConnectionState :: WaitingForHeaders ;
203
238
Ok ( true )
@@ -517,13 +552,17 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
517
552
518
553
#[ cfg( test) ]
519
554
mod tests {
555
+ use std:: io:: { Seek , SeekFrom } ;
520
556
use std:: net:: Shutdown ;
557
+ use std:: os:: unix:: io:: IntoRawFd ;
521
558
use std:: os:: unix:: net:: UnixStream ;
522
559
523
560
use super :: * ;
524
561
use crate :: common:: { Method , Version } ;
525
562
use crate :: server:: MAX_PAYLOAD_SIZE ;
526
563
564
+ use vmm_sys_util:: tempfile:: TempFile ;
565
+
527
566
#[ test]
528
567
fn test_try_read_expect ( ) {
529
568
// Test request with `Expect` header.
@@ -548,7 +587,7 @@ mod tests {
548
587
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
549
588
headers : Headers :: new ( 26 , true , true ) ,
550
589
body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
551
- file : None ,
590
+ files : Vec :: new ( ) ,
552
591
} ;
553
592
554
593
assert_eq ! ( request, expected_request) ;
@@ -585,7 +624,7 @@ mod tests {
585
624
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
586
625
headers : Headers :: new ( 26 , true , true ) ,
587
626
body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
588
- file : None ,
627
+ files : Vec :: new ( ) ,
589
628
} ;
590
629
assert_eq ! ( request, expected_request) ;
591
630
}
@@ -619,7 +658,7 @@ mod tests {
619
658
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
620
659
headers : Headers :: new ( 26 , true , true ) ,
621
660
body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
622
- file : None ,
661
+ files : Vec :: new ( ) ,
623
662
} ;
624
663
assert_eq ! ( request, expected_request) ;
625
664
}
@@ -684,7 +723,7 @@ mod tests {
684
723
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
685
724
headers : Headers :: new ( 1400 , true , true ) ,
686
725
body : Some ( Body :: new ( request_body) ) ,
687
- file : None ,
726
+ files : Vec :: new ( ) ,
688
727
} ;
689
728
690
729
assert_eq ! ( request, expected_request) ;
@@ -755,7 +794,7 @@ mod tests {
755
794
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
756
795
headers : Headers :: new ( 0 , true , true ) ,
757
796
body : None ,
758
- file : None ,
797
+ files : Vec :: new ( ) ,
759
798
} ;
760
799
assert_eq ! ( request, expected_request) ;
761
800
}
@@ -777,7 +816,7 @@ mod tests {
777
816
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
778
817
headers : Headers :: new ( 0 , false , false ) ,
779
818
body : None ,
780
- file : None ,
819
+ files : Vec :: new ( ) ,
781
820
} ;
782
821
assert_eq ! ( request, expected_request) ;
783
822
}
@@ -806,7 +845,7 @@ mod tests {
806
845
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
807
846
headers : Headers :: new ( 0 , false , false ) ,
808
847
body : None ,
809
- file : None ,
848
+ files : Vec :: new ( ) ,
810
849
} ;
811
850
assert_eq ! ( request, expected_request) ;
812
851
@@ -825,7 +864,7 @@ mod tests {
825
864
) ,
826
865
headers : Headers :: new ( 0 , false , false ) ,
827
866
body : None ,
828
- file : None ,
867
+ files : Vec :: new ( ) ,
829
868
} ;
830
869
assert_eq ! ( request, expected_request) ;
831
870
}
@@ -853,7 +892,7 @@ mod tests {
853
892
request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
854
893
headers : Headers :: new ( 26 , false , true ) ,
855
894
body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
856
- file : None ,
895
+ files : Vec :: new ( ) ,
857
896
} ;
858
897
859
898
conn. try_read ( ) . unwrap ( ) ;
@@ -864,7 +903,7 @@ mod tests {
864
903
request_line : RequestLine :: new ( Method :: Put , "http://farhost/away" , Version :: Http11 ) ,
865
904
headers : Headers :: new ( 23 , false , false ) ,
866
905
body : Some ( Body :: new ( b"this is another request" . to_vec ( ) ) ) ,
867
- file : None ,
906
+ files : Vec :: new ( ) ,
868
907
} ;
869
908
assert_eq ! ( request_first, expected_request_first) ;
870
909
assert_eq ! ( request_second, expected_request_second) ;
@@ -999,6 +1038,77 @@ mod tests {
999
1038
) ;
1000
1039
}
1001
1040
1041
+ #[ test]
1042
+ fn test_read_bytes_with_files ( ) {
1043
+ let ( sender, receiver) = UnixStream :: pair ( ) . unwrap ( ) ;
1044
+ receiver. set_nonblocking ( true ) . expect ( "Can't modify socket" ) ;
1045
+ let mut conn = HttpConnection :: new ( receiver) ;
1046
+
1047
+ // Create 3 files, edit the content and rewind back to the start.
1048
+ let mut file1 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1049
+ let mut file2 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1050
+ let mut file3 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1051
+ file1. write ( b"foo" ) . unwrap ( ) ;
1052
+ file1. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1053
+ file2. write ( b"bar" ) . unwrap ( ) ;
1054
+ file2. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1055
+ file3. write ( b"foobar" ) . unwrap ( ) ;
1056
+ file3. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1057
+
1058
+ // Send 2 file descriptors along with 3 bytes of data.
1059
+ assert_eq ! (
1060
+ sender. send_with_fds(
1061
+ & [ [ 1 , 2 , 3 ] . as_ref( ) ] ,
1062
+ & [ file1. into_raw_fd( ) , file2. into_raw_fd( ) ]
1063
+ ) ,
1064
+ Ok ( 3 )
1065
+ ) ;
1066
+
1067
+ // Check we receive the right amount of data along with the right
1068
+ // amount of file descriptors.
1069
+ assert_eq ! ( conn. read_bytes( ) , Ok ( 3 ) ) ;
1070
+ assert_eq ! ( conn. files. len( ) , 2 ) ;
1071
+
1072
+ // Check the content of the data received
1073
+ assert_eq ! ( conn. buffer[ 0 ] , 1 ) ;
1074
+ assert_eq ! ( conn. buffer[ 1 ] , 2 ) ;
1075
+ assert_eq ! ( conn. buffer[ 2 ] , 3 ) ;
1076
+
1077
+ // Check the file descriptors are usable by checking the content that
1078
+ // can be read.
1079
+ let mut buf = [ 0 ; 10 ] ;
1080
+ assert_eq ! ( conn. files[ 0 ] . read( & mut buf) . unwrap( ) , 3 ) ;
1081
+ assert_eq ! ( & buf[ ..3 ] , b"foo" ) ;
1082
+ assert_eq ! ( conn. files[ 1 ] . read( & mut buf) . unwrap( ) , 3 ) ;
1083
+ assert_eq ! ( & buf[ ..3 ] , b"bar" ) ;
1084
+
1085
+ // Send the 3rd file descriptor along with 1 byte of data.
1086
+ assert_eq ! (
1087
+ sender. send_with_fds( & [ [ 10 ] . as_ref( ) ] , & [ file3. into_raw_fd( ) ] ) ,
1088
+ Ok ( 1 )
1089
+ ) ;
1090
+
1091
+ // Check the amount of data along with the amount of file descriptors
1092
+ // are updated.
1093
+ assert_eq ! ( conn. read_bytes( ) , Ok ( 1 ) ) ;
1094
+ assert_eq ! ( conn. files. len( ) , 3 ) ;
1095
+
1096
+ // Check the content of the new data received
1097
+ assert_eq ! ( conn. buffer[ 0 ] , 10 ) ;
1098
+
1099
+ // Check the latest file descriptor is usable by checking the content
1100
+ // that can be read.
1101
+ let mut buf = [ 0 ; 10 ] ;
1102
+ assert_eq ! ( conn. files[ 2 ] . read( & mut buf) . unwrap( ) , 6 ) ;
1103
+ assert_eq ! ( & buf[ ..6 ] , b"foobar" ) ;
1104
+
1105
+ sender. shutdown ( Shutdown :: Write ) . unwrap ( ) ;
1106
+ assert_eq ! (
1107
+ conn. read_bytes( ) . unwrap_err( ) ,
1108
+ ConnectionError :: ConnectionClosed
1109
+ ) ;
1110
+ }
1111
+
1002
1112
#[ test]
1003
1113
fn test_shift_buffer_left ( ) {
1004
1114
let ( _, receiver) = UnixStream :: pair ( ) . unwrap ( ) ;
@@ -1095,7 +1205,7 @@ mod tests {
1095
1205
request_line : RequestLine :: new ( Method :: Get , "http://foo/bar" , Version :: Http11 ) ,
1096
1206
headers : Headers :: new ( 0 , true , true ) ,
1097
1207
body : None ,
1098
- file : None ,
1208
+ files : Vec :: new ( ) ,
1099
1209
} ) ;
1100
1210
assert_eq ! (
1101
1211
conn. parse_headers( & mut 0 , BUFFER_SIZE ) . unwrap_err( ) ,
@@ -1153,7 +1263,7 @@ mod tests {
1153
1263
request_line : RequestLine :: new ( Method :: Get , "http://foo/bar" , Version :: Http11 ) ,
1154
1264
headers : Headers :: new ( 0 , true , true ) ,
1155
1265
body : None ,
1156
- file : None ,
1266
+ files : Vec :: new ( ) ,
1157
1267
} ) ;
1158
1268
conn. body_vec = vec ! [ 0xde , 0xad , 0xbe , 0xef ] ;
1159
1269
assert_eq ! (
0 commit comments