Skip to content

Commit 6218e09

Browse files
methanejulienschmidt
authored andcommitted
Fix canceled context broke mysqlConn (#862)
Fix #858
1 parent 885c931 commit 6218e09

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

Diff for: connection_go18.go

+7-8
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,21 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
149149
mc.cleanup()
150150
return nil
151151
}
152+
// When ctx is already cancelled, don't watch it.
153+
if err := ctx.Err(); err != nil {
154+
return err
155+
}
156+
// When ctx is not cancellable, don't watch it.
152157
if ctx.Done() == nil {
153158
return nil
154159
}
155-
156-
mc.watching = true
157-
select {
158-
default:
159-
case <-ctx.Done():
160-
return ctx.Err()
161-
}
160+
// When watcher is not alive, can't watch it.
162161
if mc.watcher == nil {
163162
return nil
164163
}
165164

165+
mc.watching = true
166166
mc.watcher <- ctx
167-
168167
return nil
169168
}
170169

Diff for: connection_go18_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
package mysql
1212

1313
import (
14+
"context"
1415
"database/sql/driver"
1516
"testing"
1617
)
@@ -28,3 +29,31 @@ func TestCheckNamedValue(t *testing.T) {
2829
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
2930
}
3031
}
32+
33+
// TestCleanCancel tests passed context is cancelled at start.
34+
// No packet should be sent. Connection should keep current status.
35+
func TestCleanCancel(t *testing.T) {
36+
mc := &mysqlConn{
37+
closech: make(chan struct{}),
38+
}
39+
mc.startWatcher()
40+
defer mc.cleanup()
41+
42+
ctx, cancel := context.WithCancel(context.Background())
43+
cancel()
44+
45+
for i := 0; i < 3; i++ { // Repeat same behavior
46+
err := mc.Ping(ctx)
47+
if err != context.Canceled {
48+
t.Errorf("expected context.Canceled, got %#v", err)
49+
}
50+
51+
if mc.closed.IsSet() {
52+
t.Error("expected mc is not closed, closed actually")
53+
}
54+
55+
if mc.watching {
56+
t.Error("expected watching is false, but true")
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)