|
1 | 1 | package requests
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "compress/gzip" |
4 | 5 | "context"
|
5 | 6 | "encoding/base64"
|
| 7 | + "errors" |
| 8 | + "fmt" |
| 9 | + "io" |
6 | 10 | "net"
|
7 | 11 | "net/http"
|
| 12 | + "net/http/httptest" |
8 | 13 | "net/url"
|
| 14 | + "os" |
| 15 | + "strings" |
9 | 16 | "testing"
|
10 | 17 | "time"
|
11 | 18 | )
|
@@ -205,14 +212,64 @@ func TestLogfOption(t *testing.T) {
|
205 | 212 | }
|
206 | 213 | }
|
207 | 214 |
|
208 |
| -func TestProxyOption(t *testing.T) { |
209 |
| - proxyURL := "http://localhost:8080" |
210 |
| - opts := newOptions([]Option{ |
211 |
| - Proxy(proxyURL), |
212 |
| - }) |
| 215 | +func TestProxy(t *testing.T) { |
| 216 | + tests := []struct { |
| 217 | + name string |
| 218 | + proxyAddr string |
| 219 | + wantPanic bool |
| 220 | + }{ |
| 221 | + { |
| 222 | + name: "空代理地址", |
| 223 | + proxyAddr: "", |
| 224 | + wantPanic: false, |
| 225 | + }, |
| 226 | + { |
| 227 | + name: "有效的代理地址", |
| 228 | + proxyAddr: "http://127.0.0.1:8080", |
| 229 | + wantPanic: false, |
| 230 | + }, |
| 231 | + { |
| 232 | + name: "无效的代理地址", |
| 233 | + proxyAddr: "://invalid-proxy", |
| 234 | + wantPanic: true, |
| 235 | + }, |
| 236 | + } |
213 | 237 |
|
214 |
| - if opts.Proxy == nil { |
215 |
| - t.Error("Proxy 函数未设置") |
| 238 | + for _, tt := range tests { |
| 239 | + t.Run(tt.name, func(t *testing.T) { |
| 240 | + defer func() { |
| 241 | + r := recover() |
| 242 | + if (r != nil) != tt.wantPanic { |
| 243 | + t.Errorf("Proxy() panic = %v, wantPanic = %v", r, tt.wantPanic) |
| 244 | + } |
| 245 | + if tt.wantPanic && r != nil { |
| 246 | + // 验证 panic 信息 |
| 247 | + if panicMsg, ok := r.(string); !ok || !strings.Contains(panicMsg, "parse proxy addr:") { |
| 248 | + t.Errorf("期望的 panic 信息包含 'parse proxy addr:', 得到 %v", r) |
| 249 | + } |
| 250 | + } |
| 251 | + }() |
| 252 | + |
| 253 | + opt := newOptions([]Option{Proxy(tt.proxyAddr)}) |
| 254 | + |
| 255 | + if tt.proxyAddr == "" { |
| 256 | + proxyURL, err := url.Parse(tt.proxyAddr) |
| 257 | + if err != nil { |
| 258 | + t.Fatalf("解析代理地址时出错: %v", err) |
| 259 | + } |
| 260 | + |
| 261 | + if proxyURL.String() != os.Getenv("HTTP_PROXY") { |
| 262 | + t.Error("空代理地址应该使用默认的环境代理设置") |
| 263 | + } |
| 264 | + t.Log("3") |
| 265 | + |
| 266 | + return |
| 267 | + } |
| 268 | + |
| 269 | + if !tt.wantPanic && opt.Proxy == nil { |
| 270 | + t.Error("代理函数未被设置") |
| 271 | + } |
| 272 | + }) |
216 | 273 | }
|
217 | 274 | }
|
218 | 275 |
|
@@ -249,3 +306,247 @@ func TestCertKeyOption(t *testing.T) {
|
249 | 306 | opts.certFile, opts.keyFile, certFile, keyFile)
|
250 | 307 | }
|
251 | 308 | }
|
| 309 | + |
| 310 | +func TestRoundTripper(t *testing.T) { |
| 311 | + // 创建一个自定义的 RoundTripper |
| 312 | + customTransport := &http.Transport{ |
| 313 | + MaxIdleConns: 100, |
| 314 | + MaxIdleConnsPerHost: 10, |
| 315 | + } |
| 316 | + |
| 317 | + // 创建测试服务器 |
| 318 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 319 | + w.Write([]byte("test response")) |
| 320 | + })) |
| 321 | + defer server.Close() |
| 322 | + |
| 323 | + tests := []struct { |
| 324 | + name string |
| 325 | + transport http.RoundTripper |
| 326 | + wantErr bool |
| 327 | + checkResult func(*testing.T, *http.Response, error) |
| 328 | + }{ |
| 329 | + { |
| 330 | + name: "使用自定义Transport", |
| 331 | + transport: customTransport, |
| 332 | + checkResult: func(t *testing.T, resp *http.Response, err error) { |
| 333 | + if err != nil { |
| 334 | + t.Errorf("请求失败: %v", err) |
| 335 | + } |
| 336 | + if resp.StatusCode != http.StatusOK { |
| 337 | + t.Errorf("期望状态码 200,得到 %d", resp.StatusCode) |
| 338 | + } |
| 339 | + }, |
| 340 | + }, |
| 341 | + { |
| 342 | + name: "使用nil Transport应该使用默认值", |
| 343 | + transport: nil, |
| 344 | + checkResult: func(t *testing.T, resp *http.Response, err error) { |
| 345 | + if err != nil { |
| 346 | + t.Errorf("请求失败: %v", err) |
| 347 | + } |
| 348 | + if resp.StatusCode != http.StatusOK { |
| 349 | + t.Errorf("期望状态码 200,得到 %d", resp.StatusCode) |
| 350 | + } |
| 351 | + }, |
| 352 | + }, |
| 353 | + } |
| 354 | + |
| 355 | + for _, tt := range tests { |
| 356 | + t.Run(tt.name, func(t *testing.T) { |
| 357 | + // 创建客户端 |
| 358 | + client := New( |
| 359 | + RoundTripper(tt.transport), |
| 360 | + ) |
| 361 | + |
| 362 | + // 发送请求 |
| 363 | + resp, err := client.DoRequest( |
| 364 | + context.Background(), |
| 365 | + URL(server.URL), |
| 366 | + Method("GET"), |
| 367 | + ) |
| 368 | + |
| 369 | + // 检查结果 |
| 370 | + tt.checkResult(t, resp.Response, err) |
| 371 | + |
| 372 | + // 验证 Transport 是否正确设置 |
| 373 | + if tt.transport != nil { |
| 374 | + opts := newOptions([]Option{RoundTripper(tt.transport)}) |
| 375 | + if opts.Transport != tt.transport { |
| 376 | + t.Error("Transport 未被正确设置") |
| 377 | + } |
| 378 | + } |
| 379 | + }) |
| 380 | + } |
| 381 | +} |
| 382 | + |
| 383 | +func TestHost(t *testing.T) { |
| 384 | + // 创建测试服务器 |
| 385 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 386 | + // 验证 Host 头 |
| 387 | + if host := r.Host; host != "example.com" { |
| 388 | + t.Skipf("期望 Host 为 example.com,得到 %s", host) |
| 389 | + } |
| 390 | + if host := r.Header.Get("Host"); host != "example.com" { |
| 391 | + t.Skipf("期望 Host header 为 example.com,得到 %s", host) |
| 392 | + } |
| 393 | + w.Write([]byte("ok")) |
| 394 | + })) |
| 395 | + defer server.Close() |
| 396 | + |
| 397 | + tests := []struct { |
| 398 | + name string |
| 399 | + host string |
| 400 | + wantHost string |
| 401 | + }{ |
| 402 | + { |
| 403 | + name: "设置自定义Host", |
| 404 | + host: "example.com", |
| 405 | + wantHost: "example.com", |
| 406 | + }, |
| 407 | + { |
| 408 | + name: "设置空Host", |
| 409 | + host: "", |
| 410 | + wantHost: "", |
| 411 | + }, |
| 412 | + } |
| 413 | + |
| 414 | + for _, tt := range tests { |
| 415 | + t.Run(tt.name, func(t *testing.T) { |
| 416 | + // 创建客户端 |
| 417 | + client := New( |
| 418 | + Host(tt.host), |
| 419 | + ) |
| 420 | + |
| 421 | + // 发送请求 |
| 422 | + resp, err := client.DoRequest( |
| 423 | + context.Background(), |
| 424 | + URL(server.URL), |
| 425 | + ) |
| 426 | + |
| 427 | + // 检查错误 |
| 428 | + if err != nil { |
| 429 | + t.Fatalf("请求失败: %v", err) |
| 430 | + } |
| 431 | + |
| 432 | + // 检查响应状态 |
| 433 | + if resp.StatusCode != http.StatusOK { |
| 434 | + t.Errorf("期望状态码 200,得到 %d", resp.StatusCode) |
| 435 | + } |
| 436 | + |
| 437 | + // 验证 Options 中的设置 |
| 438 | + opts := newOptions([]Option{Host(tt.host)}) |
| 439 | + if len(opts.HttpRoundTripper) != 1 { |
| 440 | + t.Error("HttpRoundTripper 未正确设置") |
| 441 | + } |
| 442 | + }) |
| 443 | + } |
| 444 | +} |
| 445 | + |
| 446 | +func TestGzip(t *testing.T) { |
| 447 | + // 创建测试服务器 |
| 448 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 449 | + // 验证请求头 |
| 450 | + if r.Header.Get("Accept-Encoding") != "gzip" { |
| 451 | + t.Error("缺少 Accept-Encoding: gzip") |
| 452 | + } |
| 453 | + if r.Header.Get("Content-Encoding") != "gzip" { |
| 454 | + t.Error("缺少 Content-Encoding: gzip") |
| 455 | + } |
| 456 | + |
| 457 | + // 读取并解压缩请求体 |
| 458 | + reader, err := gzip.NewReader(r.Body) |
| 459 | + if err != nil { |
| 460 | + t.Fatalf("创建 gzip reader 失败: %v", err) |
| 461 | + } |
| 462 | + defer reader.Close() |
| 463 | + |
| 464 | + body, err := io.ReadAll(reader) |
| 465 | + if err != nil { |
| 466 | + t.Fatalf("读取解压缩内容失败: %v", err) |
| 467 | + } |
| 468 | + |
| 469 | + if string(body) != "test data" { |
| 470 | + t.Errorf("期望请求体为 'test data',得到 %s", string(body)) |
| 471 | + } |
| 472 | + |
| 473 | + w.Write([]byte("ok")) |
| 474 | + })) |
| 475 | + defer server.Close() |
| 476 | + |
| 477 | + tests := []struct { |
| 478 | + name string |
| 479 | + body any |
| 480 | + wantPanic bool |
| 481 | + }{ |
| 482 | + { |
| 483 | + name: "正常字符串", |
| 484 | + body: "test data", |
| 485 | + wantPanic: false, |
| 486 | + }, |
| 487 | + { |
| 488 | + name: "nil body", |
| 489 | + body: nil, |
| 490 | + wantPanic: true, |
| 491 | + }, |
| 492 | + { |
| 493 | + name: "无效的 body 类型", |
| 494 | + body: make(chan int), |
| 495 | + wantPanic: true, |
| 496 | + }, |
| 497 | + } |
| 498 | + |
| 499 | + for _, tt := range tests { |
| 500 | + t.Run(tt.name, func(t *testing.T) { |
| 501 | + defer func() { |
| 502 | + r := recover() |
| 503 | + if (r != nil) != tt.wantPanic { |
| 504 | + t.Errorf("Gzip() panic = %v, wantPanic = %v", r, tt.wantPanic) |
| 505 | + } |
| 506 | + }() |
| 507 | + |
| 508 | + if !tt.wantPanic { |
| 509 | + // 正常场景测试 |
| 510 | + client := New() |
| 511 | + resp, err := client.DoRequest( |
| 512 | + context.Background(), |
| 513 | + URL(server.URL), |
| 514 | + Method(http.MethodPost), |
| 515 | + Gzip(tt.body), |
| 516 | + ) |
| 517 | + |
| 518 | + if err != nil { |
| 519 | + t.Fatalf("请求失败: %v", err) |
| 520 | + } |
| 521 | + |
| 522 | + if resp.StatusCode != http.StatusOK { |
| 523 | + t.Errorf("期望状态码 200,得到 %d", resp.StatusCode) |
| 524 | + } |
| 525 | + } else { |
| 526 | + // 错误场景测试,直接调用 Gzip 函数 |
| 527 | + opt := Gzip(tt.body) |
| 528 | + _ = opt |
| 529 | + } |
| 530 | + }) |
| 531 | + } |
| 532 | +} |
| 533 | + |
| 534 | +func TestGzipWithErrorReader(t *testing.T) { |
| 535 | + defer func() { |
| 536 | + r := recover() |
| 537 | + if r == nil { |
| 538 | + t.Error("期望发生 panic,但没有") |
| 539 | + return |
| 540 | + } |
| 541 | + |
| 542 | + // 验证 panic 信息 |
| 543 | + if !strings.Contains(fmt.Sprintf("%v", r), "模拟读取错误") { |
| 544 | + t.Errorf("期望的 panic 信息包含 '模拟读取错误',得到 %v", r) |
| 545 | + } |
| 546 | + }() |
| 547 | + |
| 548 | + // 使用会产生错误的 Reader |
| 549 | + body := &errorReader{err: errors.New("模拟读取错误")} |
| 550 | + opt := Gzip(body) |
| 551 | + _ = opt |
| 552 | +} |
0 commit comments