2
2
package websocketproxy
3
3
4
4
import (
5
+ "context"
5
6
"fmt"
6
7
"io"
7
8
"log"
47
48
// If nil, DefaultDialer is used.
48
49
Dialer * websocket.Dialer
49
50
50
- // Done specifies a channel for which all proxied websocket connections
51
+ // done specifies a channel for which all proxied websocket connections
51
52
// can be closed on demand by closing the channel.
52
- Done chan struct {}
53
+ done chan struct {}
53
54
}
54
55
55
56
websocketMsg struct {
@@ -186,6 +187,9 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
186
187
187
188
errClient := make (chan error , 1 )
188
189
errBackend := make (chan error , 1 )
190
+ if w .done == nil {
191
+ w .done = make (chan struct {})
192
+ }
189
193
190
194
replicateWebsocketConn := func (dst , src * websocket.Conn , errc chan error ) {
191
195
websocketMsgRcverC := make (chan websocketMsg , 1 )
@@ -214,7 +218,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
214
218
errc <- err
215
219
break
216
220
}
217
- case <- w .Done :
221
+ case <- w .done :
218
222
m := websocket .FormatCloseMessage (websocket .CloseGoingAway , "websocketproxy: closing connection" )
219
223
dst .WriteMessage (websocket .CloseMessage , m )
220
224
break
@@ -234,8 +238,18 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
234
238
if e , ok := err .(* websocket.CloseError ); ! ok || e .Code == websocket .CloseAbnormalClosure {
235
239
log .Printf ("websocketproxy: Error when copying from client to backend: %v" , err )
236
240
}
237
- case <- w .Done :
241
+ case <- w .done :
242
+ }
243
+ }
244
+
245
+ // Shutdown closes ws connections by closing the done channel they are subscribed to.
246
+ func (w * WebsocketProxy ) Shutdown (ctx context.Context ) error {
247
+ // TODO: support using context for control and return error when applicable
248
+ // Currently implemented such that the method signature matches http.Server.Shutdown()
249
+ if w .done != nil {
250
+ close (w .done )
238
251
}
252
+ return nil
239
253
}
240
254
241
255
func copyHeader (dst , src http.Header ) {
0 commit comments