Skip to content

Commit 11bda98

Browse files
FrauElstermafredri
andauthoredDec 4, 2024··
fix: avoid writing messages after close and improve handshake (#476)
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
1 parent 1253b77 commit 11bda98

File tree

5 files changed

+252
-65
lines changed

5 files changed

+252
-65
lines changed
 

‎close.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
100100
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101101
defer errd.Wrap(&err, "failed to close WebSocket")
102102

103-
if !c.casClosing() {
103+
if c.casClosing() {
104104
err = c.waitGoroutines()
105105
if err != nil {
106106
return err
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
133133
func (c *Conn) CloseNow() (err error) {
134134
defer errd.Wrap(&err, "failed to immediately close WebSocket")
135135

136-
if !c.casClosing() {
136+
if c.casClosing() {
137137
err = c.waitGoroutines()
138138
if err != nil {
139139
return err
@@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
329329
}
330330

331331
func (c *Conn) casClosing() bool {
332-
c.closeMu.Lock()
333-
defer c.closeMu.Unlock()
334-
if !c.closing {
335-
c.closing = true
336-
return true
337-
}
338-
return false
332+
return c.closing.Swap(true)
339333
}
340334

341335
func (c *Conn) isClosed() bool {

‎conn.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ type Conn struct {
6969
writeHeaderBuf [8]byte
7070
writeHeader header
7171

72+
// Close handshake state.
73+
closeStateMu sync.RWMutex
74+
closeReceivedErr error
75+
closeSentErr error
76+
77+
// CloseRead state.
7278
closeReadMu sync.Mutex
7379
closeReadCtx context.Context
7480
closeReadDone chan struct{}
7581

82+
closing atomic.Bool
83+
closeMu sync.Mutex // Protects following.
7684
closed chan struct{}
77-
closeMu sync.Mutex
78-
closing bool
7985

8086
pingCounter atomic.Int64
8187
activePingsMu sync.Mutex

‎conn_test.go

+148-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/http/httptest"
1314
"os"
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
460461
}
461462

462463
func BenchmarkConn(b *testing.B) {
463-
var benchCases = []struct {
464+
benchCases := []struct {
464465
name string
465466
mode websocket.CompressionMode
466467
}{
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
625626
}()
626627
}
627628
}
629+
630+
func TestConnClosePropagation(t *testing.T) {
631+
t.Parallel()
632+
633+
want := []byte("hello")
634+
keepWriting := func(c *websocket.Conn) <-chan error {
635+
return xsync.Go(func() error {
636+
for {
637+
err := c.Write(context.Background(), websocket.MessageText, want)
638+
if err != nil {
639+
return err
640+
}
641+
}
642+
})
643+
}
644+
keepReading := func(c *websocket.Conn) <-chan error {
645+
return xsync.Go(func() error {
646+
for {
647+
_, got, err := c.Read(context.Background())
648+
if err != nil {
649+
return err
650+
}
651+
if !bytes.Equal(want, got) {
652+
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
653+
}
654+
}
655+
})
656+
}
657+
checkReadErr := func(t *testing.T, err error) {
658+
// Check read error (output depends on when read is called in relation to connection closure).
659+
var ce websocket.CloseError
660+
if errors.As(err, &ce) {
661+
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
662+
} else {
663+
assert.ErrorIs(t, net.ErrClosed, err)
664+
}
665+
}
666+
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
667+
for _, c := range conn {
668+
// Check write error.
669+
err := c.Write(context.Background(), websocket.MessageText, want)
670+
assert.ErrorIs(t, net.ErrClosed, err)
671+
672+
_, _, err = c.Read(context.Background())
673+
checkReadErr(t, err)
674+
}
675+
}
676+
677+
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
678+
tt, this, other := newConnTest(t, nil, nil)
679+
680+
_ = this.CloseRead(tt.ctx)
681+
thisWriteErr := keepWriting(this)
682+
683+
_, got, err := other.Read(tt.ctx)
684+
assert.Success(t, err)
685+
assert.Equal(t, "msg", want, got)
686+
687+
err = other.Close(websocket.StatusNormalClosure, "")
688+
assert.Success(t, err)
689+
690+
select {
691+
case err := <-thisWriteErr:
692+
assert.ErrorIs(t, net.ErrClosed, err)
693+
case <-tt.ctx.Done():
694+
t.Fatal(tt.ctx.Err())
695+
}
696+
697+
checkConnErrs(t, this, other)
698+
})
699+
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
700+
tt, this, other := newConnTest(t, nil, nil)
701+
702+
_ = this.CloseRead(tt.ctx)
703+
thisWriteErr := keepWriting(this)
704+
otherReadErr := keepReading(other)
705+
706+
err := this.Close(websocket.StatusNormalClosure, "")
707+
assert.Success(t, err)
708+
709+
select {
710+
case err := <-thisWriteErr:
711+
assert.ErrorIs(t, net.ErrClosed, err)
712+
case <-tt.ctx.Done():
713+
t.Fatal(tt.ctx.Err())
714+
}
715+
716+
select {
717+
case err := <-otherReadErr:
718+
checkReadErr(t, err)
719+
case <-tt.ctx.Done():
720+
t.Fatal(tt.ctx.Err())
721+
}
722+
723+
checkConnErrs(t, this, other)
724+
})
725+
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
726+
tt, this, other := newConnTest(t, nil, nil)
727+
728+
_ = other.CloseRead(tt.ctx)
729+
errs := keepReading(this)
730+
731+
err := other.Write(tt.ctx, websocket.MessageText, want)
732+
assert.Success(t, err)
733+
734+
err = other.Close(websocket.StatusNormalClosure, "")
735+
assert.Success(t, err)
736+
737+
select {
738+
case err := <-errs:
739+
checkReadErr(t, err)
740+
case <-tt.ctx.Done():
741+
t.Fatal(tt.ctx.Err())
742+
}
743+
744+
checkConnErrs(t, this, other)
745+
})
746+
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
747+
tt, this, other := newConnTest(t, nil, nil)
748+
749+
thisReadErr := keepReading(this)
750+
otherReadErr := keepReading(other)
751+
752+
err := other.Write(tt.ctx, websocket.MessageText, want)
753+
assert.Success(t, err)
754+
755+
err = this.Close(websocket.StatusNormalClosure, "")
756+
assert.Success(t, err)
757+
758+
select {
759+
case err := <-thisReadErr:
760+
checkReadErr(t, err)
761+
case <-tt.ctx.Done():
762+
t.Fatal(tt.ctx.Err())
763+
}
764+
765+
select {
766+
case err := <-otherReadErr:
767+
checkReadErr(t, err)
768+
case <-tt.ctx.Done():
769+
t.Fatal(tt.ctx.Err())
770+
}
771+
772+
checkConnErrs(t, this, other)
773+
})
774+
}

‎read.go

+59-35
Original file line numberDiff line numberDiff line change
@@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
217217
}
218218
}
219219

220-
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
220+
// prepareRead sets the readTimeout context and returns a done function
221+
// to be called after the read is done. It also returns an error if the
222+
// connection is closed. The reference to the error is used to assign
223+
// an error depending on if the connection closed or the context timed
224+
// out during use. Typically the referenced error is a named return
225+
// variable of the function calling this method.
226+
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
221227
select {
222228
case <-c.closed:
223-
return header{}, net.ErrClosed
229+
return nil, net.ErrClosed
224230
case c.readTimeout <- ctx:
225231
}
226232

227-
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
228-
if err != nil {
233+
done := func() {
229234
select {
230235
case <-c.closed:
231-
return header{}, net.ErrClosed
232-
case <-ctx.Done():
233-
return header{}, ctx.Err()
234-
default:
235-
return header{}, err
236+
if *err != nil {
237+
*err = net.ErrClosed
238+
}
239+
case c.readTimeout <- context.Background():
240+
}
241+
if *err != nil && ctx.Err() != nil {
242+
*err = ctx.Err()
236243
}
237244
}
238245

239-
select {
240-
case <-c.closed:
241-
return header{}, net.ErrClosed
242-
case c.readTimeout <- context.Background():
246+
c.closeStateMu.Lock()
247+
closeReceivedErr := c.closeReceivedErr
248+
c.closeStateMu.Unlock()
249+
if closeReceivedErr != nil {
250+
defer done()
251+
return nil, closeReceivedErr
243252
}
244253

245-
return h, nil
254+
return done, nil
246255
}
247256

248-
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
249-
select {
250-
case <-c.closed:
251-
return 0, net.ErrClosed
252-
case c.readTimeout <- ctx:
257+
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
258+
readDone, err := c.prepareRead(ctx, &err)
259+
if err != nil {
260+
return header{}, err
253261
}
262+
defer readDone()
254263

255-
n, err := io.ReadFull(c.br, p)
264+
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
256265
if err != nil {
257-
select {
258-
case <-c.closed:
259-
return n, net.ErrClosed
260-
case <-ctx.Done():
261-
return n, ctx.Err()
262-
default:
263-
return n, fmt.Errorf("failed to read frame payload: %w", err)
264-
}
266+
return header{}, err
265267
}
266268

267-
select {
268-
case <-c.closed:
269-
return n, net.ErrClosed
270-
case c.readTimeout <- context.Background():
269+
return h, nil
270+
}
271+
272+
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
273+
readDone, err := c.prepareRead(ctx, &err)
274+
if err != nil {
275+
return 0, err
276+
}
277+
defer readDone()
278+
279+
n, err := io.ReadFull(c.br, p)
280+
if err != nil {
281+
return n, fmt.Errorf("failed to read frame payload: %w", err)
271282
}
272283

273284
return n, err
@@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
325336
}
326337

327338
err = fmt.Errorf("received close frame: %w", ce)
328-
c.writeClose(ce.Code, ce.Reason)
329-
c.readMu.unlock()
330-
c.close()
339+
c.closeStateMu.Lock()
340+
c.closeReceivedErr = err
341+
closeSent := c.closeSentErr != nil
342+
c.closeStateMu.Unlock()
343+
344+
// Only unlock readMu if this connection is being closed becaue
345+
// c.close will try to acquire the readMu lock. We unlock for
346+
// writeClose as well because it may also call c.close.
347+
if !closeSent {
348+
c.readMu.unlock()
349+
_ = c.writeClose(ce.Code, ce.Reason)
350+
}
351+
if !c.casClosing() {
352+
c.readMu.unlock()
353+
_ = c.close()
354+
}
331355
return err
332356
}
333357

‎write.go

+34-18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package websocket
55

66
import (
77
"bufio"
8+
"compress/flate"
89
"context"
910
"crypto/rand"
1011
"encoding/binary"
@@ -14,8 +15,6 @@ import (
1415
"net"
1516
"time"
1617

17-
"compress/flate"
18-
1918
"github.com/coder/websocket/internal/errd"
2019
"github.com/coder/websocket/internal/util"
2120
)
@@ -249,22 +248,36 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
249248
}
250249
defer c.writeFrameMu.unlock()
251250

251+
defer func() {
252+
if c.isClosed() && opcode == opClose {
253+
err = nil
254+
}
255+
if err != nil {
256+
if ctx.Err() != nil {
257+
err = ctx.Err()
258+
} else if c.isClosed() {
259+
err = net.ErrClosed
260+
}
261+
err = fmt.Errorf("failed to write frame: %w", err)
262+
}
263+
}()
264+
265+
c.closeStateMu.Lock()
266+
closeSentErr := c.closeSentErr
267+
c.closeStateMu.Unlock()
268+
if closeSentErr != nil {
269+
return 0, net.ErrClosed
270+
}
271+
252272
select {
253273
case <-c.closed:
254274
return 0, net.ErrClosed
255275
case c.writeTimeout <- ctx:
256276
}
257-
258277
defer func() {
259-
if err != nil {
260-
select {
261-
case <-c.closed:
262-
err = net.ErrClosed
263-
case <-ctx.Done():
264-
err = ctx.Err()
265-
default:
266-
}
267-
err = fmt.Errorf("failed to write frame: %w", err)
278+
select {
279+
case <-c.closed:
280+
case c.writeTimeout <- context.Background():
268281
}
269282
}()
270283

@@ -303,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
303316
}
304317
}
305318

306-
select {
307-
case <-c.closed:
308-
if opcode == opClose {
309-
return n, nil
319+
if opcode == opClose {
320+
c.closeStateMu.Lock()
321+
c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
322+
closeReceived := c.closeReceivedErr != nil
323+
c.closeStateMu.Unlock()
324+
325+
if closeReceived && !c.casClosing() {
326+
c.writeFrameMu.unlock()
327+
_ = c.close()
310328
}
311-
return n, net.ErrClosed
312-
case c.writeTimeout <- context.Background():
313329
}
314330

315331
return n, nil

0 commit comments

Comments
 (0)
Please sign in to comment.