Skip to content

Commit 70708e4

Browse files
peterldownsjulienschmidt
authored andcommitted
Allow chaining of any http.Handler, not just http.HandlerFunc.
1 parent 8c199fb commit 70708e4

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func main() {
297297

298298
**NOTE: It might be required to set [Router.HandleMethodNotAllowed](http://godoc.org/github.com/julienschmidt/httprouter#Router.HandleMethodNotAllowed) to `false` to avoid problems.**
299299

300-
You can use another [http.HandlerFunc](http://golang.org/pkg/net/http/#HandlerFunc), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining.
300+
You can use another [http.Handler](http://golang.org/pkg/net/http/#Handler), for example another router, to handle requests which could not be matched by this router by using the [Router.NotFound](http://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining.
301301

302302
### Static files
303303
The `NotFound` handler can for example be used to serve static files from the root path `/` (like an index.html file along with other assets):

router.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ type Router struct {
138138
// handler.
139139
HandleMethodNotAllowed bool
140140

141-
// Configurable http.HandlerFunc which is called when no matching route is
141+
// Configurable http.Handler which is called when no matching route is
142142
// found. If it is not set, http.NotFound is used.
143-
NotFound http.HandlerFunc
143+
NotFound http.Handler
144144

145-
// Configurable http.HandlerFunc which is called when a request
145+
// Configurable http.Handler which is called when a request
146146
// cannot be routed and HandleMethodNotAllowed is true.
147147
// If it is not set, http.Error with http.StatusMethodNotAllowed is used.
148-
MethodNotAllowed http.HandlerFunc
148+
MethodNotAllowed http.Handler
149149

150150
// Function to handle panics recovered from http handlers.
151151
// It should be used to generate a error page and return the http error code
@@ -342,7 +342,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
342342
handle, _, _ := r.trees[method].getValue(req.URL.Path)
343343
if handle != nil {
344344
if r.MethodNotAllowed != nil {
345-
r.MethodNotAllowed(w, req)
345+
r.MethodNotAllowed.ServeHTTP(w, req)
346346
} else {
347347
http.Error(w,
348348
http.StatusText(http.StatusMethodNotAllowed),
@@ -356,7 +356,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
356356

357357
// Handle 404
358358
if r.NotFound != nil {
359-
r.NotFound(w, req)
359+
r.NotFound.ServeHTTP(w, req)
360360
} else {
361361
http.NotFound(w, req)
362362
}

router_test.go

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,48 @@ func TestRouterRoot(t *testing.T) {
174174
}
175175
}
176176

177+
func TestRouterChaining(t *testing.T) {
178+
router1 := New()
179+
router2 := New()
180+
router1.NotFound = router2
181+
182+
fooHit := false
183+
router1.POST("/foo", func(w http.ResponseWriter, req *http.Request, _ Params) {
184+
fooHit = true
185+
w.WriteHeader(http.StatusOK)
186+
})
187+
188+
barHit := false
189+
router2.POST("/bar", func(w http.ResponseWriter, req *http.Request, _ Params) {
190+
barHit = true
191+
w.WriteHeader(http.StatusOK)
192+
})
193+
194+
r, _ := http.NewRequest("POST", "/foo", nil)
195+
w := httptest.NewRecorder()
196+
router1.ServeHTTP(w, r)
197+
if !(w.Code == http.StatusOK && fooHit) {
198+
t.Errorf("Regular routing failed with router chaining.")
199+
t.FailNow()
200+
}
201+
202+
r, _ = http.NewRequest("POST", "/bar", nil)
203+
w = httptest.NewRecorder()
204+
router1.ServeHTTP(w, r)
205+
if !(w.Code == http.StatusOK && barHit) {
206+
t.Errorf("Chained routing failed with router chaining.")
207+
t.FailNow()
208+
}
209+
210+
r, _ = http.NewRequest("POST", "/qax", nil)
211+
w = httptest.NewRecorder()
212+
router1.ServeHTTP(w, r)
213+
if !(w.Code == http.StatusNotFound) {
214+
t.Errorf("NotFound behavior failed with router chaining.")
215+
t.FailNow()
216+
}
217+
}
218+
177219
func TestRouterNotAllowed(t *testing.T) {
178220
handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
179221

@@ -190,10 +232,10 @@ func TestRouterNotAllowed(t *testing.T) {
190232

191233
w = httptest.NewRecorder()
192234
responseText := "custom method"
193-
router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request) {
235+
router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
194236
w.WriteHeader(http.StatusTeapot)
195237
w.Write([]byte(responseText))
196-
}
238+
})
197239
router.ServeHTTP(w, r)
198240
if got := w.Body.String(); !(got == responseText) {
199241
t.Errorf("unexpected response got %q want %q", got, responseText)
@@ -237,10 +279,10 @@ func TestRouterNotFound(t *testing.T) {
237279

238280
// Test custom not found handler
239281
var notFound bool
240-
router.NotFound = func(rw http.ResponseWriter, r *http.Request) {
282+
router.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
241283
rw.WriteHeader(404)
242284
notFound = true
243-
}
285+
})
244286
r, _ := http.NewRequest("GET", "/nope", nil)
245287
w := httptest.NewRecorder()
246288
router.ServeHTTP(w, r)

0 commit comments

Comments
 (0)