Skip to content

Commit e8aba0b

Browse files
committed
Add CORS tests
1 parent 6e16be9 commit e8aba0b

File tree

3 files changed

+162
-10
lines changed

3 files changed

+162
-10
lines changed

Diff for: api/api.go

+1-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
"github.com/go-chi/chi/v5"
99
"github.com/go-chi/chi/v5/middleware"
10-
"github.com/go-chi/cors"
1110

1211
"cdr.dev/slog"
1312
"github.com/coder/code-marketplace/api/httpapi"
@@ -74,16 +73,8 @@ func New(options *Options) *API {
7473

7574
r := chi.NewRouter()
7675

77-
cors := cors.New(cors.Options{
78-
AllowedOrigins: []string{"*"},
79-
AllowedMethods: []string{"POST", "GET", "OPTIONS"},
80-
AllowedHeaders: []string{"*"},
81-
AllowCredentials: true,
82-
MaxAge: 300,
83-
})
84-
8576
r.Use(
86-
cors.Handler,
77+
httpmw.Cors(),
8778
httpmw.RateLimitPerMinute(options.RateLimit),
8879
middleware.GetHead,
8980
httpmw.AttachRequestID,

Diff for: api/httpmw/cors.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package httpmw
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/go-chi/cors"
7+
)
8+
9+
const (
10+
// Server headers.
11+
AccessControlAllowOriginHeader = "Access-Control-Allow-Origin"
12+
AccessControlAllowCredentialsHeader = "Access-Control-Allow-Credentials"
13+
AccessControlAllowMethodsHeader = "Access-Control-Allow-Methods"
14+
AccessControlAllowHeadersHeader = "Access-Control-Allow-Headers"
15+
VaryHeader = "Vary"
16+
17+
// Client headers.
18+
OriginHeader = "Origin"
19+
AccessControlRequestMethodHeader = "Access-Control-Request-Method"
20+
AccessControlRequestHeadersHeader = "Access-Control-Request-Headers"
21+
)
22+
23+
func Cors() func(next http.Handler) http.Handler {
24+
return cors.Handler(cors.Options{
25+
AllowedOrigins: []string{"*"},
26+
AllowedMethods: []string{
27+
http.MethodHead,
28+
http.MethodGet,
29+
http.MethodPost,
30+
http.MethodPut,
31+
http.MethodPatch,
32+
http.MethodDelete,
33+
},
34+
AllowedHeaders: []string{"*"},
35+
AllowCredentials: true,
36+
MaxAge: 300,
37+
})
38+
}

Diff for: api/httpmw/cors_test.go

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package httpmw_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/code-marketplace/api/httpmw"
11+
)
12+
13+
func TestCors(t *testing.T) {
14+
t.Parallel()
15+
16+
methods := []string{
17+
http.MethodOptions,
18+
http.MethodHead,
19+
http.MethodGet,
20+
http.MethodPost,
21+
http.MethodPut,
22+
http.MethodPatch,
23+
http.MethodDelete,
24+
}
25+
26+
tests := []struct {
27+
name string
28+
origin string
29+
allowedOrigin string
30+
headers string
31+
allowedHeaders string
32+
}{
33+
{
34+
name: "LocalHTTP",
35+
origin: "http://localhost:3000",
36+
allowedOrigin: "*",
37+
},
38+
{
39+
name: "LocalHTTPS",
40+
origin: "https://localhost:3000",
41+
allowedOrigin: "*",
42+
},
43+
{
44+
name: "HTTP",
45+
origin: "http://code-server.domain.tld",
46+
allowedOrigin: "*",
47+
},
48+
{
49+
name: "HTTPS",
50+
origin: "https://code-server.domain.tld",
51+
allowedOrigin: "*",
52+
},
53+
{
54+
// VS Code appears to use this origin.
55+
name: "VSCode",
56+
origin: "vscode-file://vscode-app",
57+
allowedOrigin: "*",
58+
},
59+
{
60+
name: "NoOrigin",
61+
allowedOrigin: "",
62+
},
63+
{
64+
name: "Headers",
65+
origin: "foobar",
66+
allowedOrigin: "*",
67+
headers: "X-TEST,X-TEST2",
68+
allowedHeaders: "X-Test, X-Test2",
69+
},
70+
}
71+
72+
for _, test := range tests {
73+
test := test
74+
t.Run(test.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
for _, method := range methods {
78+
method := method
79+
t.Run(method, func(t *testing.T) {
80+
t.Parallel()
81+
82+
r := httptest.NewRequest(method, "http://dev.coder.com", nil)
83+
if test.origin != "" {
84+
r.Header.Set(httpmw.OriginHeader, test.origin)
85+
}
86+
87+
// OPTIONS requests need to know what method will be requested, or
88+
// go-chi/cors will error. Both request headers and methods should be
89+
// ignored for regular requests even if they are set, although that is
90+
// not tested here.
91+
if method == http.MethodOptions {
92+
r.Header.Set(httpmw.AccessControlRequestMethodHeader, http.MethodGet)
93+
if test.headers != "" {
94+
r.Header.Set(httpmw.AccessControlRequestHeadersHeader, test.headers)
95+
}
96+
}
97+
98+
rw := httptest.NewRecorder()
99+
handler := httpmw.Cors()(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
100+
rw.WriteHeader(http.StatusNoContent)
101+
}))
102+
handler.ServeHTTP(rw, r)
103+
104+
// Should always set some kind of allowed origin, if allowed.
105+
require.Equal(t, test.allowedOrigin, rw.Header().Get(httpmw.AccessControlAllowOriginHeader))
106+
107+
// OPTIONS should echo back the request method and headers and we
108+
// should never get to our handler as the middleware short-circuits
109+
// with a 200.
110+
if method == http.MethodOptions {
111+
require.Equal(t, http.MethodGet, rw.Header().Get(httpmw.AccessControlAllowMethodsHeader))
112+
require.Equal(t, test.allowedHeaders, rw.Header().Get(httpmw.AccessControlAllowHeadersHeader))
113+
require.Equal(t, http.StatusOK, rw.Code)
114+
} else {
115+
require.Equal(t, "", rw.Header().Get(httpmw.AccessControlAllowMethodsHeader))
116+
require.Equal(t, "", rw.Header().Get(httpmw.AccessControlAllowHeadersHeader))
117+
require.Equal(t, http.StatusNoContent, rw.Code)
118+
}
119+
})
120+
}
121+
})
122+
}
123+
}

0 commit comments

Comments
 (0)