Skip to content

Commit 667a05a

Browse files
committed
Ensure backward compatibility with legacy EOF format
1 parent c34acf9 commit 667a05a

File tree

4 files changed

+51
-43
lines changed

4 files changed

+51
-43
lines changed

connection.go

+22-11
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,24 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
180180

181181
// Read Result
182182
columnCount, err := stmt.readPrepareResultPacket()
183-
if err != nil {
184-
return stmt, err
185-
}
186-
187-
if err := mc.readPackets(stmt.paramCount); err != nil {
188-
return nil, err
189-
}
183+
if err == nil {
184+
if stmt.paramCount > 0 {
185+
// FIXME - seems like a bug in MySQL (or it's intended).
186+
// There's no EOF return after parameters.
187+
// However, this behavior isn't consistent to Maria DB.
188+
if mc.flags&clientDeprecateEOF == 0 {
189+
if err = mc.readUntilEOF(); err != nil {
190+
return nil, err
191+
}
192+
}
193+
if err = mc.readExactPackets(stmt.paramCount); err != nil {
194+
return nil, err
195+
}
196+
}
190197

191-
if err := mc.readPackets(int(columnCount)); err != nil {
192-
return nil, err
198+
if columnCount > 0 {
199+
err = mc.readUntilEOF()
200+
}
193201
}
194202

195203
return stmt, err
@@ -415,8 +423,11 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
415423
rows.mc = mc
416424
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
417425

418-
if err := mc.readPackets(resLen); err != nil {
419-
return nil, err
426+
if resLen > 0 {
427+
// Columns
428+
if err := mc.readUntilEOF(); err != nil {
429+
return nil, err
430+
}
420431
}
421432

422433
dest := make([]driver.Value, resLen)

packets.go

+28-29
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,12 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
238238
pos += 1 + 2
239239

240240
// capability flags (upper 2 bytes) [2 bytes]
241-
mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
241+
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
242242
pos += 2
243243

244244
// length of auth-plugin-data [1 byte]
245245
// reserved (all [00]) [10 bytes]
246-
pos += +1 + 10
246+
pos += 1 + 10
247247

248248
// second part of the password cipher [mininum 13 bytes],
249249
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -614,7 +614,7 @@ func readStatus(b []byte) statusFlag {
614614
}
615615

616616
// Ok Packet
617-
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
617+
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
618618
func (mc *mysqlConn) handleOkPacket(data []byte) error {
619619
// 0x00 or 0xFE [1 byte]
620620
n := 1
@@ -640,22 +640,34 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
640640

641641
// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
642642
// acting as an EOF.
643-
func isEOFPacket(data []byte) bool {
644-
return data[0] == iEOF && len(data) < 9
643+
func (mc *mysqlConn) isEOFPacket(data []byte) bool {
644+
// Legacy EOF packet
645+
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) && mc.flags&clientDeprecateEOF == 0 {
646+
return true
647+
}
648+
return data[0] == iEOF && len(data) < 9 && mc.flags&clientDeprecateEOF != 0
645649
}
646650

647651
// Read Packets as Field Packets until EOF-Packet or an Error appears
648652
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
649653
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
650654
columns := make([]mysqlField, count)
651655

652-
for i := 0; i < count; i++ {
656+
// If we set clientDeprecateEOF capability flag,
657+
// the EOF will be no longer sent after all columns.
658+
packets := count
659+
if mc.flags&clientDeprecateEOF == 0 {
660+
// Legacy way, read one more EOF packet.
661+
packets += 1
662+
}
663+
664+
for i := 0; i < packets; i++ {
653665
data, err := mc.readPacket()
654666
if err != nil {
655667
return nil, err
656668
}
657669

658-
if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) {
670+
if mc.isEOFPacket(data) {
659671
if i == count {
660672
return columns, nil
661673
}
@@ -759,12 +771,13 @@ func (rows *textRows) readRow(dest []driver.Value) error {
759771
}
760772

761773
// EOF Packet
762-
if isEOFPacket(data) {
774+
if mc.isEOFPacket(data) {
763775
if mc.flags&clientDeprecateEOF == 0 {
764776
// server_status [2 bytes]
765777
rows.mc.status = readStatus(data[3:])
766778
} else {
767779
if err := mc.handleOkPacket(data); err != nil {
780+
rows.mc = nil
768781
return err
769782
}
770783
}
@@ -830,37 +843,22 @@ func (mc *mysqlConn) readUntilEOF() error {
830843
switch {
831844
case data[0] == iERR:
832845
return mc.handleErrorPacket(data)
833-
case isEOFPacket(data):
846+
case mc.isEOFPacket(data):
834847
if mc.flags&clientDeprecateEOF == 0 {
835848
mc.status = readStatus(data[3:])
836-
} else {
837-
return mc.handleOkPacket(data)
849+
return nil
838850
}
839-
return nil
851+
return mc.handleOkPacket(data)
840852
}
841853
}
842854
}
843855

844-
func (mc *mysqlConn) readPackets(num int) error {
845-
846-
// we need to read EOF as well
847-
if mc.flags&clientDeprecateEOF == 0 {
848-
num++
849-
}
850-
856+
func (mc *mysqlConn) readExactPackets(num int) error {
851857
for i := 0; i < num; i++ {
852-
data, err := mc.readPacket()
858+
_, err := mc.readPacket()
853859
if err != nil {
854860
return err
855861
}
856-
857-
switch {
858-
case data[0] == iERR:
859-
return mc.handleErrorPacket(data)
860-
case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data):
861-
mc.status = readStatus(data[3:])
862-
return nil
863-
}
864862
}
865863
return nil
866864
}
@@ -1223,11 +1221,12 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12231221

12241222
// packet indicator [1 byte]
12251223
if data[0] != iOK {
1226-
if isEOFPacket(data) {
1224+
if rows.mc.isEOFPacket(data) {
12271225
if rows.mc.flags&clientDeprecateEOF == 0 {
12281226
rows.mc.status = readStatus(data[3:])
12291227
} else {
12301228
if err := rows.mc.handleOkPacket(data); err != nil {
1229+
rows.mc = nil
12311230
return err
12321231
}
12331232
}

rows.go

-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ func (rows *textRows) Next(dest []driver.Value) error {
215215
if err := mc.error(); err != nil {
216216
return err
217217
}
218-
219218
// Fetch next row from stream
220219
return rows.readRow(dest)
221220
}

statement.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
7373

7474
if resLen > 0 {
7575
// Columns
76-
if err = mc.readUntilEOF(); err != nil {
76+
if err = mc.readExactPackets(resLen); err != nil {
7777
return nil, err
7878
}
79-
8079
// Rows
8180
if err := mc.readUntilEOF(); err != nil {
8281
return nil, err

0 commit comments

Comments
 (0)