|
| 1 | +package matchers |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "io" |
| 6 | + "net/http" |
| 7 | + "net/http/httptest" |
| 8 | + "reflect" |
| 9 | + "strings" |
| 10 | + |
| 11 | + "github.com/onsi/gomega/format" |
| 12 | + "github.com/onsi/gomega/types" |
| 13 | +) |
| 14 | + |
| 15 | +var _ types.GomegaMatcher = new(HaveHTTPProtocolMatcher) |
| 16 | + |
| 17 | +// HaveHTTPProtocolMatcher matches the request with the expected protocol |
| 18 | +// This has been inspired by gomega.HaveHTTPStatusMatcher |
| 19 | +type HaveHTTPProtocolMatcher struct { |
| 20 | + Expected interface{} |
| 21 | +} |
| 22 | + |
| 23 | +func (matcher *HaveHTTPProtocolMatcher) Match(actual interface{}) (success bool, err error) { |
| 24 | + var resp *http.Response |
| 25 | + switch a := actual.(type) { |
| 26 | + case *http.Response: |
| 27 | + resp = a |
| 28 | + case *httptest.ResponseRecorder: |
| 29 | + resp = a.Result() |
| 30 | + default: |
| 31 | + return false, fmt.Errorf("HaveHTTPProtocol matcher expects *http.Response or *httptest.ResponseRecorder. Got:\n%s", format.Object(actual, 1)) |
| 32 | + } |
| 33 | + |
| 34 | + switch e := matcher.Expected.(type) { |
| 35 | + case string: |
| 36 | + if resp.Proto == e { |
| 37 | + return true, nil |
| 38 | + } |
| 39 | + default: |
| 40 | + return false, fmt.Errorf("HaveHTTPProtocol matcher must be passed a string type. Got:\n%s", format.Object(matcher.Expected, 1)) |
| 41 | + } |
| 42 | + |
| 43 | + return false, nil |
| 44 | +} |
| 45 | + |
| 46 | +func (matcher *HaveHTTPProtocolMatcher) FailureMessage(actual interface{}) (message string) { |
| 47 | + return fmt.Sprintf("Expected\n%s\n%s\n%s", formatHttpResponse(actual), "to have HTTP protocol", matcher.Expected) |
| 48 | +} |
| 49 | + |
| 50 | +func (matcher *HaveHTTPProtocolMatcher) NegatedFailureMessage(actual interface{}) (message string) { |
| 51 | + return fmt.Sprintf("Expected\n%s\n%s\n%s", formatHttpResponse(actual), "not to have HTTP protocol", matcher.Expected) |
| 52 | +} |
| 53 | + |
| 54 | +func formatHttpResponse(input interface{}) string { |
| 55 | + var resp *http.Response |
| 56 | + switch r := input.(type) { |
| 57 | + case *http.Response: |
| 58 | + resp = r |
| 59 | + case *httptest.ResponseRecorder: |
| 60 | + resp = r.Result() |
| 61 | + default: |
| 62 | + return "cannot format invalid HTTP response" |
| 63 | + } |
| 64 | + |
| 65 | + body := "<nil>" |
| 66 | + if resp.Body != nil { |
| 67 | + defer resp.Body.Close() |
| 68 | + data, err := io.ReadAll(resp.Body) |
| 69 | + if err != nil { |
| 70 | + data = []byte("<error reading body>") |
| 71 | + } |
| 72 | + body = format.Object(string(data), 0) |
| 73 | + } |
| 74 | + |
| 75 | + var s strings.Builder |
| 76 | + s.WriteString(fmt.Sprintf("%s<%s>: {\n", format.Indent, reflect.TypeOf(input))) |
| 77 | + s.WriteString(fmt.Sprintf("%s%sProtocol: %s\n", format.Indent, format.Indent, resp.Proto)) |
| 78 | + s.WriteString(fmt.Sprintf("%s%sStatus: %s\n", format.Indent, format.Indent, format.Object(resp.Status, 0))) |
| 79 | + s.WriteString(fmt.Sprintf("%s%sStatusCode: %s\n", format.Indent, format.Indent, format.Object(resp.StatusCode, 0))) |
| 80 | + s.WriteString(fmt.Sprintf("%s%sBody: %s\n", format.Indent, format.Indent, body)) |
| 81 | + s.WriteString(fmt.Sprintf("%s}", format.Indent)) |
| 82 | + |
| 83 | + return s.String() |
| 84 | +} |
0 commit comments