Skip to content

Commit

Permalink
Merge branch 'master' into routegroup
Browse files Browse the repository at this point in the history
  • Loading branch information
David Budworth authored Jan 5, 2017
2 parents bb7dc83 + 8a45e95 commit 8d90774
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 196 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ go:
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- 1.7
- tip
231 changes: 92 additions & 139 deletions README.md

Large diffs are not rendered by default.

70 changes: 59 additions & 11 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,19 @@ type Router struct {
// handler.
HandleMethodNotAllowed bool

// If enabled, the router automatically replies to OPTIONS requests.
// Custom OPTIONS handlers take priority over automatic replies.
HandleOPTIONS bool

// Configurable http.Handler which is called when no matching route is
// found. If it is not set, http.NotFound is used.
NotFound http.Handler

// Configurable http.Handler which is called when a request
// cannot be routed and HandleMethodNotAllowed is true.
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
// The "Allow" header with allowed request methods is set before the handler
// is called.
MethodNotAllowed http.Handler

// Function to handle panics recovered from http handlers.
Expand All @@ -165,6 +171,7 @@ func New() *Router {
RedirectTrailingSlash: true,
RedirectFixedPath: true,
HandleMethodNotAllowed: true,
HandleOPTIONS: true,
}
}

Expand Down Expand Up @@ -290,15 +297,53 @@ func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
return nil, nil, false
}

func (r *Router) allowed(path, reqMethod string) (allow string) {
if path == "*" { // server-wide
for method := range r.trees {
if method == "OPTIONS" {
continue
}

// add request method to list of allowed methods
if len(allow) == 0 {
allow = method
} else {
allow += ", " + method
}
}
} else { // specific path
for method := range r.trees {
// Skip the requested method - we already tried this one
if method == reqMethod || method == "OPTIONS" {
continue
}

handle, _, _ := r.trees[method].getValue(path)
if handle != nil {
// add request method to list of allowed methods
if len(allow) == 0 {
allow = method
} else {
allow += ", " + method
}
}
}
}
if len(allow) > 0 {
allow += ", OPTIONS"
}
return
}

// ServeHTTP makes the router implement the http.Handler interface.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.PanicHandler != nil {
defer r.recv(w, req)
}

if root := r.trees[req.Method]; root != nil {
path := req.URL.Path
path := req.URL.Path

if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path); handle != nil {
handle(w, req, ps)
return
Expand Down Expand Up @@ -335,16 +380,19 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
}

// Handle 405
if r.HandleMethodNotAllowed {
for method := range r.trees {
// Skip the requested method - we already tried this one
if method == req.Method {
continue
if req.Method == "OPTIONS" {
// Handle OPTIONS requests
if r.HandleOPTIONS {
if allow := r.allowed(path, req.Method); len(allow) > 0 {
w.Header().Set("Allow", allow)
return
}

handle, _, _ := r.trees[method].getValue(req.URL.Path)
if handle != nil {
}
} else {
// Handle 405
if r.HandleMethodNotAllowed {
if allow := r.allowed(path, req.Method); len(allow) > 0 {
w.Header().Set("Allow", allow)
if r.MethodNotAllowed != nil {
r.MethodNotAllowed.ServeHTTP(w, req)
} else {
Expand Down
116 changes: 113 additions & 3 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ func TestRouter(t *testing.T) {
}

type handlerStruct struct {
handeled *bool
handled *bool
}

func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
*h.handeled = true
*h.handled = true
}

func TestRouterAPI(t *testing.T) {
Expand Down Expand Up @@ -216,20 +216,127 @@ func TestRouterChaining(t *testing.T) {
}
}

func TestRouterOPTIONS(t *testing.T) {
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}

router := New()
router.POST("/path", handlerFunc)

// test not allowed
// * (server)
r, _ := http.NewRequest("OPTIONS", "*", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

// path
r, _ = http.NewRequest("OPTIONS", "/path", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

r, _ = http.NewRequest("OPTIONS", "/doesnotexist", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusNotFound) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
}

// add another method
router.GET("/path", handlerFunc)

// test again
// * (server)
r, _ = http.NewRequest("OPTIONS", "*", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

// path
r, _ = http.NewRequest("OPTIONS", "/path", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

// custom handler
var custom bool
router.OPTIONS("/path", func(w http.ResponseWriter, r *http.Request, _ Params) {
custom = true
})

// test again
// * (server)
r, _ = http.NewRequest("OPTIONS", "*", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, GET, OPTIONS" && allow != "GET, POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}
if custom {
t.Error("custom handler called on *")
}

// path
r, _ = http.NewRequest("OPTIONS", "/path", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusOK) {
t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
}
if !custom {
t.Error("custom handler not called")
}
}

func TestRouterNotAllowed(t *testing.T) {
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}

router := New()
router.POST("/path", handlerFunc)

// Test not allowed
// test not allowed
r, _ := http.NewRequest("GET", "/path", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusMethodNotAllowed) {
t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

// add another method
router.DELETE("/path", handlerFunc)
router.OPTIONS("/path", handlerFunc) // must be ignored

// test again
r, _ = http.NewRequest("GET", "/path", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, r)
if !(w.Code == http.StatusMethodNotAllowed) {
t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
} else if allow := w.Header().Get("Allow"); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}

// test custom handler
w = httptest.NewRecorder()
responseText := "custom method"
router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand All @@ -243,6 +350,9 @@ func TestRouterNotAllowed(t *testing.T) {
if w.Code != http.StatusTeapot {
t.Errorf("unexpected response code %d want %d", w.Code, http.StatusTeapot)
}
if allow := w.Header().Get("Allow"); allow != "POST, DELETE, OPTIONS" && allow != "DELETE, POST, OPTIONS" {
t.Error("unexpected Allow header value: " + allow)
}
}

func TestRouterNotFound(t *testing.T) {
Expand Down
Loading

0 comments on commit 8d90774

Please sign in to comment.