Skip to content

Commit 3dd723a

Browse files
authored
accept: Add unwrapping for hijack like http.ResponseController (#472)
Since we rely on the connection not being hijacked too early (i.e. detecting the presence of http.Hijacker) to set headers, we must manually implement the unwrapping of the http.ResponseController. By doing so, we also retain Go 1.19 compatibility without build tags. Closes #455
1 parent 641f4f5 commit 3dd723a

File tree

4 files changed

+110
-1
lines changed

4 files changed

+110
-1
lines changed

accept.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
105105
}
106106
}
107107

108-
hj, ok := w.(http.Hijacker)
108+
hj, ok := hijacker(w)
109109
if !ok {
110110
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
111111
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)

accept_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,33 @@ func TestAccept(t *testing.T) {
143143
_, err := Accept(w, r, nil)
144144
assert.Contains(t, err, `failed to hijack connection`)
145145
})
146+
147+
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
148+
t.Parallel()
149+
150+
rr := httptest.NewRecorder()
151+
w := mockUnwrapper{
152+
ResponseWriter: rr,
153+
unwrap: func() http.ResponseWriter {
154+
return mockHijacker{
155+
ResponseWriter: rr,
156+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
157+
return nil, nil, errors.New("haha")
158+
},
159+
}
160+
},
161+
}
162+
163+
r := httptest.NewRequest("GET", "/", nil)
164+
r.Header.Set("Connection", "Upgrade")
165+
r.Header.Set("Upgrade", "websocket")
166+
r.Header.Set("Sec-WebSocket-Version", "13")
167+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
168+
169+
_, err := Accept(w, r, nil)
170+
assert.Contains(t, err, "failed to hijack connection")
171+
})
172+
146173
t.Run("closeRace", func(t *testing.T) {
147174
t.Parallel()
148175

@@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{}
534561
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
535562
return mj.hijack()
536563
}
564+
565+
type mockUnwrapper struct {
566+
http.ResponseWriter
567+
unwrap func() http.ResponseWriter
568+
}
569+
570+
var _ rwUnwrapper = mockUnwrapper{}
571+
572+
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
573+
return mu.unwrap()
574+
}

hijack.go

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//go:build !js
2+
3+
package websocket
4+
5+
import (
6+
"net/http"
7+
)
8+
9+
type rwUnwrapper interface {
10+
Unwrap() http.ResponseWriter
11+
}
12+
13+
// hijacker returns the Hijacker interface of the http.ResponseWriter.
14+
// It follows the Unwrap method of the http.ResponseWriter if available,
15+
// matching the behavior of http.ResponseController. If the Hijacker
16+
// interface is not found, it returns false.
17+
//
18+
// Since the http.ResponseController is not available in Go 1.19, and
19+
// does not support checking the presence of the Hijacker interface,
20+
// this function is used to provide a consistent way to check for the
21+
// Hijacker interface across Go versions.
22+
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
23+
for {
24+
switch t := rw.(type) {
25+
case http.Hijacker:
26+
return t, true
27+
case rwUnwrapper:
28+
rw = t.Unwrap()
29+
default:
30+
return nil, false
31+
}
32+
}
33+
}

hijack_go120_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//go:build !js && go1.20
2+
3+
package websocket
4+
5+
import (
6+
"bufio"
7+
"errors"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
13+
"github.com/coder/websocket/internal/test/assert"
14+
)
15+
16+
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
17+
t.Parallel()
18+
19+
rr := httptest.NewRecorder()
20+
w := mockUnwrapper{
21+
ResponseWriter: rr,
22+
unwrap: func() http.ResponseWriter {
23+
return mockHijacker{
24+
ResponseWriter: rr,
25+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
26+
return nil, nil, errors.New("haha")
27+
},
28+
}
29+
},
30+
}
31+
32+
_, _, err := http.NewResponseController(w).Hijack()
33+
assert.Contains(t, err, "haha")
34+
hj, ok := hijacker(w)
35+
assert.Equal(t, "hijacker found", ok, true)
36+
_, _, err = hj.Hijack()
37+
assert.Contains(t, err, "haha")
38+
}

0 commit comments

Comments
 (0)