Skip to content

Commit 82c4266

Browse files
authored
Merge pull request #342 from blackbeam/result-set-terminator
Fix result set terminator handling
2 parents bac4b8d + 0ba13f6 commit 82c4266

File tree

4 files changed

+42
-36
lines changed

4 files changed

+42
-36
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ crossbeam = "0.8.1"
7070
io-enum = "1.0.0"
7171
flate2 = { version = "1.0", default-features = false }
7272
lru = "0.8.1"
73-
mysql_common = { version = "0.29.1", default-features = false }
73+
mysql_common = { version = "0.29.2", default-features = false }
7474
socket2 = "0.4"
7575
once_cell = "1.7.2"
7676
pem = "1.0.1"

azure-pipelines.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ jobs:
103103
104104
- job: "TestBasicWindows"
105105
pool:
106-
vmImage: "vs2017-win2016"
106+
vmImage: "windows-2019"
107107
strategy:
108108
maxParallel: 10
109109
matrix:
@@ -125,7 +125,7 @@ jobs:
125125
echo socket=MYSQL >> C:\my.cnf
126126
echo datadir=C:\\ProgramData\\MySQL\\MySQL Server 8.0\\Data\\ >> C:\my.cnf
127127
"C:\Program Files\MySQL\MySQL Server 8.0\bin\mysqld" --install MySQL --defaults-file=C:\my.cnf
128-
net start MySql
128+
net start MySQL
129129
"C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET GLOBAL max_allowed_packet = 36700160;" -uroot -ppassword
130130
"C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = WARN;" -uroot -ppassword
131131
"C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = ON;" -uroot -ppassword

src/conn/mod.rs

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use mysql_common::{
1717
binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, Column, ComStmtClose,
1818
ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
1919
HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
20-
OldAuthSwitchRequest, ResultSetTerminator, SessionStateInfo,
20+
OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
2121
},
2222
proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
2323
row::{Row, RowDeserializer},
@@ -204,6 +204,11 @@ impl ConnInner {
204204
pub struct Conn(Box<ConnInner>);
205205

206206
impl Conn {
207+
/// Must not be called before handle_handshake.
208+
const fn has_capability(&self, flag: CapabilityFlags) -> bool {
209+
self.0.capability_flags.contains(flag)
210+
}
211+
207212
/// Returns version number reported by the server.
208213
pub fn server_version(&self) -> (u16, u16, u16) {
209214
self.0
@@ -562,10 +567,7 @@ impl Conn {
562567

563568
if self.is_insecure() {
564569
if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
565-
if !handshake
566-
.capabilities()
567-
.contains(CapabilityFlags::CLIENT_SSL)
568-
{
570+
if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
569571
return Err(DriverError(TlsNotSupported));
570572
} else {
571573
self.do_ssl_request()?;
@@ -596,11 +598,7 @@ impl Conn {
596598
self.write_handshake_response(&auth_plugin, auth_data.as_deref())?;
597599
self.continue_auth(&auth_plugin, &*nonce, false)?;
598600

599-
if self
600-
.0
601-
.capability_flags
602-
.contains(CapabilityFlags::CLIENT_COMPRESS)
603-
{
601+
if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
604602
self.switch_to_compressed();
605603
}
606604

@@ -1080,32 +1078,28 @@ impl Conn {
10801078
self.query_first(format!("SELECT @@{}", name))
10811079
}
10821080

1083-
fn next_bin(&mut self, columns: Arc<[Column]>) -> Result<Option<Row>> {
1081+
fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
10841082
if !self.0.has_results {
10851083
return Ok(None);
10861084
}
1087-
let pld = self.read_packet()?;
1088-
if pld[0] == 0xfe && pld.len() < 0xfe {
1089-
self.0.has_results = false;
1090-
self.handle_ok::<ResultSetTerminator>(&pld)?;
1091-
return Ok(None);
1092-
}
1093-
let row = ParseBuf(&*pld).parse::<RowDeserializer<ServerSide, Binary>>(columns)?;
1094-
Ok(Some(row.into()))
1095-
}
10961085

1097-
fn next_text(&mut self, columns: Arc<[Column]>) -> Result<Option<Row>> {
1098-
if !self.0.has_results {
1099-
return Ok(None);
1100-
}
11011086
let pld = self.read_packet()?;
1102-
if pld[0] == 0xfe && pld.len() < 0xfe {
1103-
self.0.has_results = false;
1104-
self.handle_ok::<ResultSetTerminator>(&pld)?;
1105-
return Ok(None);
1087+
1088+
if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
1089+
if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
1090+
self.0.has_results = false;
1091+
self.handle_ok::<ResultSetTerminator>(&pld)?;
1092+
return Ok(None);
1093+
}
1094+
} else {
1095+
if pld[0] == 0xfe && pld.len() < 8 {
1096+
self.0.has_results = false;
1097+
self.handle_ok::<OldEofPacket>(&pld)?;
1098+
return Ok(None);
1099+
}
11061100
}
1107-
let row = ParseBuf(&*pld).parse::<RowDeserializer<(), Text>>(columns)?;
1108-
Ok(Some(row.into()))
1101+
1102+
Ok(Some(pld))
11091103
}
11101104

11111105
fn has_stmt(&self, query: &[u8]) -> bool {

src/conn/query_result.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
pub use mysql_common::proto::{Binary, Text};
1010

11-
use mysql_common::packets::OkPacket;
11+
use mysql_common::{io::ParseBuf, packets::OkPacket, row::RowDeserializer, value::ServerSide};
1212

1313
use std::{borrow::Cow, marker::PhantomData, sync::Arc};
1414

@@ -27,13 +27,25 @@ pub trait Protocol: 'static + Send + Sync {
2727

2828
impl Protocol for Text {
2929
fn next(conn: &mut Conn, columns: Arc<[Column]>) -> Result<Option<Row>> {
30-
conn.next_text(columns)
30+
match conn.next_row_packet()? {
31+
Some(pld) => {
32+
let row = ParseBuf(&*pld).parse::<RowDeserializer<(), Text>>(columns)?;
33+
Ok(Some(row.into()))
34+
}
35+
None => Ok(None),
36+
}
3137
}
3238
}
3339

3440
impl Protocol for Binary {
3541
fn next(conn: &mut Conn, columns: Arc<[Column]>) -> Result<Option<Row>> {
36-
conn.next_bin(columns)
42+
match conn.next_row_packet()? {
43+
Some(pld) => {
44+
let row = ParseBuf(&*pld).parse::<RowDeserializer<ServerSide, Binary>>(columns)?;
45+
Ok(Some(row.into()))
46+
}
47+
None => Ok(None),
48+
}
3749
}
3850
}
3951

0 commit comments

Comments
 (0)