Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add recovery middleware #44

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,8 @@ func (c *Context) Get(key string) (any, bool) {
val, ok := c.storage[key]
return val, ok
}

// Debug returns whether we are in debug mode or not.
func (c *Context) Debug() bool {
return c.kid.Debug()
}
9 changes: 9 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,12 @@ func TestContext_HTMLString(t *testing.T) {
assert.Equal(t, "<p>Hello</p>", res.Body.String())
assert.Equal(t, "text/html", res.Header().Get("Content-Type"))
}

func TestContext_Debug(t *testing.T) {
k := New()
k.debug = true

ctx := newContext(k)

assert.True(t, ctx.Debug())
}
7 changes: 7 additions & 0 deletions kid.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ func (k *Kid) Debug() bool {
return k.debug
}

// NewContext basically is a helper function and can be used in testing.
func (k *Kid) NewContext(req *http.Request, res http.ResponseWriter) *Context {
ctx := newContext(k)
ctx.reset(req, res)
return ctx
}

// ApplyOptions applies the given options.
func (k *Kid) ApplyOptions(opts ...Option) {
for _, opt := range opts {
Expand Down
8 changes: 8 additions & 0 deletions kid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,11 @@ func TestPanicIfNil(t *testing.T) {
panicIfNil(x, "")
})
}

func TestKid_NewContext(t *testing.T) {
k := New()

ctx := k.NewContext(nil, nil)

assert.NotNil(t, ctx)
}
66 changes: 66 additions & 0 deletions middlewares/recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package middlewares

import (
"fmt"
"io"
"net/http"
"os"
"runtime/debug"

"github.com/mojixcoder/kid"
)

// RecoveryConfig is the config used to build a Recovery middleware.
type RecoveryConfig struct {
// LogRecovers logs when a recovery happens, only in debug mode.
LogRecovers bool

// PrintStacktrace prints the entire stacktrace if true, only in debug mode.
PrintStacktrace bool

// Writer is the writer for logging recoveries and stacktraces.
Writer io.Writer

// OnRecovery is the function which will be called when a recovery occurs.
OnRecovery func(c *kid.Context, err any)
}

// DefaultRecoverConfig is the default Recovery config.
var DefaultRecoveryConfig = RecoveryConfig{
LogRecovers: true,
Writer: os.Stdout,
OnRecovery: func(c *kid.Context, err any) {
c.JSON(http.StatusInternalServerError, kid.Map{"message": http.StatusText(http.StatusInternalServerError)})
},
}

// NewRecovery returns a new Recovery middleware.
func NewRecovery() kid.MiddlewareFunc {
return NewRecoveryWithConfig(DefaultRecoveryConfig)
}

// NewRecoveryWithConfig returns a new Recovery middleware with the given config.
func NewRecoveryWithConfig(cfg RecoveryConfig) kid.MiddlewareFunc {
return func(next kid.HandlerFunc) kid.HandlerFunc {
return func(c *kid.Context) {
defer func() {
if err := recover(); err != nil {
if cfg.LogRecovers && c.Debug() {
fmt.Fprintf(cfg.Writer, "[RECOVERY] panic recovered: %v\n", err)
}

if cfg.PrintStacktrace && c.Debug() {
stack := debug.Stack()
fmt.Fprintf(cfg.Writer, "%s", string(stack))
}

if cfg.OnRecovery != nil {
cfg.OnRecovery(c, err)
}
}
}()

next(c)
}
}
}
64 changes: 64 additions & 0 deletions middlewares/recovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package middlewares

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/mojixcoder/kid"
"github.com/stretchr/testify/assert"
)

var flag bool

var recoveryHandler kid.HandlerFunc = func(c *kid.Context) {
panic("err")
}

func TestNewRecoveryWithConfig(t *testing.T) {
k := kid.New()
var buf bytes.Buffer

recovery := NewRecoveryWithConfig(RecoveryConfig{LogRecovers: true, Writer: &buf})

ctx := k.NewContext(nil, httptest.NewRecorder())
recovery(recoveryHandler)(ctx)
assert.Equal(t, "[RECOVERY] panic recovered: err\n", buf.String())

buf.Reset()
recovery = NewRecoveryWithConfig(RecoveryConfig{PrintStacktrace: true, Writer: &buf})

ctx = k.NewContext(nil, httptest.NewRecorder())
recovery(recoveryHandler)(ctx)
assert.NotEmpty(t, buf.String())

buf.Reset()
k.ApplyOptions(kid.WithDebug(false))
recovery(recoveryHandler)(ctx)
assert.Empty(t, buf.String())

buf.Reset()
recovery = NewRecoveryWithConfig(RecoveryConfig{
OnRecovery: func(c *kid.Context, err any) {
flag = true
},
})

ctx = k.NewContext(nil, httptest.NewRecorder())
recovery(recoveryHandler)(ctx)
assert.True(t, flag)
}

func TestNewRecovery(t *testing.T) {
k := kid.New()

recovery := NewRecovery()

res := httptest.NewRecorder()
ctx := k.NewContext(nil, res)
recovery(recoveryHandler)(ctx)

assert.Equal(t, res.Code, http.StatusInternalServerError)
assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", res.Body.String())
}
8 changes: 8 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type (

// Written returns true if response has already been written otherwise returns false.
Written() bool

// Status returns the status code.
Status() int
}

// response implements ResponseWriter.
Expand Down Expand Up @@ -85,6 +88,11 @@ func (r *response) Written() bool {
return r.written
}

// Status returns the status code.
func (r *response) Status() int {
return r.status
}

// Flush implements the http.Flusher interface.
func (r *response) Flush() {
r.WriteHeaderNow()
Expand Down
9 changes: 9 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,12 @@ func TestResponseWriter_Hijack(t *testing.T) {
})
assert.True(t, res.Written())
}

func TestResponseWriter_Status(t *testing.T) {
w := httptest.NewRecorder()
res := newResponse(w).(*response)

res.WriteHeader(http.StatusAccepted)

assert.Equal(t, http.StatusAccepted, res.Status())
}