@@ -102,6 +102,7 @@ struct ConnInner {
102
102
status : StatusFlags ,
103
103
last_ok_packet : Option < OkPacket < ' static > > ,
104
104
last_err_packet : Option < mysql_common:: packets:: ServerError < ' static > > ,
105
+ handshake_complete : bool ,
105
106
pool : Option < Pool > ,
106
107
pending_result : std:: result:: Result < Option < PendingResult > , ServerError > ,
107
108
tx_status : TxStatus ,
@@ -147,6 +148,7 @@ impl ConnInner {
147
148
status : StatusFlags :: empty ( ) ,
148
149
last_ok_packet : None ,
149
150
last_err_packet : None ,
151
+ handshake_complete : false ,
150
152
stream : None ,
151
153
is_mariadb : false ,
152
154
version : ( 0 , 0 , 0 ) ,
@@ -585,6 +587,7 @@ impl Conn {
585
587
handshake_response. serialize ( buf. as_mut ( ) ) ;
586
588
587
589
self . write_packet ( buf) . await ?;
590
+ self . inner . handshake_complete = true ;
588
591
Ok ( ( ) )
589
592
}
590
593
@@ -789,7 +792,19 @@ impl Conn {
789
792
if let Ok ( ok_packet) = ok_packet {
790
793
self . handle_ok ( ok_packet. into_owned ( ) ) ;
791
794
} else {
792
- let err_packet = ParseBuf ( packet) . parse :: < ErrPacket > ( self . capabilities ( ) ) ;
795
+ // If we haven't completed the handshake the server will not be aware of our
796
+ // capabilities and so it will behave as if we have none. In particular, the error
797
+ // packet will not contain a SQL State field even if our capabilities do contain the
798
+ // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
799
+ // with no capability assumptions if we have not completed the handshake.
800
+ //
801
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
802
+ let capabilities = if self . inner . handshake_complete {
803
+ self . capabilities ( )
804
+ } else {
805
+ CapabilityFlags :: empty ( )
806
+ } ;
807
+ let err_packet = ParseBuf ( packet) . parse :: < ErrPacket > ( capabilities) ;
793
808
if let Ok ( err_packet) = err_packet {
794
809
self . handle_err ( err_packet) ?;
795
810
return Ok ( true ) ;
@@ -1270,10 +1285,11 @@ mod test {
1270
1285
use futures_util:: stream:: { self , StreamExt } ;
1271
1286
use mysql_common:: constants:: MAX_PAYLOAD_LEN ;
1272
1287
use rand:: Fill ;
1288
+ use tokio:: { io:: AsyncWriteExt , net:: TcpListener } ;
1273
1289
1274
1290
use crate :: {
1275
1291
from_row, params, prelude:: * , test_misc:: get_opts, ChangeUserOpts , Conn , Error ,
1276
- OptsBuilder , Pool , Value , WhiteListFsHandler ,
1292
+ OptsBuilder , Pool , ServerError , Value , WhiteListFsHandler ,
1277
1293
} ;
1278
1294
1279
1295
#[ tokio:: test]
@@ -2189,6 +2205,45 @@ mod test {
2189
2205
Ok ( ( ) )
2190
2206
}
2191
2207
2208
+ #[ tokio:: test]
2209
+ async fn should_handle_initial_error_packet ( ) {
2210
+ let header = [
2211
+ 0x68 , 0x00 , 0x00 , // packet_length
2212
+ 0x00 , // sequence
2213
+ 0xff , // error_header
2214
+ 0x69 , 0x04 , // error_code
2215
+ ] ;
2216
+ let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'" ;
2217
+
2218
+ // Create a fake MySQL server that immediately replies with an error packet.
2219
+ let listener = TcpListener :: bind ( "127.0.0.1:0000" ) . await . unwrap ( ) ;
2220
+
2221
+ let listen_addr = listener. local_addr ( ) . unwrap ( ) ;
2222
+
2223
+ tokio:: task:: spawn ( async move {
2224
+ let ( mut stream, _) = listener. accept ( ) . await . unwrap ( ) ;
2225
+ stream. write_all ( & header) . await . unwrap ( ) ;
2226
+ stream. write_all ( error_message. as_bytes ( ) ) . await . unwrap ( ) ;
2227
+ stream. shutdown ( ) . await . unwrap ( ) ;
2228
+ } ) ;
2229
+
2230
+ let opts = OptsBuilder :: default ( )
2231
+ . ip_or_hostname ( listen_addr. ip ( ) . to_string ( ) )
2232
+ . tcp_port ( listen_addr. port ( ) ) ;
2233
+ let server_err = match Conn :: new ( opts) . await {
2234
+ Err ( Error :: Server ( server_err) ) => server_err,
2235
+ other => panic ! ( "expected server error but got: {:?}" , other) ,
2236
+ } ;
2237
+ assert_eq ! (
2238
+ server_err,
2239
+ ServerError {
2240
+ code: 1129 ,
2241
+ state: "HY000" . to_owned( ) ,
2242
+ message: error_message. to_owned( ) ,
2243
+ }
2244
+ ) ;
2245
+ }
2246
+
2192
2247
#[ cfg( feature = "nightly" ) ]
2193
2248
mod bench {
2194
2249
use crate :: { conn:: Conn , queryable:: Queryable , test_misc:: get_opts} ;
0 commit comments