Skip to content

Commit d7ddb8b

Browse files
authored
Fix issue 1567 (#1570)
### Description closes #1567 When TLS is enabled, `mc.netConn` is rewritten after the TLS handshak as detailed here: https://github.com/go-sql-driver/mysql/blob/d86c4527bae98ccd4e5060f72887520ce30eda5e/packets.go#L355 Therefore, `mc.netConn` should not be accessed within the watcher goroutine. Instead, `mc.rawConn` should be initialized prior to invoking `mc.startWatcher`, and `mc.rawConn` should be used in lieu of `mc.netConn`. ### Checklist - [x] Code compiles correctly - [x] Created tests which fail without the change (if possible) - [x] All tests passing - [x] Extended the README / documentation, if necessary - [x] Added myself / the copyright holder to the AUTHORS file <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Improved variable naming for better code readability and maintenance. - Enhanced network connection handling logic. - **New Features** - Updated TCP connection handling to better support TCP Keepalives. - **Tests** - Added a new test to address and verify the fix for a specific issue related to TLS, connection pooling, and round trip time estimation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent d86c452 commit d7ddb8b

File tree

4 files changed

+37
-5
lines changed

4 files changed

+37
-5
lines changed

connection.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ func (mc *mysqlConn) cleanup() {
153153

154154
// Makes cleanup idempotent
155155
close(mc.closech)
156-
nc := mc.netConn
157-
if nc == nil {
156+
conn := mc.rawConn
157+
if conn == nil {
158158
return
159159
}
160-
if err := nc.Close(); err != nil {
160+
if err := conn.Close(); err != nil {
161161
mc.log(err)
162162
}
163163
// This function can be called from multiple goroutines.

connector.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
102102
nd := net.Dialer{Timeout: mc.cfg.Timeout}
103103
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
104104
}
105-
106105
if err != nil {
107106
return nil, err
108107
}
108+
mc.rawConn = mc.netConn
109109

110110
// Enable TCP Keepalives on TCP connections
111111
if tc, ok := mc.netConn.(*net.TCPConn); ok {

driver_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"io"
2121
"log"
2222
"math"
23+
mrand "math/rand"
2324
"net"
2425
"net/url"
2526
"os"
@@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) {
35773578
}
35783579
}
35793580
}
3581+
3582+
func TestIssue1567(t *testing.T) {
3583+
// enable TLS.
3584+
runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) {
3585+
// disable connection pooling.
3586+
// data race happens when new connection is created.
3587+
dbt.db.SetMaxIdleConns(0)
3588+
3589+
// estimate round trip time.
3590+
start := time.Now()
3591+
if err := dbt.db.PingContext(context.Background()); err != nil {
3592+
t.Fatal(err)
3593+
}
3594+
rtt := time.Since(start)
3595+
if rtt <= 0 {
3596+
// In some environments, rtt may become 0, so set it to at least 1ms.
3597+
rtt = time.Millisecond
3598+
}
3599+
3600+
count := 1000
3601+
if testing.Short() {
3602+
count = 10
3603+
}
3604+
3605+
for i := 0; i < count; i++ {
3606+
timeout := time.Duration(mrand.Int63n(int64(rtt)))
3607+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
3608+
dbt.db.PingContext(ctx)
3609+
cancel()
3610+
}
3611+
})
3612+
}

packets.go

-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
351351
if err := tlsConn.Handshake(); err != nil {
352352
return err
353353
}
354-
mc.rawConn = mc.netConn
355354
mc.netConn = tlsConn
356355
mc.buf.nc = tlsConn
357356
}

0 commit comments

Comments
 (0)