Skip to content

Commit 0a3a76b

Browse files
authored
Add cache middleware to zrouter and fix ttl on combinedCache (#63)
* Add cache middleware to zrouter * Add headers to response * Some improvements * minor fix
1 parent fcd31fe commit 0a3a76b

File tree

6 files changed

+174
-3
lines changed

6 files changed

+174
-3
lines changed

pkg/zcache/combined_cache.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type combinedCache struct {
2424
appName string
2525
}
2626

27-
func (c *combinedCache) Set(ctx context.Context, key string, value interface{}, _ time.Duration) error {
27+
func (c *combinedCache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
2828
c.logger.Sugar().Debugf("set key on combined cache, key: [%s]", key)
2929

3030
if err := c.remoteCache.Set(ctx, key, value, c.ttl); err != nil {
@@ -35,8 +35,7 @@ func (c *combinedCache) Set(ctx context.Context, key string, value interface{},
3535
}
3636
}
3737

38-
// ttl is controlled by cache instantiation, so it does not matter here
39-
if err := c.localCache.Set(ctx, key, value, c.ttl); err != nil {
38+
if err := c.localCache.Set(ctx, key, value, ttl); err != nil {
4039
c.logger.Sugar().Errorf("error setting key on combined/local cache, key: [%s], err: %s", key, err)
4140
return err
4241
}

pkg/zcache/zcache_mock.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,29 @@ package zcache
33
import (
44
"context"
55
"github.com/stretchr/testify/mock"
6+
"github.com/zondax/golem/pkg/metrics"
67
"time"
78
)
89

910
type MockZCache struct {
1011
mock.Mock
1112
}
1213

14+
func (m *MockZCache) GetStats() ZCacheStats {
15+
args := m.Called()
16+
return args.Get(0).(ZCacheStats)
17+
}
18+
19+
func (m *MockZCache) IsNotFoundError(err error) bool {
20+
args := m.Called(err)
21+
return args.Bool(0)
22+
}
23+
24+
func (m *MockZCache) SetupAndMonitorMetrics(appName string, metricsServer metrics.TaskMetrics, updateInterval time.Duration) []error {
25+
args := m.Called(appName, metricsServer, updateInterval)
26+
return args.Get(0).([]error)
27+
}
28+
1329
func (m *MockZCache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
1430
args := m.Called(ctx, key, value, ttl)
1531
return args.Error(0)

pkg/zrouter/domain/cache.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package domain
2+
3+
import "time"
4+
5+
type CacheConfig struct {
6+
Paths map[string]time.Duration
7+
}

pkg/zrouter/zmiddlewares/cache.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package zmiddlewares
2+
3+
import (
4+
"fmt"
5+
"github.com/zondax/golem/pkg/zcache"
6+
"github.com/zondax/golem/pkg/zrouter/domain"
7+
"go.uber.org/zap"
8+
"net/http"
9+
"runtime/debug"
10+
"time"
11+
)
12+
13+
const (
14+
cacheKeyPrefix = "zrouter_cache"
15+
)
16+
17+
func CacheMiddleware(cache zcache.ZCache, config domain.CacheConfig, logger *zap.SugaredLogger) func(next http.Handler) http.Handler {
18+
return func(next http.Handler) http.Handler {
19+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20+
path := r.URL.Path
21+
fullURL := constructFullURL(r)
22+
23+
if ttl, found := config.Paths[path]; found {
24+
key := constructCacheKey(fullURL)
25+
26+
if tryServeFromCache(w, r, cache, key) {
27+
return
28+
}
29+
30+
mrw := &metricsResponseWriter{ResponseWriter: w}
31+
next.ServeHTTP(mrw, r) // Important: This line needs to be BEFORE setting the cache.
32+
cacheResponseIfNeeded(mrw, r, cache, key, ttl, logger)
33+
}
34+
})
35+
}
36+
}
37+
38+
func constructFullURL(r *http.Request) string {
39+
fullURL := r.URL.Path
40+
if queryString := r.URL.RawQuery; queryString != "" {
41+
fullURL += "?" + queryString
42+
}
43+
return fullURL
44+
}
45+
46+
func constructCacheKey(fullURL string) string {
47+
return fmt.Sprintf("%s:%s", cacheKeyPrefix, fullURL)
48+
}
49+
50+
func tryServeFromCache(w http.ResponseWriter, r *http.Request, cache zcache.ZCache, key string) bool {
51+
var cachedResponse []byte
52+
err := cache.Get(r.Context(), key, &cachedResponse)
53+
if err == nil && cachedResponse != nil {
54+
w.Header().Set(domain.ContentTypeHeader, domain.ContentTypeApplicationJSON)
55+
_, _ = w.Write(cachedResponse)
56+
return true
57+
}
58+
return false
59+
}
60+
61+
func cacheResponseIfNeeded(mrw *metricsResponseWriter, r *http.Request, cache zcache.ZCache, key string, ttl time.Duration, logger *zap.SugaredLogger) {
62+
if mrw.status != http.StatusOK {
63+
return
64+
}
65+
66+
responseBody := mrw.Body()
67+
if err := cache.Set(r.Context(), key, responseBody, ttl); err != nil {
68+
logger.Errorf("Internal error when setting cache response: %v\n%s", err, debug.Stack())
69+
}
70+
}

pkg/zrouter/zmiddlewares/middleware.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package zmiddlewares
22

33
import (
4+
"bytes"
45
"net/http"
56
)
67

@@ -10,6 +11,7 @@ type metricsResponseWriter struct {
1011
http.ResponseWriter
1112
status int
1213
written int64
14+
body *bytes.Buffer
1315
}
1416

1517
func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
@@ -18,7 +20,18 @@ func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
1820
}
1921

2022
func (mrw *metricsResponseWriter) Write(p []byte) (int, error) {
23+
if mrw.body == nil {
24+
mrw.body = new(bytes.Buffer)
25+
}
26+
mrw.body.Write(p)
2127
n, err := mrw.ResponseWriter.Write(p)
2228
mrw.written += int64(n)
2329
return n, err
2430
}
31+
32+
func (mrw *metricsResponseWriter) Body() []byte {
33+
if mrw.body != nil {
34+
return mrw.body.Bytes()
35+
}
36+
return nil
37+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package zmiddlewares
2+
3+
import (
4+
"github.com/zondax/golem/pkg/zrouter/domain"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"github.com/go-chi/chi/v5"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/mock"
13+
"github.com/zondax/golem/pkg/zcache"
14+
"go.uber.org/zap"
15+
)
16+
17+
func TestCacheMiddleware(t *testing.T) {
18+
r := chi.NewRouter()
19+
logger, _ := zap.NewDevelopment()
20+
mockCache := new(zcache.MockZCache)
21+
22+
cacheConfig := domain.CacheConfig{Paths: map[string]time.Duration{
23+
"/cached-path": 5 * time.Minute,
24+
}}
25+
26+
r.Use(CacheMiddleware(mockCache, cacheConfig, logger.Sugar()))
27+
28+
// Simulate a response that should be cached
29+
r.Get("/cached-path", func(w http.ResponseWriter, r *http.Request) {
30+
w.WriteHeader(http.StatusOK)
31+
_, _ = w.Write([]byte("Test!"))
32+
})
33+
34+
cachedResponseBody := []byte("Test!")
35+
36+
// Setup the mock for the first request (cache miss)
37+
mockCache.On("Get", mock.Anything, "zrouter_cache:/cached-path", mock.AnythingOfType("*[]uint8")).Return(nil).Once()
38+
mockCache.On("Set", mock.Anything, "zrouter_cache:/cached-path", cachedResponseBody, 5*time.Minute).Return(nil).Once()
39+
40+
// Setup the mock for the second request (cache hit)
41+
mockCache.On("Get", mock.Anything, "zrouter_cache:/cached-path", mock.AnythingOfType("*[]uint8")).Return(nil).Run(func(args mock.Arguments) {
42+
arg := args.Get(2).(*[]byte) // Get the argument where the cached response will be stored
43+
*arg = cachedResponseBody // Simulate the cached response
44+
})
45+
46+
// Perform the first request: the response should be generated and cached
47+
req := httptest.NewRequest("GET", "/cached-path", nil)
48+
rec := httptest.NewRecorder()
49+
r.ServeHTTP(rec, req)
50+
51+
assert.Equal(t, http.StatusOK, rec.Code)
52+
assert.Equal(t, "Test!", rec.Body.String())
53+
54+
// Verify that the cache mock was invoked correctly
55+
mockCache.AssertExpectations(t)
56+
57+
// Perform the second request: the response should be served from the cache
58+
rec2 := httptest.NewRecorder()
59+
r.ServeHTTP(rec2, req)
60+
61+
assert.Equal(t, http.StatusOK, rec2.Code)
62+
assert.Equal(t, "Test!", rec2.Body.String())
63+
64+
// Verify that the cache mock was invoked correctly for the second request
65+
mockCache.AssertExpectations(t)
66+
}

0 commit comments

Comments
 (0)