Skip to content

Commit 230a06c

Browse files
committed
Add tests
1 parent 12cbf54 commit 230a06c

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

accept_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,53 @@ func TestAccept(t *testing.T) {
143143
_, err := Accept(w, r, nil)
144144
assert.Contains(t, err, `failed to hijack connection`)
145145
})
146+
147+
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
148+
t.Parallel()
149+
150+
rr := httptest.NewRecorder()
151+
w := mockUnwrapper{
152+
ResponseWriter: rr,
153+
unwrap: func() http.ResponseWriter {
154+
return mockHijacker{
155+
ResponseWriter: rr,
156+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
157+
return nil, nil, errors.New("haha")
158+
},
159+
}
160+
},
161+
}
162+
163+
r := httptest.NewRequest("GET", "/", nil)
164+
r.Header.Set("Connection", "Upgrade")
165+
r.Header.Set("Upgrade", "websocket")
166+
r.Header.Set("Sec-WebSocket-Version", "13")
167+
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
168+
169+
_, err := Accept(w, r, nil)
170+
assert.Contains(t, err, "failed to hijack connection")
171+
})
172+
173+
t.Run("hijackerHTTPResponseControllerCompatibility", func(t *testing.T) {
174+
t.Parallel()
175+
176+
rr := httptest.NewRecorder()
177+
w := mockUnwrapper{
178+
ResponseWriter: rr,
179+
unwrap: func() http.ResponseWriter {
180+
return mockHijacker{
181+
ResponseWriter: rr,
182+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
183+
return nil, nil, errors.New("haha")
184+
},
185+
}
186+
},
187+
}
188+
189+
_, _, err := http.NewResponseController(w).Hijack()
190+
assert.Contains(t, err, "haha")
191+
})
192+
146193
t.Run("closeRace", func(t *testing.T) {
147194
t.Parallel()
148195

@@ -534,3 +581,14 @@ var _ http.Hijacker = mockHijacker{}
534581
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
535582
return mj.hijack()
536583
}
584+
585+
type mockUnwrapper struct {
586+
http.ResponseWriter
587+
unwrap func() http.ResponseWriter
588+
}
589+
590+
var _ rwUnwrapper = mockUnwrapper{}
591+
592+
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
593+
return mu.unwrap()
594+
}

0 commit comments

Comments
 (0)