@@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode {
97
97
//
98
98
// Close will unblock all goroutines interacting with the connection once
99
99
// complete.
100
- func (c * Conn ) Close (code StatusCode , reason string ) error {
101
- defer c .wg .Wait ()
102
- return c .closeHandshake (code , reason )
100
+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
101
+ defer errd .Wrap (& err , "failed to close WebSocket" )
102
+
103
+ if ! c .casClosing () {
104
+ err = c .waitGoroutines ()
105
+ if err != nil {
106
+ return err
107
+ }
108
+ return net .ErrClosed
109
+ }
110
+ defer func () {
111
+ if errors .Is (err , net .ErrClosed ) {
112
+ err = nil
113
+ }
114
+ }()
115
+
116
+ err = c .closeHandshake (code , reason )
117
+
118
+ err2 := c .close ()
119
+ if err == nil && err2 != nil {
120
+ err = err2
121
+ }
122
+
123
+ err2 = c .waitGoroutines ()
124
+ if err == nil && err2 != nil {
125
+ err = err2
126
+ }
127
+
128
+ return err
103
129
}
104
130
105
131
// CloseNow closes the WebSocket connection without attempting a close handshake.
106
132
// Use when you do not want the overhead of the close handshake.
107
133
func (c * Conn ) CloseNow () (err error ) {
108
- defer c .wg .Wait ()
109
134
defer errd .Wrap (& err , "failed to close WebSocket" )
110
135
111
- if c .isClosed () {
136
+ if ! c .casClosing () {
137
+ err = c .waitGoroutines ()
138
+ if err != nil {
139
+ return err
140
+ }
112
141
return net .ErrClosed
113
142
}
143
+ defer func () {
144
+ if errors .Is (err , net .ErrClosed ) {
145
+ err = nil
146
+ }
147
+ }()
114
148
115
- c .close (nil )
116
- return c .closeErr
117
- }
118
-
119
- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
120
- defer errd .Wrap (& err , "failed to close WebSocket" )
121
-
122
- writeErr := c .writeClose (code , reason )
123
- closeHandshakeErr := c .waitCloseHandshake ()
149
+ err = c .close ()
124
150
125
- if writeErr != nil {
126
- return writeErr
151
+ err2 := c .waitGoroutines ()
152
+ if err == nil && err2 != nil {
153
+ err = err2
127
154
}
155
+ return err
156
+ }
128
157
129
- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
130
- return closeHandshakeErr
158
+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
159
+ err := c .writeClose (code , reason )
160
+ if err != nil {
161
+ return err
131
162
}
132
163
164
+ err = c .waitCloseHandshake ()
165
+ if CloseStatus (err ) != code {
166
+ return err
167
+ }
133
168
return nil
134
169
}
135
170
136
171
func (c * Conn ) writeClose (code StatusCode , reason string ) error {
137
- c .closeMu .Lock ()
138
- wroteClose := c .wroteClose
139
- c .wroteClose = true
140
- c .closeMu .Unlock ()
141
- if wroteClose {
142
- return net .ErrClosed
143
- }
144
-
145
172
ce := CloseError {
146
173
Code : code ,
147
174
Reason : reason ,
148
175
}
149
176
150
177
var p []byte
151
- var marshalErr error
178
+ var err error
152
179
if ce .Code != StatusNoStatusRcvd {
153
- p , marshalErr = ce .bytes ()
154
- }
155
-
156
- writeErr := c .writeControl (context .Background (), opClose , p )
157
- if CloseStatus (writeErr ) != - 1 {
158
- // Not a real error if it's due to a close frame being received.
159
- writeErr = nil
180
+ p , err = ce .bytes ()
181
+ if err != nil {
182
+ return err
183
+ }
160
184
}
161
185
162
- // We do this after in case there was an error writing the close frame.
163
- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
186
+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
187
+ defer cancel ( )
164
188
165
- if marshalErr != nil {
166
- return marshalErr
189
+ err = c .writeControl (ctx , opClose , p )
190
+ // If the connection closed as we're writing we ignore the error as we might
191
+ // have written the close frame, the peer responded and then someone else read it
192
+ // and closed the connection.
193
+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
194
+ return err
167
195
}
168
- return writeErr
196
+ return nil
169
197
}
170
198
171
199
func (c * Conn ) waitCloseHandshake () error {
172
- defer c .close (nil )
173
-
174
200
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
175
201
defer cancel ()
176
202
@@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error {
180
206
}
181
207
defer c .readMu .unlock ()
182
208
183
- if c .readCloseFrameErr != nil {
184
- return c .readCloseFrameErr
185
- }
186
-
187
209
for i := int64 (0 ); i < c .msgReader .payloadLength ; i ++ {
188
210
_ , err := c .br .ReadByte ()
189
211
if err != nil {
@@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
206
228
}
207
229
}
208
230
231
+ func (c * Conn ) waitGoroutines () error {
232
+ t := time .NewTimer (time .Second * 15 )
233
+ defer t .Stop ()
234
+
235
+ select {
236
+ case <- c .timeoutLoopDone :
237
+ case <- t .C :
238
+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
239
+ }
240
+
241
+ c .closeReadMu .Lock ()
242
+ closeRead := c .closeReadCtx != nil
243
+ c .closeReadMu .Unlock ()
244
+ if closeRead {
245
+ select {
246
+ case <- c .closeReadDone :
247
+ case <- t .C :
248
+ return errors .New ("failed to wait for close read goroutine to exit" )
249
+ }
250
+ }
251
+
252
+ select {
253
+ case <- c .closed :
254
+ case <- t .C :
255
+ return errors .New ("failed to wait for connection to be closed" )
256
+ }
257
+
258
+ return nil
259
+ }
260
+
209
261
func parseClosePayload (p []byte ) (CloseError , error ) {
210
262
if len (p ) == 0 {
211
263
return CloseError {
@@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
276
328
return buf , nil
277
329
}
278
330
279
- func (c * Conn ) setCloseErr ( err error ) {
331
+ func (c * Conn ) casClosing () bool {
280
332
c .closeMu .Lock ()
281
- c .setCloseErrLocked (err )
282
- c .closeMu .Unlock ()
283
- }
284
-
285
- func (c * Conn ) setCloseErrLocked (err error ) {
286
- if c .closeErr == nil && err != nil {
287
- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
333
+ defer c .closeMu .Unlock ()
334
+ if ! c .closing {
335
+ c .closing = true
336
+ return true
288
337
}
338
+ return false
289
339
}
290
340
291
341
func (c * Conn ) isClosed () bool {
0 commit comments