Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions modules/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {

func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
defer func() {
if err := recover(); err != nil {
if recovered := recover(); recovered != nil {
err := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
panic(err)
}
Expand Down Expand Up @@ -117,7 +118,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool {
return false
}

func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
type middlewareProvider = func(next http.Handler) http.Handler

func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
handler := endpoint
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler).ServeHTTP
}
handler(w, r)
}

func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
return func(next http.Handler) http.Handler {
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
Expand All @@ -129,14 +140,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo

// toHandlerProvider converts a handler to a handler provider
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
func toHandlerProvider(handler any) middlewareProvider {
funcInfo := routing.GetFuncInfo(handler)
fn := reflect.ValueOf(handler)
if fn.Type().Kind() != reflect.Func {
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
}

if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
if hp, ok := handler.(middlewareProvider); ok {
return wrapHandlerProvider(hp, funcInfo)
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
return wrapHandlerProvider(hp, funcInfo)
Expand Down
87 changes: 65 additions & 22 deletions modules/web/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ import (
"github.com/go-chi/chi/v5"
)

// PreMiddlewareProvider is a special middleware provider which will be executed
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
// A route can do something (e.g.: set middleware options) at the place where it is declared,
// and the code will be executed before other middlewares which are added before the declaration.
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
type PreMiddlewareProvider func(next http.Handler) http.Handler

// Bind binding an obj to a handler's context data
func Bind[T any](_ T) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
Expand All @@ -41,7 +48,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {

// Router defines a route based on chi's router
type Router struct {
chiRouter *chi.Mux
chiRouter *chi.Mux

afterRouting []any

curGroupPrefix string
curMiddlewares []any
}
Expand All @@ -52,16 +62,23 @@ func NewRouter() *Router {
return &Router{chiRouter: r}
}

// Use supports two middlewares
func (r *Router) Use(middlewares ...any) {
// BeforeRouting adds middlewares which will be executed before the request path gets routed
// It should only be used for framework-level global middlewares when it needs to change request method & path.
func (r *Router) BeforeRouting(middlewares ...any) {
for _, m := range middlewares {
if !isNilOrFuncNil(m) {
r.chiRouter.Use(toHandlerProvider(m))
}
}
}

// Group mounts a sub-Router along a `pattern` string.
// AfterRouting adds middlewares which will be executed after the request path gets routed
// It can see the routed path and resolved path parameters
func (r *Router) AfterRouting(middlewares ...any) {
r.afterRouting = append(r.afterRouting, middlewares...)
}

// Group mounts a sub-router along a "pattern" string.
func (r *Router) Group(pattern string, fn func(), middlewares ...any) {
previousGroupPrefix := r.curGroupPrefix
previousMiddlewares := r.curMiddlewares
Expand Down Expand Up @@ -93,36 +110,54 @@ func isNilOrFuncNil(v any) bool {
return r.Kind() == reflect.Func && r.IsNil()
}

func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1)
for _, m := range curMiddlewares {
if !isNilOrFuncNil(m) {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
all = append(all, toHandlerProvider(middlewareProvider(h)))
}
}
return all
}

func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
all = append(all, toHandlerProvider(m))
}
}
return all
}

func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
if len(h) == 0 {
panic("no endpoint handler provided")
}
for i, m := range h {
if !isNilOrFuncNil(m) {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
} else if i == len(h)-1 {
panic("endpoint handler can't be nil")
}
if isNilOrFuncNil(h[len(h)-1]) {
panic("endpoint handler can't be nil")
}

handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
hasPreMiddlewares = len(handlerProviders) > 0
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)

middlewares := handlerProviders[:len(handlerProviders)-1]
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
mockPoint := RouterMockPoint(MockAfterMiddlewares)
if mockPoint != nil {
middlewares = append(middlewares, mockPoint)
}
return middlewares, handlerFunc
return middlewares, handlerFunc, hasPreMiddlewares
}

// Methods adds the same handlers for multiple http "methods" (separated by ",").
// If any method is invalid, the lower level router will panic.
func (r *Router) Methods(methods, pattern string, h ...any) {
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
fullPattern := r.getPattern(pattern)
if strings.Contains(methods, ",") {
methods := strings.SplitSeq(methods, ",")
Expand All @@ -134,15 +169,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
}
}

// Mount attaches another Router along ./pattern/*
// Mount attaches another Router along "/pattern/*"
func (r *Router) Mount(pattern string, subRouter *Router) {
subRouter.Use(r.curMiddlewares...)
r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter)
handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares))
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares)
r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter)
}

// Any delegate requests for all methods
func (r *Router) Any(pattern string, h ...any) {
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
}

Expand Down Expand Up @@ -178,12 +217,16 @@ func (r *Router) Patch(pattern string, h ...any) {

// ServeHTTP implements http.Handler
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
r.normalizeRequestPath(w, req, r.chiRouter)
}

// NotFound defines a handler to respond whenever a route could not be found.
func (r *Router) NotFound(h http.HandlerFunc) {
r.chiRouter.NotFound(h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h})
r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) {
executeMiddlewaresHandler(w, r, middlewares, handlerFunc)
})
}

func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) {
Expand Down
13 changes: 6 additions & 7 deletions modules/web/router_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request)
for _, m := range g.matchers {
if m.matchPath(chiCtx, path) {
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
handler := m.handlerFunc
for i := len(m.middlewares) - 1; i >= 0; i-- {
handler = m.middlewares[i](handler).ServeHTTP
}
handler(resp, req)
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
return
}
}
Expand Down Expand Up @@ -67,7 +63,7 @@ type routerPathMatcher struct {
pattern string
re *regexp.Regexp
params []routerPathParam
middlewares []func(http.Handler) http.Handler
middlewares []middlewareProvider
handlerFunc http.HandlerFunc
}

Expand Down Expand Up @@ -111,7 +107,10 @@ func isValidMethod(name string) bool {
}

func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h)
if hasPreMiddlewares {
panic("pre-middlewares are not supported in router path matcher")
}
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
for method := range strings.SplitSeq(methods, ",") {
method = strings.TrimSpace(method)
Expand Down
Loading
Loading