|
| 1 | +//go:build go1.18 |
| 2 | +// +build go1.18 |
| 3 | + |
| 4 | +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 5 | +package lambdaurl |
| 6 | + |
| 7 | +import ( |
| 8 | + "bytes" |
| 9 | + "context" |
| 10 | + _ "embed" |
| 11 | + "encoding/json" |
| 12 | + "io" |
| 13 | + "io/ioutil" |
| 14 | + "log" |
| 15 | + "net/http" |
| 16 | + "testing" |
| 17 | + "time" |
| 18 | + |
| 19 | + "github.com/aws/aws-lambda-go/events" |
| 20 | + "github.com/stretchr/testify/assert" |
| 21 | + "github.com/stretchr/testify/require" |
| 22 | +) |
| 23 | + |
| 24 | +//go:embed testdata/function-url-request-with-headers-and-cookies-and-text-body.json |
| 25 | +var helloRequest []byte |
| 26 | + |
| 27 | +//go:embed testdata/function-url-domain-only-get-request.json |
| 28 | +var domainOnlyGetRequest []byte |
| 29 | + |
| 30 | +//go:embed testdata/function-url-domain-only-get-request-trailing-slash.json |
| 31 | +var domainOnlyWithSlashGetRequest []byte |
| 32 | + |
| 33 | +//go:embed testdata/function-url-domain-only-request-with-base64-encoded-body.json |
| 34 | +var base64EncodedBodyRequest []byte |
| 35 | + |
| 36 | +func TestWrap(t *testing.T) { |
| 37 | + for name, params := range map[string]struct { |
| 38 | + input []byte |
| 39 | + handler http.HandlerFunc |
| 40 | + expectStatus int |
| 41 | + expectBody string |
| 42 | + expectHeaders map[string]string |
| 43 | + expectCookies []string |
| 44 | + }{ |
| 45 | + "hello": { |
| 46 | + input: helloRequest, |
| 47 | + handler: func(w http.ResponseWriter, r *http.Request) { |
| 48 | + w.Header().Add("Hello", "world1") |
| 49 | + w.Header().Add("Hello", "world2") |
| 50 | + http.SetCookie(w, &http.Cookie{Name: "yummy", Value: "cookie"}) |
| 51 | + http.SetCookie(w, &http.Cookie{Name: "yummy", Value: "cake"}) |
| 52 | + http.SetCookie(w, &http.Cookie{Name: "fruit", Value: "banana", Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)}) |
| 53 | + for _, c := range r.Cookies() { |
| 54 | + http.SetCookie(w, c) |
| 55 | + } |
| 56 | + |
| 57 | + w.WriteHeader(http.StatusTeapot) |
| 58 | + encoder := json.NewEncoder(w) |
| 59 | + _ = encoder.Encode(struct{ RequestQueryParams, Method any }{r.URL.Query(), r.Method}) |
| 60 | + }, |
| 61 | + expectStatus: http.StatusTeapot, |
| 62 | + expectHeaders: map[string]string{ |
| 63 | + "Hello": "world1,world2", |
| 64 | + }, |
| 65 | + expectCookies: []string{ |
| 66 | + "yummy=cookie", |
| 67 | + "yummy=cake", |
| 68 | + "fruit=banana; Expires=Fri, 31 Dec 1999 00:00:00 GMT", |
| 69 | + "foo=bar", |
| 70 | + "hello=hello", |
| 71 | + }, |
| 72 | + expectBody: `{"RequestQueryParams":{"foo":["bar"],"hello":["world"]},"Method":"POST"}` + "\n", |
| 73 | + }, |
| 74 | + "mux": { |
| 75 | + input: helloRequest, |
| 76 | + handler: func(w http.ResponseWriter, r *http.Request) { |
| 77 | + log.Println(r.URL) |
| 78 | + mux := http.NewServeMux() |
| 79 | + mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { |
| 80 | + w.WriteHeader(200) |
| 81 | + _, _ = w.Write([]byte("Hello World!")) |
| 82 | + }) |
| 83 | + mux.ServeHTTP(w, r) |
| 84 | + }, |
| 85 | + expectStatus: 200, |
| 86 | + expectBody: "Hello World!", |
| 87 | + }, |
| 88 | + "get-implicit-trailing-slash": { |
| 89 | + input: domainOnlyGetRequest, |
| 90 | + handler: func(w http.ResponseWriter, r *http.Request) { |
| 91 | + encoder := json.NewEncoder(w) |
| 92 | + _ = encoder.Encode(r.Method) |
| 93 | + _ = encoder.Encode(r.URL.String()) |
| 94 | + }, |
| 95 | + expectStatus: http.StatusOK, |
| 96 | + expectBody: "\"GET\"\n\"https://lambda-url-id.lambda-url.us-west-2.on.aws/\"\n", |
| 97 | + }, |
| 98 | + "get-explicit-trailing-slash": { |
| 99 | + input: domainOnlyWithSlashGetRequest, |
| 100 | + handler: func(w http.ResponseWriter, r *http.Request) { |
| 101 | + encoder := json.NewEncoder(w) |
| 102 | + _ = encoder.Encode(r.Method) |
| 103 | + _ = encoder.Encode(r.URL.String()) |
| 104 | + }, |
| 105 | + expectStatus: http.StatusOK, |
| 106 | + expectBody: "\"GET\"\n\"https://lambda-url-id.lambda-url.us-west-2.on.aws/\"\n", |
| 107 | + }, |
| 108 | + "empty handler": { |
| 109 | + input: helloRequest, |
| 110 | + handler: func(w http.ResponseWriter, r *http.Request) {}, |
| 111 | + expectStatus: http.StatusOK, |
| 112 | + }, |
| 113 | + "base64request": { |
| 114 | + input: base64EncodedBodyRequest, |
| 115 | + handler: func(w http.ResponseWriter, r *http.Request) { |
| 116 | + _, _ = io.Copy(w, r.Body) |
| 117 | + }, |
| 118 | + expectStatus: http.StatusOK, |
| 119 | + expectBody: "<idk/>", |
| 120 | + }, |
| 121 | + } { |
| 122 | + t.Run(name, func(t *testing.T) { |
| 123 | + handler := Wrap(params.handler) |
| 124 | + var req events.LambdaFunctionURLRequest |
| 125 | + require.NoError(t, json.Unmarshal(params.input, &req)) |
| 126 | + res, err := handler(context.Background(), &req) |
| 127 | + require.NoError(t, err) |
| 128 | + resultBodyBytes, err := ioutil.ReadAll(res) |
| 129 | + require.NoError(t, err) |
| 130 | + resultHeaderBytes, resultBodyBytes, ok := bytes.Cut(resultBodyBytes, []byte{0, 0, 0, 0, 0, 0, 0, 0}) |
| 131 | + require.True(t, ok) |
| 132 | + var resultHeader struct { |
| 133 | + StatusCode int |
| 134 | + Headers map[string]string |
| 135 | + Cookies []string |
| 136 | + } |
| 137 | + require.NoError(t, json.Unmarshal(resultHeaderBytes, &resultHeader)) |
| 138 | + assert.Equal(t, params.expectBody, string(resultBodyBytes)) |
| 139 | + assert.Equal(t, params.expectStatus, resultHeader.StatusCode) |
| 140 | + assert.Equal(t, params.expectHeaders, resultHeader.Headers) |
| 141 | + assert.Equal(t, params.expectCookies, resultHeader.Cookies) |
| 142 | + }) |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +func TestRequestContext(t *testing.T) { |
| 147 | + var req *events.LambdaFunctionURLRequest |
| 148 | + require.NoError(t, json.Unmarshal(helloRequest, &req)) |
| 149 | + handler := Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 150 | + reqFromContext, exists := RequestFromContext(r.Context()) |
| 151 | + require.True(t, exists) |
| 152 | + require.NotNil(t, reqFromContext) |
| 153 | + assert.Equal(t, req, reqFromContext) |
| 154 | + })) |
| 155 | + _, err := handler(context.Background(), req) |
| 156 | + require.NoError(t, err) |
| 157 | +} |
0 commit comments