-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcors.go
122 lines (104 loc) · 2.85 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package rmsgo
import (
"net/http"
"strings"
)
var (
errCorsFail = Forbidden("you are not allowed in here")
allowMethods = []string{"HEAD", "GET", "PUT", "DELETE"}
allowHeaders = []string{
"Authorization",
"Content-Length",
"Content-Type",
"Origin",
"X-Requested-With",
"If-Match",
"If-None-Match",
}
)
func handleCORS(next http.Handler) http.Handler {
mux := &MuxWithError{}
mux.HandleFunc("OPTIONS /", preflight) // preflight does not pass on the request to the next handler
mux.Handle("/", cors(next))
return mux
}
func preflight(w http.ResponseWriter, r *http.Request) error {
path := r.URL.Path
isFolder := path[len(path)-1] == '/'
hs := w.Header()
// always set Vary headers
hs.Add("Vary", "Origin")
hs.Add("Vary", "Access-Control-Request-Method")
hs.Add("Vary", "Access-Control-Request-Headers")
origin := r.Header.Get("Origin")
if !(g.allowAllOrigins || g.allowOrigin(r, origin)) {
return errCorsFail
}
n, err := Retrieve(path)
if err != nil { // not found
return errCorsFail
}
if n.isFolder != isFolder { // malformed request
return errCorsFail
}
reqMethod := strings.ToUpper(r.Header.Get("Access-Control-Request-Method"))
reqMethodAllowed := false
if reqMethod == http.MethodOptions {
reqMethodAllowed = true
} else {
for _, m := range allowMethods {
if m == reqMethod {
reqMethodAllowed = true
break
}
}
}
if !reqMethodAllowed {
return errCorsFail
}
// We might get multiple header values, but a single value might actually
// contain multiple values itself, separated by commas.
// By first joining all the values together, and then splitting again, we
// ensure that all values are separate.
reqHeaders := strings.Split(strings.Join(r.Header.Values("Access-Control-Request-Headers"), ","), ",")
for _, reqHeader := range reqHeaders {
reqHeader = http.CanonicalHeaderKey(strings.TrimSpace(reqHeader))
reqHeaderAllowed := false
for _, h := range allowHeaders {
if h == reqHeader {
reqHeaderAllowed = true
break
}
}
if !reqHeaderAllowed {
return errCorsFail
}
}
if g.allowAllOrigins {
hs.Set("Access-Control-Allow-Origin", "*")
} else {
hs.Set("Access-Control-Allow-Origin", origin)
}
hs.Set("Access-Control-Allow-Methods", strings.Join(allowMethods, ", "))
hs.Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", "))
w.WriteHeader(http.StatusNoContent)
return nil
}
func cors(next http.Handler) http.Handler {
return HandlerWithError(func(w http.ResponseWriter, r *http.Request) error {
hs := w.Header()
// always set Vary header
hs.Set("Vary", "Origin")
origin := r.Header.Get("Origin")
if !(g.allowAllOrigins || g.allowOrigin(r, origin)) {
return errCorsFail
}
if g.allowAllOrigins {
hs.Set("Access-Control-Allow-Origin", "*")
} else {
hs.Set("Access-Control-Allow-Origin", origin)
}
next.ServeHTTP(w, r)
return nil
})
}