Skip to content

Commit 125aebe

Browse files
authored
Merge pull request #307 from petrosagg/initial-error-packet
conn: handle initial error packet correctly
2 parents a475a5f + 4bf929a commit 125aebe

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

src/conn/mod.rs

+57-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct ConnInner {
102102
status: StatusFlags,
103103
last_ok_packet: Option<OkPacket<'static>>,
104104
last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105+
handshake_complete: bool,
105106
pool: Option<Pool>,
106107
pending_result: std::result::Result<Option<PendingResult>, ServerError>,
107108
tx_status: TxStatus,
@@ -147,6 +148,7 @@ impl ConnInner {
147148
status: StatusFlags::empty(),
148149
last_ok_packet: None,
149150
last_err_packet: None,
151+
handshake_complete: false,
150152
stream: None,
151153
is_mariadb: false,
152154
version: (0, 0, 0),
@@ -585,6 +587,7 @@ impl Conn {
585587
handshake_response.serialize(buf.as_mut());
586588

587589
self.write_packet(buf).await?;
590+
self.inner.handshake_complete = true;
588591
Ok(())
589592
}
590593

@@ -789,7 +792,19 @@ impl Conn {
789792
if let Ok(ok_packet) = ok_packet {
790793
self.handle_ok(ok_packet.into_owned());
791794
} else {
792-
let err_packet = ParseBuf(packet).parse::<ErrPacket>(self.capabilities());
795+
// If we haven't completed the handshake the server will not be aware of our
796+
// capabilities and so it will behave as if we have none. In particular, the error
797+
// packet will not contain a SQL State field even if our capabilities do contain the
798+
// `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
799+
// with no capability assumptions if we have not completed the handshake.
800+
//
801+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
802+
let capabilities = if self.inner.handshake_complete {
803+
self.capabilities()
804+
} else {
805+
CapabilityFlags::empty()
806+
};
807+
let err_packet = ParseBuf(packet).parse::<ErrPacket>(capabilities);
793808
if let Ok(err_packet) = err_packet {
794809
self.handle_err(err_packet)?;
795810
return Ok(true);
@@ -1270,10 +1285,11 @@ mod test {
12701285
use futures_util::stream::{self, StreamExt};
12711286
use mysql_common::constants::MAX_PAYLOAD_LEN;
12721287
use rand::Fill;
1288+
use tokio::{io::AsyncWriteExt, net::TcpListener};
12731289

12741290
use crate::{
12751291
from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1276-
OptsBuilder, Pool, Value, WhiteListFsHandler,
1292+
OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler,
12771293
};
12781294

12791295
#[tokio::test]
@@ -2189,6 +2205,45 @@ mod test {
21892205
Ok(())
21902206
}
21912207

2208+
#[tokio::test]
2209+
async fn should_handle_initial_error_packet() {
2210+
let header = [
2211+
0x68, 0x00, 0x00, // packet_length
2212+
0x00, // sequence
2213+
0xff, // error_header
2214+
0x69, 0x04, // error_code
2215+
];
2216+
let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'";
2217+
2218+
// Create a fake MySQL server that immediately replies with an error packet.
2219+
let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap();
2220+
2221+
let listen_addr = listener.local_addr().unwrap();
2222+
2223+
tokio::task::spawn(async move {
2224+
let (mut stream, _) = listener.accept().await.unwrap();
2225+
stream.write_all(&header).await.unwrap();
2226+
stream.write_all(error_message.as_bytes()).await.unwrap();
2227+
stream.shutdown().await.unwrap();
2228+
});
2229+
2230+
let opts = OptsBuilder::default()
2231+
.ip_or_hostname(listen_addr.ip().to_string())
2232+
.tcp_port(listen_addr.port());
2233+
let server_err = match Conn::new(opts).await {
2234+
Err(Error::Server(server_err)) => server_err,
2235+
other => panic!("expected server error but got: {:?}", other),
2236+
};
2237+
assert_eq!(
2238+
server_err,
2239+
ServerError {
2240+
code: 1129,
2241+
state: "HY000".to_owned(),
2242+
message: error_message.to_owned(),
2243+
}
2244+
);
2245+
}
2246+
21922247
#[cfg(feature = "nightly")]
21932248
mod bench {
21942249
use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};

0 commit comments

Comments
 (0)