diff --git a/AUTHORS b/AUTHORS index 510b869b..a261819f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,6 +37,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov diff --git a/auth_test.go b/auth_test.go index 46e1e3b4..a8f1d4bd 100644 --- a/auth_test.go +++ b/auth_test.go @@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.TLS = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } diff --git a/benchmark_test.go b/benchmark_test.go index 5c9a046b..8275ebc4 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -440,3 +440,53 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { } }) } + +// BenchmarkReceiveMetadata measures performance of receiving more metadata than real data +func BenchmarkReceiveMetadata(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + + // Create a table with 1000 integer fields + createTableQuery := "CREATE TABLE large_integer_table (" + for i := 0; i < 1000; i++ { + createTableQuery += fmt.Sprintf("col_%d INT", i) + if i < 999 { + createTableQuery += ", " + } + } + createTableQuery += ")" + + // Initialize database + db := initDB(b, false, + "DROP TABLE IF EXISTS large_integer_table", + createTableQuery, + "INSERT INTO large_integer_table VALUES ("+ + strings.Repeat("0,", 999)+"0)", // Insert a row of zeros + ) + defer db.Close() + + // Prepare a SELECT query to retrieve metadata + stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) + defer stmt.Close() + + b.StartTimer() + + // Benchmark metadata retrieval + for i := 0; i < b.N; i++ { + rows := tb.checkRows(stmt.Query()) + + // Create a slice to scan all columns + values := make([]interface{}, 1000) + valuePtrs := make([]interface{}, 1000) + for j := range values { + valuePtrs[j] = &values[j] + } + rows.Next() + // Scan the row + err := rows.Scan(valuePtrs...) + tb.check(err) + + rows.Close() + } +} diff --git a/connection.go b/connection.go index 3e455a3f..cd84f29d 100644 --- a/connection.go +++ b/connection.go @@ -24,21 +24,22 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - result mysqlResult // managed by clearResult() and handleOkPacket(). - compIO *compIO - cfg *Config - connector *connector - maxAllowedPacket int - maxWriteSize int - flags clientFlag - status statusFlag - sequence uint8 - compressSequence uint8 - parseTime bool - compress bool + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO + cfg *Config + connector *connector + maxAllowedPacket int + maxWriteSize int + clientCapabilities capabilityFlag + clientExtCapabilities extendedCapabilityFlag + status statusFlag + sequence uint8 + compressSequence uint8 + parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(stmt.paramCount); err != nil { return nil, err } } if columnCount > 0 { - err = mc.readUntilEOF() + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(int(columnCount)); err != nil { + return nil, err + } + } } } @@ -370,19 +379,19 @@ func (mc *mysqlConn) exec(query string) error { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipResultSetRows(); err != nil { return err } } @@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, _, err = handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -453,7 +462,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -461,14 +470,14 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + return dest[0].([]byte), mc.skipResultSetRows() } } return nil, err diff --git a/connector.go b/connector.go index bc1d46af..fec1c3dd 100644 --- a/connector.go +++ b/connector.go @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err @@ -153,7 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil { mc.cleanup() return nil, err } @@ -167,7 +167,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } diff --git a/const.go b/const.go index 4aadcd64..b33b1452 100644 --- a/const.go +++ b/const.go @@ -43,10 +43,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags -type clientFlag uint32 +type capabilityFlag uint32 const ( - clientLongPassword clientFlag = 1 << iota + clientMySQL capabilityFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB @@ -73,6 +73,20 @@ const ( clientDeprecateEOF ) +// https://mariadb.com/kb/en/connection/#capabilities +type extendedCapabilityFlag uint32 + +const ( + progressIndicator extendedCapabilityFlag = 1 << iota + clientComMulti + clientStmtBulkOperations + clientExtendedMetadata + clientCacheMetadata + clientUnitBulkResult +) + +// https://mariadb.com/kb/en/connection/#capabilities + const ( comQuit byte = iota + 1 comInitDB diff --git a/driver_test.go b/driver_test.go index 00e82865..8569494e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) { } runTests(t, tdsn, func(dbt *DBTest) { + // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // when character_set_collations is set for the charset, it overrides the default collation + // so we need to check if the default collation is overridden + forceExpected := expected + var defaultCollations string + err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) + if err == nil { + // Query succeeded, need to check if we should override expected collation + collationMap := make(map[string]string) + pairs := strings.Split(defaultCollations, ",") + for _, pair := range pairs { + parts := strings.Split(pair, "=") + if len(parts) == 2 { + collationMap[parts[0]] = parts[1] + } + } + + // Get charset prefix from expected collation + parts := strings.Split(expected, "_") + if len(parts) > 0 { + charset := parts[0] + if newCollation, ok := collationMap[charset]; ok { + forceExpected = newCollation + } + } + } + var got string if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { dbt.Fatal(err) } if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) + if forceExpected != expected { + if got != forceExpected { + dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) + } + } else { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } } }) } @@ -1685,7 +1718,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1693,8 +1726,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() diff --git a/packets.go b/packets.go index 4b836216..1bdf2aea 100644 --- a/packets.go +++ b/packets.go @@ -174,19 +174,19 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { +func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, "", mc.handleErrorPacket(data) + return nil, 0, 0, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, "", fmt.Errorf( + return nil, 0, 0, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -204,15 +204,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if mc.flags&clientProtocol41 == 0 { - return nil, "", ErrOldProtocol + serverCapabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if serverCapabilities&clientProtocol41 == 0 { + return nil, serverCapabilities, 0, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if serverCapabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, "", ErrNoTLS + return nil, serverCapabilities, 0, "", ErrNoTLS } } pos += 2 @@ -222,11 +222,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + serverCapabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] - // reserved (all [00]) [10 bytes] - pos += 11 + // reserved (all [00]) [6 bytes] + pos += 7 + if serverCapabilities&clientMySQL == 0 { + // MariaDB server, use extended capability flags + serverExtendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + } + pos += 4 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -250,67 +255,74 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro } else { plugin = string(data[pos:]) } - // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, serverExtendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, 0, plugin, nil } -// Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { - // Adjust client flags based on server support - clientFlags := clientProtocol41 | - clientSecureConn | - clientLongPassword | - clientTransactions | - clientLocalFiles | - clientPluginAuth | - clientMultiResults | - mc.flags&clientConnectAttrs | - mc.flags&clientLongFlag - - sendConnectAttrs := mc.flags&clientConnectAttrs != 0 - - if mc.cfg.ClientFoundRows { - clientFlags |= clientFoundRows +// initClientCapabilities initializes the client capabilities based on server support and configuration +func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, cfg *Config) capabilityFlag { + + clientCapabilities := + clientMySQL | + clientLongFlag | + clientIgnoreSpace | + clientProtocol41 | + clientSecureConn | + clientTransactions | + clientPluginAuthLenEncClientData | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + clientConnectAttrs | + clientDeprecateEOF + + if cfg.ClientFoundRows { + clientCapabilities |= clientFoundRows } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { - clientFlags |= clientCompress + if cfg.compress { + clientCapabilities |= clientCompress } // To enable TLS / SSL if mc.cfg.TLS != nil { - clientFlags |= clientSSL + clientCapabilities |= clientSSL } if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements + clientCapabilities |= clientMultiStatements + } + if n := len(cfg.DBName); n > 0 { + clientCapabilities |= clientConnectWithDB } + return clientCapabilities & serverCapabilities +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string) error { + // Adjust client flags based on server support + mc.clientCapabilities = mc.initClientCapabilities(serverCapabilities, mc.cfg) + mc.clientExtCapabilities = clientCacheMetadata & serverExtendedCapabilities + + sendConnectAttrs := mc.clientCapabilities&clientConnectAttrs != 0 + // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLen := len(authResp) authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - if len(authRespLEI) > 1 { - // if the length can not be written in 1 byte, it must be written as a - // length encoded integer - clientFlags |= clientPluginAuthLenEncClientData - } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { - clientFlags |= clientConnectWithDB - pktLen += n + 1 - } + pktLen += len(mc.cfg.DBName) + 1 // encode length of the connection attributes var connAttrsLEI []byte @@ -328,8 +340,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } - // ClientFlags [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) + // clientCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[4:], uint32(mc.clientCapabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -348,9 +360,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Filler [23 bytes] (all 0x00) pos := 13 - for ; pos < 13+23; pos++ { + for ; pos < 13+19; pos++ { data[pos] = 0 } + if mc.clientCapabilities&clientMySQL == 0 { + // clientExtendedCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.clientExtCapabilities)) + pos += 4 + } else { + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -385,9 +406,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { pos += copy(data[pos:], mc.cfg.DBName) - data[pos] = 0x00 - pos++ } + data[pos] = 0x00 + pos++ pos += copy(data[pos:], plugin) data[pos] = 0x00 @@ -535,32 +556,37 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() if err != nil { - return 0, err + return 0, false, err } switch data[0] { case iOK: - return 0, mc.handleOkPacket(data) + return 0, false, mc.handleOkPacket(data) case iERR: - return 0, mc.conn().handleErrorPacket(data) + return 0, false, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, false, mc.handleInFileRequest(string(data[1:])) } // column count // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html - num, _, _ := readLengthEncodedInteger(data) + // https://mariadb.com/kb/en/result-set-packets/#column-count-packet + num, _, len := readLengthEncodedInteger(data) + + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + return int(num), data[len] == 0x01, nil + } // ignore remaining data in the packet. see #1478. - return int(num), nil + return int(num), true, nil } // Error Packet @@ -684,43 +710,28 @@ func (mc *okHandler) handleOkPacket(data []byte) error { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i == count { - return columns, nil - } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) - } - // Catalog - pos, err := skipLengthEncodedString(data) - if err != nil { - return nil, err - } - + pos := int(data[0]) + 1 // Database [len coded string] - n, err := skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 - // Table [len coded string] + // Table alias [len coded string] + // alias length can be up to 256 if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n - columns[i].tableName = string(tableName) + columns[i].tableName = tableName } else { - n, err = skipLengthEncodedString(data[pos:]) + n, err := skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } @@ -728,26 +739,18 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { } // Original table [len coded string] - n, err = skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 - // Name [len coded string] + // Name alias [len coded string] name, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } - columns[i].name = string(name) + columns[i].name = name pos += n // Original name [len coded string] - n, err = skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 // Filler [uint8] pos++ @@ -770,13 +773,13 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + } - // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // skip EOF packet if client does not support deprecateEOF + if err := mc.skipEof(); err != nil { + return nil, err } + return columns, nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -794,9 +797,16 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if data[0] == iEOF && len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // Ok Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -816,9 +826,9 @@ func (rows *textRows) readRow(dest []driver.Value) error { ) for i := range dest { - // Read bytes and convert to string + // Read field bytes var buf []byte - buf, isNull, n, err = readLengthEncodedString(data[pos:]) + buf, isNull, n, err = readLengthEncodedBytes(data[pos:]) pos += n if err != nil { @@ -861,6 +871,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { default: dest[i] = buf + continue } if err != nil { return err @@ -870,8 +881,33 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) skipPackets(number int) error { + for i := 0; i < number; i++ { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipEof() error { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipColumns(resLen int) error { + if err := mc.skipPackets(resLen); err != nil { + return err + } + return mc.skipEof() +} + +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) skipResultSetRows() error { for { data, err := mc.readPacket() if err != nil { @@ -882,10 +918,18 @@ func (mc *mysqlConn) readUntilEOF() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) == 5 { - mc.status = readStatus(data[3:]) + if len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // OK packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } + return nil } - return nil } } } @@ -1176,17 +1220,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // mc.affectedRows and mc.insertIds. func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, _, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipColumns(resLen); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipResultSetRows(); err != nil { return err } } @@ -1203,19 +1247,27 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + // EOF/OK Packet + if data[0] == iEOF { + if rows.mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + rows.mc.status = readStatus(data[3:]) + } else { + // OK Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + rows.mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } - mc := rows.mc - rows.mc = nil // Error otherwise + mc := rows.mc + rows.mc = nil return mc.handleErrorPacket(data) } @@ -1297,7 +1349,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { fieldTypeVector: var isNull bool var n int - dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + dest[i], isNull, n, err = readLengthEncodedBytes(data[pos:]) pos += n if err == nil { if !isNull { diff --git a/packets_test.go b/packets_test.go index 694b0564..b487051e 100644 --- a/packets_test.go +++ b/packets_test.go @@ -332,11 +332,19 @@ func TestRegression801(t *testing.T) { 112, 97, 115, 115, 119, 111, 114, 100} conn.maxReads = 1 - authData, pluginName, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, pluginName, err := mc.readHandshakePacket() if err != nil { t.Fatalf("got error: %v", err) } + if serverCapabilities != 2148530143 { + t.Fatalf("expected serverCapabilities to be 2148530143, got %v", serverCapabilities) + } + + if serverExtendedCapabilities != 0 { + t.Fatalf("expected serverExtendedCapabilities to be 0, got %v", serverExtendedCapabilities) + } + if pluginName != "mysql_native_password" { t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) } diff --git a/rows.go b/rows.go index df98417b..bfb821dc 100644 --- a/rows.go +++ b/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.skipResultSetRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.skipResultSetRows(); err != nil { return 0, err } rows.rs.done = true @@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() if err != nil { // Clean up about multi-results flag rows.rs.done = true diff --git a/statement.go b/statement.go index 35df8545..7c63f1ed 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField } func (stmt *mysqlStmt) Close() error { @@ -64,19 +65,19 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(resLen); err != nil { return nil, err } // Rows - if err := mc.readUntilEOF(); err != nil { + if err = mc.skipResultSetRows(); err != nil { return nil, err } } @@ -107,7 +108,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -116,7 +117,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + if metadataFollows { + if rows.rs.columns, err = mc.readColumns(resLen); err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + if err = mc.skipEof(); err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } } else { rows.rs.done = true diff --git a/utils.go b/utils.go index 8716c26c..92445a28 100644 --- a/utils.go +++ b/utils.go @@ -524,10 +524,7 @@ func uint64ToString(n uint64) []byte { return a[i:] } -// returns the string read as a bytes slice, whether the value is NULL, -// the number of bytes read and an error, in case the string is longer than -// the input slice -func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { +func readLengthEncodedBytes(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { @@ -543,6 +540,25 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { return nil, false, n, io.EOF } +// returns the string read as a bytes slice, whether the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func readLengthEncodedString(b []byte) (string, bool, int, error) { + // Get length + num, isNull, n := readLengthEncodedInteger(b) + if num < 1 { + return "", isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return string(b[n-int(num) : n : n]), false, n, nil + } + return "", false, n, io.EOF +} + // returns the number of bytes skipped and an error, in case the string is // longer than the input slice func skipLengthEncodedString(b []byte) (int, error) { @@ -567,7 +583,9 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - + if b[0] < 251 { + return uint64(b[0]), false, 1 + } switch b[0] { // 251: NULL case 0xfb: @@ -582,12 +600,9 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { return uint64(getUint24(b[1:])), false, 4 // 254: value of following 8 - case 0xfe: - return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9 + default: + return binary.LittleEndian.Uint64(b[1:]), false, 9 } - - // 0-250: value of first byte - return uint64(b[0]), false, 1 } // encodes a uint64 value and appends it to the given bytes slice