Skip to content

Commit b0ec201

Browse files
authored
Merge pull request #427 from alixander/fix-race
fix closenow race
2 parents 8d2374e + 250db1e commit b0ec201

10 files changed

+226
-162
lines changed

Diff for: accept_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"net/http/httptest"
1212
"strings"
13+
"sync"
1314
"testing"
1415

1516
"nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142143
_, err := Accept(w, r, nil)
143144
assert.Contains(t, err, `failed to hijack connection`)
144145
})
146+
t.Run("closeRace", func(t *testing.T) {
147+
t.Parallel()
148+
149+
server, _ := net.Pipe()
150+
151+
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
152+
newResponseWriter := func() http.ResponseWriter {
153+
return mockHijacker{
154+
ResponseWriter: httptest.NewRecorder(),
155+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
156+
return server, rw, nil
157+
},
158+
}
159+
}
160+
w := newResponseWriter()
161+
162+
r := httptest.NewRequest("GET", "/", nil)
163+
r.Header.Set("Connection", "Upgrade")
164+
r.Header.Set("Upgrade", "websocket")
165+
r.Header.Set("Sec-WebSocket-Version", "13")
166+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
167+
168+
c, err := Accept(w, r, nil)
169+
wg := &sync.WaitGroup{}
170+
wg.Add(2)
171+
go func() {
172+
c.Close(StatusInternalError, "the sky is falling")
173+
wg.Done()
174+
}()
175+
go func() {
176+
c.CloseNow()
177+
wg.Done()
178+
}()
179+
wg.Wait()
180+
assert.Success(t, err)
181+
})
145182
}
146183

147184
func Test_verifyClientHandshake(t *testing.T) {

Diff for: close.go

+103-53
Original file line numberDiff line numberDiff line change
@@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode {
9797
//
9898
// Close will unblock all goroutines interacting with the connection once
9999
// complete.
100-
func (c *Conn) Close(code StatusCode, reason string) error {
101-
defer c.wg.Wait()
102-
return c.closeHandshake(code, reason)
100+
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101+
defer errd.Wrap(&err, "failed to close WebSocket")
102+
103+
if !c.casClosing() {
104+
err = c.waitGoroutines()
105+
if err != nil {
106+
return err
107+
}
108+
return net.ErrClosed
109+
}
110+
defer func() {
111+
if errors.Is(err, net.ErrClosed) {
112+
err = nil
113+
}
114+
}()
115+
116+
err = c.closeHandshake(code, reason)
117+
118+
err2 := c.close()
119+
if err == nil && err2 != nil {
120+
err = err2
121+
}
122+
123+
err2 = c.waitGoroutines()
124+
if err == nil && err2 != nil {
125+
err = err2
126+
}
127+
128+
return err
103129
}
104130

105131
// CloseNow closes the WebSocket connection without attempting a close handshake.
106132
// Use when you do not want the overhead of the close handshake.
107133
func (c *Conn) CloseNow() (err error) {
108-
defer c.wg.Wait()
109134
defer errd.Wrap(&err, "failed to close WebSocket")
110135

111-
if c.isClosed() {
136+
if !c.casClosing() {
137+
err = c.waitGoroutines()
138+
if err != nil {
139+
return err
140+
}
112141
return net.ErrClosed
113142
}
143+
defer func() {
144+
if errors.Is(err, net.ErrClosed) {
145+
err = nil
146+
}
147+
}()
114148

115-
c.close(nil)
116-
return c.closeErr
117-
}
118-
119-
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
120-
defer errd.Wrap(&err, "failed to close WebSocket")
121-
122-
writeErr := c.writeClose(code, reason)
123-
closeHandshakeErr := c.waitCloseHandshake()
149+
err = c.close()
124150

125-
if writeErr != nil {
126-
return writeErr
151+
err2 := c.waitGoroutines()
152+
if err == nil && err2 != nil {
153+
err = err2
127154
}
155+
return err
156+
}
128157

129-
if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
130-
return closeHandshakeErr
158+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
159+
err := c.writeClose(code, reason)
160+
if err != nil {
161+
return err
131162
}
132163

164+
err = c.waitCloseHandshake()
165+
if CloseStatus(err) != code {
166+
return err
167+
}
133168
return nil
134169
}
135170

136171
func (c *Conn) writeClose(code StatusCode, reason string) error {
137-
c.closeMu.Lock()
138-
wroteClose := c.wroteClose
139-
c.wroteClose = true
140-
c.closeMu.Unlock()
141-
if wroteClose {
142-
return net.ErrClosed
143-
}
144-
145172
ce := CloseError{
146173
Code: code,
147174
Reason: reason,
148175
}
149176

150177
var p []byte
151-
var marshalErr error
178+
var err error
152179
if ce.Code != StatusNoStatusRcvd {
153-
p, marshalErr = ce.bytes()
154-
}
155-
156-
writeErr := c.writeControl(context.Background(), opClose, p)
157-
if CloseStatus(writeErr) != -1 {
158-
// Not a real error if it's due to a close frame being received.
159-
writeErr = nil
180+
p, err = ce.bytes()
181+
if err != nil {
182+
return err
183+
}
160184
}
161185

162-
// We do this after in case there was an error writing the close frame.
163-
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
186+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
187+
defer cancel()
164188

165-
if marshalErr != nil {
166-
return marshalErr
189+
err = c.writeControl(ctx, opClose, p)
190+
// If the connection closed as we're writing we ignore the error as we might
191+
// have written the close frame, the peer responded and then someone else read it
192+
// and closed the connection.
193+
if err != nil && !errors.Is(err, net.ErrClosed) {
194+
return err
167195
}
168-
return writeErr
196+
return nil
169197
}
170198

171199
func (c *Conn) waitCloseHandshake() error {
172-
defer c.close(nil)
173-
174200
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
175201
defer cancel()
176202

@@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error {
180206
}
181207
defer c.readMu.unlock()
182208

183-
if c.readCloseFrameErr != nil {
184-
return c.readCloseFrameErr
185-
}
186-
187209
for i := int64(0); i < c.msgReader.payloadLength; i++ {
188210
_, err := c.br.ReadByte()
189211
if err != nil {
@@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
206228
}
207229
}
208230

231+
func (c *Conn) waitGoroutines() error {
232+
t := time.NewTimer(time.Second * 15)
233+
defer t.Stop()
234+
235+
select {
236+
case <-c.timeoutLoopDone:
237+
case <-t.C:
238+
return errors.New("failed to wait for timeoutLoop goroutine to exit")
239+
}
240+
241+
c.closeReadMu.Lock()
242+
closeRead := c.closeReadCtx != nil
243+
c.closeReadMu.Unlock()
244+
if closeRead {
245+
select {
246+
case <-c.closeReadDone:
247+
case <-t.C:
248+
return errors.New("failed to wait for close read goroutine to exit")
249+
}
250+
}
251+
252+
select {
253+
case <-c.closed:
254+
case <-t.C:
255+
return errors.New("failed to wait for connection to be closed")
256+
}
257+
258+
return nil
259+
}
260+
209261
func parseClosePayload(p []byte) (CloseError, error) {
210262
if len(p) == 0 {
211263
return CloseError{
@@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
276328
return buf, nil
277329
}
278330

279-
func (c *Conn) setCloseErr(err error) {
331+
func (c *Conn) casClosing() bool {
280332
c.closeMu.Lock()
281-
c.setCloseErrLocked(err)
282-
c.closeMu.Unlock()
283-
}
284-
285-
func (c *Conn) setCloseErrLocked(err error) {
286-
if c.closeErr == nil && err != nil {
287-
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
333+
defer c.closeMu.Unlock()
334+
if !c.closing {
335+
c.closing = true
336+
return true
288337
}
338+
return false
289339
}
290340

291341
func (c *Conn) isClosed() bool {

0 commit comments

Comments
 (0)