Skip to content

Commit 81cb257

Browse files
authored
feat: added some helper methods to kid.Context (#55)
1 parent 28b3f7a commit 81cb257

File tree

7 files changed

+133
-30
lines changed

7 files changed

+133
-30
lines changed

context.go

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package kid
22

33
import (
4+
"context"
45
"net/http"
56
"net/url"
67
"sync"
78
)
89

10+
const contentTypeHeader string = "Content-Type"
11+
912
// Context is the context of current HTTP request.
1013
// It holds data related to current HTTP request.
1114
type Context struct {
@@ -56,6 +59,20 @@ func (c *Context) Params() Params {
5659
return c.params
5760
}
5861

62+
// Path returns request's path used for matching request to a handler.
63+
func (c *Context) Path() string {
64+
u := c.request.URL
65+
if u.RawPath != "" {
66+
return u.RawPath
67+
}
68+
return u.Path
69+
}
70+
71+
// Method returns request method.
72+
func (c *Context) Method() string {
73+
return c.request.Method
74+
}
75+
5976
// QueryParam returns value of a query parameter
6077
func (c *Context) QueryParam(name string) string {
6178
queryParam := c.request.URL.Query().Get(name)
@@ -205,7 +222,6 @@ func (c *Context) GetRequestHeader(key string) string {
205222
// writeContentType sets content type header for response.
206223
// It won't overwrite content type if it's already set.
207224
func (c *Context) writeContentType(contentType string) {
208-
contentTypeHeader := "Content-Type"
209225
if c.GetResponseHeader(contentTypeHeader) == "" {
210226
c.SetResponseHeader(contentTypeHeader, contentType)
211227
}
@@ -228,6 +244,36 @@ func (c *Context) Get(key string) (any, bool) {
228244
return val, ok
229245
}
230246

247+
// Clone clones the context and returns it.
248+
//
249+
// Should be used when context is passed to the background jobs.
250+
//
251+
// Writes to the response of a cloned context will panic.
252+
func (c *Context) Clone() *Context {
253+
ctx := Context{
254+
request: c.request.Clone(context.Background()),
255+
response: c.response.(*response).clone(),
256+
kid: c.kid,
257+
lock: sync.Mutex{},
258+
}
259+
260+
// Copy path params.
261+
params := make(Params, len(c.params))
262+
for k, v := range c.params {
263+
params[k] = v
264+
}
265+
ctx.params = params
266+
267+
// Copy storage.
268+
storage := make(Map, len(c.storage))
269+
for k, v := range c.storage {
270+
storage[k] = v
271+
}
272+
ctx.storage = storage
273+
274+
return &ctx
275+
}
276+
231277
// Debug returns whether we are in debug mode or not.
232278
func (c *Context) Debug() bool {
233279
return c.kid.Debug()

context_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,61 @@ func TestContext_SetRequestHeader(t *testing.T) {
534534
ctx.SetRequestHeader("key", "value")
535535
assert.Equal(t, "value", ctx.GetRequestHeader("key"))
536536
}
537+
538+
func TestContext_Path(t *testing.T) {
539+
ctx := newContext(New())
540+
541+
req := httptest.NewRequest(http.MethodGet, "/path", nil)
542+
req.URL.Path = "/path"
543+
req.URL.RawPath = ""
544+
545+
ctx.reset(req, nil)
546+
547+
assert.Equal(t, "/path", ctx.Path())
548+
549+
req.URL.RawPath = "/path2"
550+
assert.Equal(t, "/path2", ctx.Path())
551+
}
552+
553+
func TestContext_Method(t *testing.T) {
554+
ctx := newContext(New())
555+
556+
req := httptest.NewRequest(http.MethodDelete, "/path", nil)
557+
558+
ctx.reset(req, nil)
559+
560+
assert.Equal(t, http.MethodDelete, ctx.Method())
561+
}
562+
563+
func TestContext_Clone(t *testing.T) {
564+
ctx := newContext(New())
565+
566+
req := httptest.NewRequest(http.MethodDelete, "/path", nil)
567+
res := httptest.NewRecorder()
568+
569+
ctx.reset(req, res)
570+
571+
ctx.Set("key", "value1")
572+
ctx.params["key"] = "value2"
573+
574+
clonedCtx := ctx.Clone()
575+
576+
ctx.Set("key", "value")
577+
ctx.params["key"] = "value"
578+
579+
assert.NotEqual(t, ctx, clonedCtx)
580+
assert.NotEqual(t, ctx.request, clonedCtx.request)
581+
assert.NotEqual(t, ctx.response, clonedCtx.response)
582+
assert.Equal(t, "value1", clonedCtx.storage["key"])
583+
assert.Equal(t, "value2", clonedCtx.Param("key"))
584+
assert.Equal(t, ctx.kid, clonedCtx.kid)
585+
assert.NotNil(t, ctx.response.(*response).ResponseWriter)
586+
587+
assert.Panics(t, func() {
588+
clonedCtx.Byte(http.StatusAccepted, []byte("test"))
589+
})
590+
591+
assert.NotPanics(t, func() {
592+
clonedCtx.Response().Status()
593+
})
594+
}

kid.go

+1-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"io"
77
"net/http"
8-
"net/url"
98
"os"
109
"reflect"
1110
"runtime"
@@ -223,7 +222,7 @@ func (k *Kid) ServeHTTP(w http.ResponseWriter, r *http.Request) {
223222
c := k.pool.Get().(*Context)
224223
c.reset(r, w)
225224

226-
route, params, err := k.router.search(getPath(r.URL), r.Method)
225+
route, params, err := k.router.search(c.Path(), r.Method)
227226

228227
c.setParams(params)
229228

@@ -294,14 +293,6 @@ func (k *Kid) printDebug(w io.Writer, format string, values ...any) {
294293
}
295294
}
296295

297-
// getPath returns request's path.
298-
func getPath(u *url.URL) string {
299-
if u.RawPath != "" {
300-
return u.RawPath
301-
}
302-
return u.Path
303-
}
304-
305296
// resolveAddress returns the address which server will run on.
306297
func resolveAddress(addresses []string, goos string) string {
307298
if len(addresses) == 0 {

kid_test.go

-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"net/http"
1010
"net/http/httptest"
11-
"net/url"
1211
"testing"
1312
"time"
1413

@@ -622,20 +621,6 @@ func TestResolveAddress(t *testing.T) {
622621
assert.Equal(t, ":2377", addr)
623622
}
624623

625-
func TestGetPath(t *testing.T) {
626-
u, err := url.Parse("http://localhost/foo%25fbar")
627-
assert.NoError(t, err)
628-
629-
assert.Empty(t, u.RawPath)
630-
assert.Equal(t, u.Path, getPath(u))
631-
632-
u, err = url.Parse("http://localhost/foo%fbar")
633-
assert.NoError(t, err)
634-
635-
assert.NotEmpty(t, u.RawPath)
636-
assert.Equal(t, u.RawPath, getPath(u))
637-
}
638-
639624
func TestApplyOptions(t *testing.T) {
640625
k := New()
641626

middlewares/logger.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ func NewLoggerWithConfig(cfg LoggerConfig) kid.MiddlewareFunc {
9898
next(c)
9999

100100
end := time.Now()
101-
req := c.Request()
102101
duration := end.Sub(start)
103102

104103
status := c.Response().Status()
@@ -108,9 +107,9 @@ func NewLoggerWithConfig(cfg LoggerConfig) kid.MiddlewareFunc {
108107
slog.Duration("latency_ns", duration),
109108
slog.String("latency", duration.String()),
110109
slog.Int("status", status),
111-
slog.String("path", req.URL.Path),
112-
slog.String("method", req.Method),
113-
slog.String("user_agent", req.Header.Get("User-Agent")),
110+
slog.String("path", c.Path()),
111+
slog.String("method", c.Method()),
112+
slog.String("user_agent", c.GetRequestHeader("User-Agent")),
114113
}
115114

116115
if status < 400 {

response.go

+8
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,11 @@ func (r *response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
104104
r.written = true
105105
return r.ResponseWriter.(http.Hijacker).Hijack()
106106
}
107+
108+
// clone clones the current response instance.
109+
//
110+
// No writes are permitted.
111+
func (r response) clone() *response {
112+
r.ResponseWriter = nil
113+
return &r
114+
}

response_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,19 @@ func TestResponseWriter_Status(t *testing.T) {
104104

105105
assert.Equal(t, http.StatusAccepted, res.Status())
106106
}
107+
108+
func TestResponse_clone(t *testing.T) {
109+
res := newResponse(httptest.NewRecorder()).(*response)
110+
res.size = 10
111+
res.status = 200
112+
res.written = true
113+
114+
clonedRes := res.clone()
115+
116+
assert.NotEqual(t, res, clonedRes)
117+
assert.NotNil(t, res.ResponseWriter)
118+
assert.Nil(t, clonedRes.ResponseWriter)
119+
assert.Equal(t, res.size, clonedRes.size)
120+
assert.Equal(t, res.written, clonedRes.written)
121+
assert.Equal(t, res.status, clonedRes.status)
122+
}

0 commit comments

Comments
 (0)