Skip to content

Commit e0d410d

Browse files
committed
remove errBadConnNoWrite and markBadConn
1 parent 87443b9 commit e0d410d

File tree

5 files changed

+48
-70
lines changed

5 files changed

+48
-70
lines changed

connection.go

+12-22
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,12 @@ func (mc *mysqlConn) handleParams() (err error) {
111111
return
112112
}
113113

114-
func (mc *mysqlConn) markBadConn(err error) error {
115-
if mc == nil {
116-
return err
117-
}
118-
if err != errBadConnNoWrite {
119-
return err
120-
}
121-
return driver.ErrBadConn
122-
}
123-
124114
func (mc *mysqlConn) Begin() (driver.Tx, error) {
125115
return mc.begin(false)
126116
}
127117

128118
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
129119
if mc.closed.Load() {
130-
mc.log(ErrInvalidConn)
131120
return nil, driver.ErrBadConn
132121
}
133122
var q string
@@ -140,7 +129,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
140129
if err == nil {
141130
return &mysqlTx{mc}, err
142131
}
143-
return nil, mc.markBadConn(err)
132+
return nil, err
144133
}
145134

146135
func (mc *mysqlConn) Close() (err error) {
@@ -189,7 +178,6 @@ func (mc *mysqlConn) error() error {
189178

190179
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
191180
if mc.closed.Load() {
192-
mc.log(ErrInvalidConn)
193181
return nil, driver.ErrBadConn
194182
}
195183
// Send command
@@ -324,7 +312,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
324312

325313
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
326314
if mc.closed.Load() {
327-
mc.log(ErrInvalidConn)
328315
return nil, driver.ErrBadConn
329316
}
330317
if len(args) != 0 {
@@ -344,15 +331,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
344331
copied := mc.result
345332
return &copied, err
346333
}
347-
return nil, mc.markBadConn(err)
334+
return nil, err
348335
}
349336

350337
// Internal function to execute commands
351338
func (mc *mysqlConn) exec(query string) error {
352339
handleOk := mc.clearResult()
353340
// Send command
354341
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
355-
return mc.markBadConn(err)
342+
return err
356343
}
357344

358345
// Read Result
@@ -382,11 +369,10 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
382369

383370
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
384371
handleOk := mc.clearResult()
385-
386372
if mc.closed.Load() {
387-
mc.log(ErrInvalidConn)
388373
return nil, driver.ErrBadConn
389374
}
375+
390376
if len(args) != 0 {
391377
if !mc.cfg.InterpolateParams {
392378
return nil, driver.ErrSkip
@@ -398,17 +384,18 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
398384
}
399385
query = prepared
400386
}
387+
401388
// Send command
402389
err := mc.writeCommandPacketStr(comQuery, query)
403390
if err != nil {
404-
return nil, mc.markBadConn(err)
391+
return nil, err
405392
}
406393

407394
// Read Result
408395
var resLen int
409396
resLen, err = handleOk.readResultSetHeaderPacket()
410397
if err != nil {
411-
return nil, mc.markBadConn(err)
398+
return nil, err
412399
}
413400

414401
rows := new(textRows)
@@ -482,7 +469,6 @@ func (mc *mysqlConn) finish() {
482469
// Ping implements driver.Pinger interface
483470
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
484471
if mc.closed.Load() {
485-
mc.log(ErrInvalidConn)
486472
return driver.ErrBadConn
487473
}
488474

@@ -493,7 +479,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
493479

494480
handleOk := mc.clearResult()
495481
if err = mc.writeCommandPacket(comPing); err != nil {
496-
return mc.markBadConn(err)
482+
return err
497483
}
498484

499485
return handleOk.readResultOK()
@@ -699,8 +685,12 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
699685
return nil
700686
}
701687

688+
var _ driver.SessionResetter = &mysqlConn{}
689+
702690
// IsValid implements driver.Validator interface
703691
// (From Go 1.15)
704692
func (mc *mysqlConn) IsValid() bool {
705693
return !mc.closed.Load()
706694
}
695+
696+
var _ driver.Validator = &mysqlConn{}

connection_test.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) {
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
166+
closech: make(chan struct{}),
166167
}
167168

168169
err := mc.Ping(context.Background())
169170

170-
if err != driver.ErrBadConn {
171-
t.Errorf("expected driver.ErrBadConn, got %#v", err)
171+
if !errors.Is(err, nc.err) {
172+
t.Errorf("expected %v, got %#v", nc.err, err)
172173
}
173174
}
174175

@@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) {
184185

185186
err := mc.Ping(context.Background())
186187

187-
if err != ErrInvalidConn {
188-
t.Errorf("expected ErrInvalidConn, got %#v", err)
188+
if !errors.Is(err, nc.err) {
189+
t.Errorf("expected %v, got %#v", nc.err, err)
189190
}
190191
}
191192

errors.go

-6
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ var (
2929
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3030
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
3131
ErrBusyBuffer = errors.New("busy buffer")
32-
33-
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
34-
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
35-
// to trigger a resend.
36-
// See https://github.com/go-sql-driver/mysql/pull/302
37-
errBadConnNoWrite = errors.New("bad connection")
3832
)
3933

4034
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))

packets.go

+29-35
Original file line numberDiff line numberDiff line change
@@ -117,39 +117,33 @@ func (mc *mysqlConn) writePacket(data []byte) error {
117117
// Write packet
118118
if mc.writeTimeout > 0 {
119119
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
120-
mc.cleanup()
121120
mc.log(err)
121+
mc.cleanup()
122122
return err
123123
}
124124
}
125125

126126
n, err := mc.netConn.Write(data[:4+size])
127-
if err == nil && n == 4+size {
128-
mc.sequence++
129-
if size != maxPacketSize {
130-
return nil
131-
}
132-
pktLen -= size
133-
data = data[size:]
134-
continue
135-
}
136-
137-
// Handle error
138-
if err == nil { // n != len(data)
127+
if err != nil {
139128
mc.cleanup()
140-
mc.log(ErrMalformPkt)
141-
} else {
142129
if cerr := mc.canceled.Value(); cerr != nil {
143130
return cerr
144131
}
145-
if n == 0 && pktLen == len(data)-4 {
146-
// only for the first loop iteration when nothing was written yet
147-
return errBadConnNoWrite
148-
}
132+
return err
133+
}
134+
if n != 4+size {
135+
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
136+
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
149137
mc.cleanup()
150-
mc.log(err)
138+
return io.ErrShortWrite
151139
}
152-
return ErrInvalidConn
140+
141+
mc.sequence++
142+
if size != maxPacketSize {
143+
return nil
144+
}
145+
pktLen -= size
146+
data = data[size:]
153147
}
154148
}
155149

@@ -305,8 +299,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
305299
data, err := mc.buf.takeBuffer(pktLen + 4)
306300
if err != nil {
307301
// cannot take the buffer. Something must be wrong with the connection
308-
mc.log(err)
309-
return errBadConnNoWrite
302+
mc.cleanup()
303+
return err
310304
}
311305

312306
// ClientFlags [32 bit]
@@ -394,8 +388,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
394388
data, err := mc.buf.takeSmallBuffer(pktLen)
395389
if err != nil {
396390
// cannot take the buffer. Something must be wrong with the connection
397-
mc.log(err)
398-
return errBadConnNoWrite
391+
mc.cleanup()
392+
return err
399393
}
400394

401395
// Add the auth data [EOF]
@@ -414,8 +408,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
414408
data, err := mc.buf.takeSmallBuffer(4 + 1)
415409
if err != nil {
416410
// cannot take the buffer. Something must be wrong with the connection
417-
mc.log(err)
418-
return errBadConnNoWrite
411+
mc.cleanup()
412+
return err
419413
}
420414

421415
// Add command byte
@@ -433,8 +427,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
433427
data, err := mc.buf.takeBuffer(pktLen + 4)
434428
if err != nil {
435429
// cannot take the buffer. Something must be wrong with the connection
436-
mc.log(err)
437-
return errBadConnNoWrite
430+
mc.cleanup()
431+
return err
438432
}
439433

440434
// Add command byte
@@ -454,8 +448,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
454448
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
455449
if err != nil {
456450
// cannot take the buffer. Something must be wrong with the connection
457-
mc.log(err)
458-
return errBadConnNoWrite
451+
mc.cleanup()
452+
return err
459453
}
460454

461455
// Add command byte
@@ -997,8 +991,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
997991
}
998992
if err != nil {
999993
// cannot take the buffer. Something must be wrong with the connection
1000-
mc.log(err)
1001-
return errBadConnNoWrite
994+
mc.cleanup()
995+
return err
1002996
}
1003997

1004998
// command [1 byte]
@@ -1196,8 +1190,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11961190
if valuesCap != cap(paramValues) {
11971191
data = append(data[:pos], paramValues...)
11981192
if err = mc.buf.store(data); err != nil {
1199-
mc.log(err)
1200-
return errBadConnNoWrite
1193+
mc.cleanup()
1194+
return err
12011195
}
12021196
}
12031197

statement.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5757
// Send command
5858
err := stmt.writeExecutePacket(args)
5959
if err != nil {
60-
return nil, stmt.mc.markBadConn(err)
60+
return nil, err
6161
}
6262

6363
mc := stmt.mc
@@ -95,13 +95,12 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
9595

9696
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
9797
if stmt.mc.closed.Load() {
98-
stmt.mc.log(ErrInvalidConn)
9998
return nil, driver.ErrBadConn
10099
}
101100
// Send command
102101
err := stmt.writeExecutePacket(args)
103102
if err != nil {
104-
return nil, stmt.mc.markBadConn(err)
103+
return nil, err
105104
}
106105

107106
mc := stmt.mc

0 commit comments

Comments
 (0)