Skip to content

Commit 735c6d3

Browse files
committed
update: coverage test
1 parent 885ca92 commit 735c6d3

File tree

4 files changed

+227
-52
lines changed

4 files changed

+227
-52
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
go.work
2222

2323
vendor/
24-
.idea/
24+
.idea/
25+
coverage.txt

response_writer_test.go

+102-3
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@ import (
1313
)
1414

1515
// mockResponseWriter implements http.ResponseWriter for testing
16+
// 更新 mockResponseWriter 以支持错误测试
1617
type mockResponseWriter struct {
1718
headers http.Header
1819
statuscode int
1920
body bytes.Buffer
21+
writeError error // 添加这个字段
22+
}
23+
24+
func (m *mockResponseWriter) Write(b []byte) (int, error) {
25+
if m.writeError != nil {
26+
return 0, m.writeError
27+
}
28+
return m.body.Write(b)
2029
}
2130

2231
func newMockResponseWriter() *mockResponseWriter {
2332
return &mockResponseWriter{headers: make(http.Header)}
2433
}
2534

26-
func (m *mockResponseWriter) Header() http.Header { return m.headers }
27-
func (m *mockResponseWriter) Write(b []byte) (int, error) { return m.body.Write(b) }
28-
func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code }
35+
func (m *mockResponseWriter) Header() http.Header { return m.headers }
36+
func (m *mockResponseWriter) WriteHeader(code int) { m.statuscode = code }
2937

3038
// TestResponseWriterBasic tests basic functionality of ResponseWriter
3139
func TestResponseWriterBasic(t *testing.T) {
@@ -208,3 +216,94 @@ func TestResponseWriterFlush(t *testing.T) {
208216
t.Errorf("Expected chunk2, got %s, err: %v", chunk2, err)
209217
}
210218
}
219+
220+
// TestResponseWriterRead 测试 Read 方法
221+
func TestResponseWriterRead(t *testing.T) {
222+
mock := newMockResponseWriter()
223+
w := newResponseWriter(mock)
224+
225+
// 写入测试数据
226+
testData := []byte("test data for reading")
227+
_, err := w.Write(testData)
228+
if err != nil {
229+
t.Fatalf("Failed to write test data: %v", err)
230+
}
231+
232+
// 测试读取
233+
buf := make([]byte, len(testData))
234+
n, err := w.Read(buf)
235+
if err != nil {
236+
t.Fatalf("Read failed: %v", err)
237+
}
238+
if n != len(testData) {
239+
t.Errorf("Expected to read %d bytes, got %d", len(testData), n)
240+
}
241+
if string(buf) != string(testData) {
242+
t.Errorf("Expected to read '%s', got '%s'", string(testData), string(buf))
243+
}
244+
245+
// 测试读取完后的EOF
246+
n, err = w.Read(buf)
247+
if err != io.EOF {
248+
t.Errorf("Expected EOF after reading all data, got %v", err)
249+
}
250+
}
251+
252+
// TestResponseWriterPush 测试 HTTP/2 Push 功能
253+
func TestResponseWriterPush(t *testing.T) {
254+
// 创建支持 HTTP/2 的测试服务器
255+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
256+
rw := newResponseWriter(w)
257+
err := rw.Push("/style.css", &http.PushOptions{
258+
Method: "GET",
259+
Header: http.Header{
260+
"Content-Type": []string{"text/css"},
261+
},
262+
})
263+
if err != nil && err != http.ErrNotSupported {
264+
t.Errorf("Push failed: %v", err)
265+
}
266+
rw.Write([]byte("main content"))
267+
})
268+
269+
server := httptest.NewUnstartedServer(handler)
270+
server.EnableHTTP2 = true
271+
server.StartTLS()
272+
defer server.Close()
273+
274+
// 发起请求
275+
client := server.Client()
276+
resp, err := client.Get(server.URL)
277+
if err != nil {
278+
t.Fatalf("Request failed: %v", err)
279+
}
280+
defer resp.Body.Close()
281+
282+
body, err := io.ReadAll(resp.Body)
283+
if err != nil {
284+
t.Fatalf("Failed to read response: %v", err)
285+
}
286+
if string(body) != "main content" {
287+
t.Errorf("Expected 'main content', got '%s'", string(body))
288+
}
289+
}
290+
291+
// TestResponseWriterWriteError 测试 Write 方法的错误处理
292+
func TestResponseWriterWriteError(t *testing.T) {
293+
// 创建一个会返回错误的 mock
294+
errMock := &mockResponseWriter{
295+
headers: make(http.Header),
296+
writeError: fmt.Errorf("write error"),
297+
}
298+
299+
w := newResponseWriter(errMock)
300+
301+
// 测试写入错误
302+
n, err := w.Write([]byte("test"))
303+
if err == nil {
304+
t.Error("Expected write error, got nil")
305+
}
306+
if n != 0 {
307+
t.Errorf("Expected 0 bytes written on error, got %d", n)
308+
}
309+
}

transport.go

-48
Original file line numberDiff line numberDiff line change
@@ -78,51 +78,3 @@ func newTransport(opts ...Option) *http.Transport {
7878
},
7979
}
8080
}
81-
82-
// RoundTrip implements the RoundTripper interface.
83-
// It processes requests by calling the RoundTripper method.
84-
// func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
85-
// return t.RoundTripper().RoundTrip(r)
86-
// }
87-
88-
// RoundTripper returns a configured http.RoundTripper.
89-
// It applies all registered middleware in reverse order.
90-
// func (t *Transport) RoundTripper(opts ...Option) http.RoundTripper {
91-
// return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
92-
// options := newOptions(t.opts, opts...)
93-
// if options.Transport == nil {
94-
// options.Transport = t.Transport
95-
// }
96-
// // Apply middleware in reverse order
97-
// for i := len(options.HttpRoundTripper) - 1; i >= 0; i-- {
98-
// options.Transport = options.HttpRoundTripper[i](options.Transport)
99-
// }
100-
// return options.Transport.RoundTrip(r)
101-
// })
102-
// }
103-
104-
// Redirect creates a middleware for handling HTTP redirects.
105-
// It handles 301 (Moved Permanently) and 302 (Found) status codes.
106-
func Redirect(next http.RoundTripper) http.RoundTripper {
107-
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
108-
response, err := next.RoundTrip(req)
109-
if err != nil {
110-
return response, err
111-
}
112-
// Check if redirection is needed
113-
if response.StatusCode != http.StatusMovedPermanently && response.StatusCode != http.StatusFound {
114-
return response, err
115-
}
116-
// Create redirect request
117-
if req, err = NewRequestWithContext(req.Context(), Options{
118-
Method: req.Method,
119-
URL: response.Header.Get("Location"),
120-
Header: req.Header,
121-
body: req.Body,
122-
}); err != nil {
123-
return response, err
124-
}
125-
// Execute redirect request
126-
return next.RoundTrip(req)
127-
})
128-
}

transport_test.go

+123
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ package requests
33
import (
44
"context"
55
"io"
6+
"net"
67
"net/http"
78
"net/http/httptest"
9+
"net/url"
810
"strings"
911
"testing"
12+
"time"
1013
)
1114

1215
func Test_Setup(t *testing.T) {
@@ -50,3 +53,123 @@ func Test_Setup(t *testing.T) {
5053
}
5154

5255
}
56+
57+
func TestWarpRoundTripper(t *testing.T) {
58+
// 测试装饰器链
59+
var order []string
60+
rt1 := RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
61+
order = append(order, "rt1")
62+
return &http.Response{StatusCode: 200}, nil
63+
})
64+
65+
rt2 := WarpRoundTripper(rt1)(http.DefaultTransport)
66+
_, err := rt2.RoundTrip(&http.Request{})
67+
if err != nil {
68+
t.Fatal(err)
69+
}
70+
if len(order) != 1 || order[0] != "rt1" {
71+
t.Error("装饰器执行顺序错误")
72+
}
73+
}
74+
75+
func TestNewTransport(t *testing.T) {
76+
tests := []struct {
77+
name string
78+
opts []Option
79+
test func(*testing.T, *http.Transport)
80+
}{
81+
{
82+
name: "Unix套接字",
83+
opts: []Option{URL("unix:///tmp/test.sock")},
84+
test: func(t *testing.T, tr *http.Transport) {
85+
_, err := tr.DialContext(context.Background(), "unix", "/tmp/test.sock")
86+
if err == nil {
87+
t.Error("期望Unix套接字连接失败")
88+
}
89+
},
90+
},
91+
{
92+
name: "本地地址绑定",
93+
opts: []Option{LocalAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})},
94+
test: func(t *testing.T, tr *http.Transport) {
95+
conn, err := tr.DialContext(context.Background(), "tcp", "example.com:80")
96+
if err == nil {
97+
conn.Close()
98+
}
99+
},
100+
},
101+
{
102+
name: "TLS配置",
103+
opts: []Option{Verify(false)},
104+
test: func(t *testing.T, tr *http.Transport) {
105+
if tr.TLSClientConfig.InsecureSkipVerify != true {
106+
t.Error("TLS验证配置错误")
107+
}
108+
},
109+
},
110+
{
111+
name: "连接池配置",
112+
opts: []Option{MaxConns(100)},
113+
test: func(t *testing.T, tr *http.Transport) {
114+
if tr.MaxIdleConns != 100 || tr.MaxIdleConnsPerHost != 100 {
115+
t.Error("连接池配置错误")
116+
}
117+
},
118+
},
119+
}
120+
121+
for _, tt := range tests {
122+
t.Run(tt.name, func(t *testing.T) {
123+
tr := newTransport(tt.opts...)
124+
tt.test(t, tr)
125+
})
126+
}
127+
}
128+
129+
func TestTransportWithRealServer(t *testing.T) {
130+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131+
time.Sleep(10 * time.Millisecond) // 模拟处理延迟
132+
w.Write([]byte("ok"))
133+
}))
134+
defer server.Close()
135+
136+
tr := newTransport(
137+
Timeout(100*time.Millisecond),
138+
MaxConns(10),
139+
Verify(false),
140+
)
141+
142+
client := &http.Client{Transport: tr}
143+
144+
// 并发测试
145+
for i := 0; i < 10; i++ {
146+
go func() {
147+
resp, err := client.Get(server.URL)
148+
if err != nil {
149+
t.Error(err)
150+
return
151+
}
152+
defer resp.Body.Close()
153+
}()
154+
}
155+
156+
time.Sleep(200 * time.Millisecond)
157+
}
158+
159+
func TestTransportProxy(t *testing.T) {
160+
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
161+
w.Write([]byte("proxy response"))
162+
}))
163+
defer proxyServer.Close()
164+
165+
tr := newTransport(Proxy(proxyServer.URL))
166+
167+
// 验证代理设置是否生效
168+
proxyURL, err := tr.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
169+
if err != nil {
170+
t.Fatal(err)
171+
}
172+
if proxyURL == nil {
173+
t.Error("代理未正确设置")
174+
}
175+
}

0 commit comments

Comments
 (0)