Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added some helper methods to kid.Context #55

Merged
merged 6 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package kid

import (
"context"
"net/http"
"net/url"
"sync"
)

const contentTypeHeader string = "Content-Type"

// Context is the context of current HTTP request.
// It holds data related to current HTTP request.
type Context struct {
Expand Down Expand Up @@ -56,6 +59,20 @@ func (c *Context) Params() Params {
return c.params
}

// Path returns request's path used for matching request to a handler.
func (c *Context) Path() string {
u := c.request.URL
if u.RawPath != "" {
return u.RawPath
}
return u.Path
}

// Method returns request method.
func (c *Context) Method() string {
return c.request.Method
}

// QueryParam returns value of a query parameter
func (c *Context) QueryParam(name string) string {
queryParam := c.request.URL.Query().Get(name)
Expand Down Expand Up @@ -205,7 +222,6 @@ func (c *Context) GetRequestHeader(key string) string {
// writeContentType sets content type header for response.
// It won't overwrite content type if it's already set.
func (c *Context) writeContentType(contentType string) {
contentTypeHeader := "Content-Type"
if c.GetResponseHeader(contentTypeHeader) == "" {
c.SetResponseHeader(contentTypeHeader, contentType)
}
Expand All @@ -228,6 +244,36 @@ func (c *Context) Get(key string) (any, bool) {
return val, ok
}

// Clone clones the context and returns it.
//
// Should be used when context is passed to the background jobs.
//
// Writes to the response of a cloned context will panic.
func (c *Context) Clone() *Context {
ctx := Context{
request: c.request.Clone(context.Background()),
response: c.response.(*response).clone(),
kid: c.kid,
lock: sync.Mutex{},
}

// Copy path params.
params := make(Params, len(c.params))
for k, v := range c.params {
params[k] = v
}
ctx.params = params

// Copy storage.
storage := make(Map, len(c.storage))
for k, v := range c.storage {
storage[k] = v
}
ctx.storage = storage

return &ctx
}

// Debug returns whether we are in debug mode or not.
func (c *Context) Debug() bool {
return c.kid.Debug()
Expand Down
58 changes: 58 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,61 @@ func TestContext_SetRequestHeader(t *testing.T) {
ctx.SetRequestHeader("key", "value")
assert.Equal(t, "value", ctx.GetRequestHeader("key"))
}

func TestContext_Path(t *testing.T) {
ctx := newContext(New())

req := httptest.NewRequest(http.MethodGet, "/path", nil)
req.URL.Path = "/path"
req.URL.RawPath = ""

ctx.reset(req, nil)

assert.Equal(t, "/path", ctx.Path())

req.URL.RawPath = "/path2"
assert.Equal(t, "/path2", ctx.Path())
}

func TestContext_Method(t *testing.T) {
ctx := newContext(New())

req := httptest.NewRequest(http.MethodDelete, "/path", nil)

ctx.reset(req, nil)

assert.Equal(t, http.MethodDelete, ctx.Method())
}

func TestContext_Clone(t *testing.T) {
ctx := newContext(New())

req := httptest.NewRequest(http.MethodDelete, "/path", nil)
res := httptest.NewRecorder()

ctx.reset(req, res)

ctx.Set("key", "value1")
ctx.params["key"] = "value2"

clonedCtx := ctx.Clone()

ctx.Set("key", "value")
ctx.params["key"] = "value"

assert.NotEqual(t, ctx, clonedCtx)
assert.NotEqual(t, ctx.request, clonedCtx.request)
assert.NotEqual(t, ctx.response, clonedCtx.response)
assert.Equal(t, "value1", clonedCtx.storage["key"])
assert.Equal(t, "value2", clonedCtx.Param("key"))
assert.Equal(t, ctx.kid, clonedCtx.kid)
assert.NotNil(t, ctx.response.(*response).ResponseWriter)

assert.Panics(t, func() {
clonedCtx.Byte(http.StatusAccepted, []byte("test"))
})

assert.NotPanics(t, func() {
clonedCtx.Response().Status()
})
}
11 changes: 1 addition & 10 deletions kid.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"reflect"
"runtime"
Expand Down Expand Up @@ -223,7 +222,7 @@ func (k *Kid) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := k.pool.Get().(*Context)
c.reset(r, w)

route, params, err := k.router.search(getPath(r.URL), r.Method)
route, params, err := k.router.search(c.Path(), r.Method)

c.setParams(params)

Expand Down Expand Up @@ -294,14 +293,6 @@ func (k *Kid) printDebug(w io.Writer, format string, values ...any) {
}
}

// getPath returns request's path.
func getPath(u *url.URL) string {
if u.RawPath != "" {
return u.RawPath
}
return u.Path
}

// resolveAddress returns the address which server will run on.
func resolveAddress(addresses []string, goos string) string {
if len(addresses) == 0 {
Expand Down
15 changes: 0 additions & 15 deletions kid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -622,20 +621,6 @@ func TestResolveAddress(t *testing.T) {
assert.Equal(t, ":2377", addr)
}

func TestGetPath(t *testing.T) {
u, err := url.Parse("http://localhost/foo%25fbar")
assert.NoError(t, err)

assert.Empty(t, u.RawPath)
assert.Equal(t, u.Path, getPath(u))

u, err = url.Parse("http://localhost/foo%fbar")
assert.NoError(t, err)

assert.NotEmpty(t, u.RawPath)
assert.Equal(t, u.RawPath, getPath(u))
}

func TestApplyOptions(t *testing.T) {
k := New()

Expand Down
7 changes: 3 additions & 4 deletions middlewares/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ func NewLoggerWithConfig(cfg LoggerConfig) kid.MiddlewareFunc {
next(c)

end := time.Now()
req := c.Request()
duration := end.Sub(start)

status := c.Response().Status()
Expand All @@ -108,9 +107,9 @@ func NewLoggerWithConfig(cfg LoggerConfig) kid.MiddlewareFunc {
slog.Duration("latency_ns", duration),
slog.String("latency", duration.String()),
slog.Int("status", status),
slog.String("path", req.URL.Path),
slog.String("method", req.Method),
slog.String("user_agent", req.Header.Get("User-Agent")),
slog.String("path", c.Path()),
slog.String("method", c.Method()),
slog.String("user_agent", c.GetRequestHeader("User-Agent")),
}

if status < 400 {
Expand Down
8 changes: 8 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,11 @@ func (r *response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
r.written = true
return r.ResponseWriter.(http.Hijacker).Hijack()
}

// clone clones the current response instance.
//
// No writes are permitted.
func (r response) clone() *response {
r.ResponseWriter = nil
return &r
}
16 changes: 16 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,19 @@ func TestResponseWriter_Status(t *testing.T) {

assert.Equal(t, http.StatusAccepted, res.Status())
}

func TestResponse_clone(t *testing.T) {
res := newResponse(httptest.NewRecorder()).(*response)
res.size = 10
res.status = 200
res.written = true

clonedRes := res.clone()

assert.NotEqual(t, res, clonedRes)
assert.NotNil(t, res.ResponseWriter)
assert.Nil(t, clonedRes.ResponseWriter)
assert.Equal(t, res.size, clonedRes.size)
assert.Equal(t, res.written, clonedRes.written)
assert.Equal(t, res.status, clonedRes.status)
}