Skip to content

Commit 502cce2

Browse files
authored
Merge pull request #1651 from curvegrid/cors-allow-origin-func
CORS: add an optional custom function to validate the origin
2 parents 17a5fca + e6f24aa commit 502cce2

File tree

2 files changed

+88
-24
lines changed

2 files changed

+88
-24
lines changed

middleware/cors.go

+41-24
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ type (
1919
// Optional. Default value []string{"*"}.
2020
AllowOrigins []string `yaml:"allow_origins"`
2121

22+
// AllowOriginFunc is a custom function to validate the origin. It takes the
23+
// origin as an argument and returns true if allowed or false otherwise. If
24+
// an error is returned, it is returned by the handler. If this option is
25+
// set, AllowOrigins is ignored.
26+
// Optional.
27+
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
28+
2229
// AllowMethods defines a list methods allowed when accessing the resource.
2330
// This is used in response to a preflight request.
2431
// Optional. Default value DefaultCORSConfig.AllowMethods.
@@ -113,40 +120,50 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
113120
return c.NoContent(http.StatusNoContent)
114121
}
115122

116-
// Check allowed origins
117-
for _, o := range config.AllowOrigins {
118-
if o == "*" && config.AllowCredentials {
119-
allowOrigin = origin
120-
break
121-
}
122-
if o == "*" || o == origin {
123-
allowOrigin = o
124-
break
123+
if config.AllowOriginFunc != nil {
124+
allowed, err := config.AllowOriginFunc(origin)
125+
if err != nil {
126+
return err
125127
}
126-
if matchSubdomain(origin, o) {
128+
if allowed {
127129
allowOrigin = origin
128-
break
129130
}
130-
}
131-
132-
// Check allowed origin patterns
133-
for _, re := range allowOriginPatterns {
134-
if allowOrigin == "" {
135-
didx := strings.Index(origin, "://")
136-
if didx == -1 {
137-
continue
131+
} else {
132+
// Check allowed origins
133+
for _, o := range config.AllowOrigins {
134+
if o == "*" && config.AllowCredentials {
135+
allowOrigin = origin
136+
break
138137
}
139-
domAuth := origin[didx+3:]
140-
// to avoid regex cost by invalid long domain
141-
if len(domAuth) > 253 {
138+
if o == "*" || o == origin {
139+
allowOrigin = o
142140
break
143141
}
144-
145-
if match, _ := regexp.MatchString(re, origin); match {
142+
if matchSubdomain(origin, o) {
146143
allowOrigin = origin
147144
break
148145
}
149146
}
147+
148+
// Check allowed origin patterns
149+
for _, re := range allowOriginPatterns {
150+
if allowOrigin == "" {
151+
didx := strings.Index(origin, "://")
152+
if didx == -1 {
153+
continue
154+
}
155+
domAuth := origin[didx+3:]
156+
// to avoid regex cost by invalid long domain
157+
if len(domAuth) > 253 {
158+
break
159+
}
160+
161+
if match, _ := regexp.MatchString(re, origin); match {
162+
allowOrigin = origin
163+
break
164+
}
165+
}
166+
}
150167
}
151168

152169
// Origin not allowed

middleware/cors_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"errors"
45
"net/http"
56
"net/http/httptest"
67
"testing"
@@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
360361
}
361362
}
362363
}
364+
365+
func Test_allowOriginFunc(t *testing.T) {
366+
returnTrue := func(origin string) (bool, error) {
367+
return true, nil
368+
}
369+
returnFalse := func(origin string) (bool, error) {
370+
return false, nil
371+
}
372+
returnError := func(origin string) (bool, error) {
373+
return true, errors.New("this is a test error")
374+
}
375+
376+
allowOriginFuncs := []func(origin string) (bool, error){
377+
returnTrue,
378+
returnFalse,
379+
returnError,
380+
}
381+
382+
const origin = "http://example.com"
383+
384+
e := echo.New()
385+
for _, allowOriginFunc := range allowOriginFuncs {
386+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
387+
rec := httptest.NewRecorder()
388+
c := e.NewContext(req, rec)
389+
req.Header.Set(echo.HeaderOrigin, origin)
390+
cors := CORSWithConfig(CORSConfig{
391+
AllowOriginFunc: allowOriginFunc,
392+
})
393+
h := cors(echo.NotFoundHandler)
394+
err := h(c)
395+
396+
expected, expectedErr := allowOriginFunc(origin)
397+
if expectedErr != nil {
398+
assert.Equal(t, expectedErr, err)
399+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
400+
continue
401+
}
402+
403+
if expected {
404+
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
405+
} else {
406+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
407+
}
408+
}
409+
}

0 commit comments

Comments
 (0)