Skip to content

Commit ee7b9af

Browse files
committed
Add tests
1 parent 12cbf54 commit ee7b9af

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

accept_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,33 @@ func TestAccept(t *testing.T) {
143143
_, err := Accept(w, r, nil)
144144
assert.Contains(t, err, `failed to hijack connection`)
145145
})
146+
147+
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
148+
t.Parallel()
149+
150+
rr := httptest.NewRecorder()
151+
w := mockUnwrapper{
152+
ResponseWriter: rr,
153+
unwrap: func() http.ResponseWriter {
154+
return mockHijacker{
155+
ResponseWriter: rr,
156+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
157+
return nil, nil, errors.New("haha")
158+
},
159+
}
160+
},
161+
}
162+
163+
r := httptest.NewRequest("GET", "/", nil)
164+
r.Header.Set("Connection", "Upgrade")
165+
r.Header.Set("Upgrade", "websocket")
166+
r.Header.Set("Sec-WebSocket-Version", "13")
167+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
168+
169+
_, err := Accept(w, r, nil)
170+
assert.Contains(t, err, "failed to hijack connection")
171+
})
172+
146173
t.Run("closeRace", func(t *testing.T) {
147174
t.Parallel()
148175

@@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{}
534561
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
535562
return mj.hijack()
536563
}
564+
565+
type mockUnwrapper struct {
566+
http.ResponseWriter
567+
unwrap func() http.ResponseWriter
568+
}
569+
570+
var _ rwUnwrapper = mockUnwrapper{}
571+
572+
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
573+
return mu.unwrap()
574+
}

hijack_go120_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//go:build !js && go1.20
2+
3+
package websocket
4+
5+
import (
6+
"bufio"
7+
"errors"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
13+
"github.com/coder/websocket/internal/test/assert"
14+
)
15+
16+
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
17+
t.Parallel()
18+
19+
rr := httptest.NewRecorder()
20+
w := mockUnwrapper{
21+
ResponseWriter: rr,
22+
unwrap: func() http.ResponseWriter {
23+
return mockHijacker{
24+
ResponseWriter: rr,
25+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
26+
return nil, nil, errors.New("haha")
27+
},
28+
}
29+
},
30+
}
31+
32+
_, _, err := http.NewResponseController(w).Hijack()
33+
assert.Contains(t, err, "haha")
34+
hj, ok := hijacker(w)
35+
assert.Equal(t, "hijacker found", ok, true)
36+
_, _, err = hj.Hijack()
37+
assert.Contains(t, err, "haha")
38+
}

0 commit comments

Comments
 (0)