From 4cce501050c54834897bd65d66bf54a965ce6893 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Mon, 27 May 2024 19:44:05 +0200 Subject: [PATCH 01/10] IAM: Add caching to HTTP client --- auth/client/iam/caching.go | 166 ++++++++++++++++++++++++++++++++ auth/client/iam/caching_test.go | 138 ++++++++++++++++++++++++++ auth/client/iam/openid4vp.go | 2 +- 3 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 auth/client/iam/caching.go create mode 100644 auth/client/iam/caching_test.go diff --git a/auth/client/iam/caching.go b/auth/client/iam/caching.go new file mode 100644 index 0000000000..c0c73a1d4b --- /dev/null +++ b/auth/client/iam/caching.go @@ -0,0 +1,166 @@ +package iam + +import ( + "bytes" + "fmt" + "github.com/nuts-foundation/nuts-node/auth/log" + "github.com/nuts-foundation/nuts-node/core" + "github.com/pquerna/cachecontrol" + "io" + "net/http" + "net/url" + "sync" + "time" +) + +// CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID/StatusList2021 clients. +// It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable. +type CachingHTTPRequestDoer struct { + MaxBytes int + Doer core.HTTPRequestDoer + + // currentSizeBytes is the current size of the cache in bytes. + // It's used to make room for new entries when the cache is full. + currentSizeBytes int + // head is the first entry of a linked list of cache entries, ordered by expiration time. + // The first entry is the one that will expire first, which optimizes the removal of expired entries. + head *cacheEntry + // entriesByURL is a map of cache entries, indexed by the URL of the request. + // This optimizes the lookup of cache entries by URL. + entriesByURL map[string][]*cacheEntry + mux sync.RWMutex +} + +type cacheEntry struct { + responseData []byte + requestURL *url.URL + requestMethod string + requestRawQuery string + expirationTime time.Time + next *cacheEntry + responseStatus int + responseHeaders http.Header +} + +func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, error) { + if httpRequest.Method == http.MethodGet { + if response := h.getCachedEntry(httpRequest); response != nil { + return response, nil + } + } + + httpResponse, err := h.Doer.Do(httpRequest) + if err != nil { + return nil, err + } + if httpRequest.Method == http.MethodGet { + reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false}) + if err != nil { + log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String()) + } + if len(reasons) > 0 { + log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons) + return httpResponse, nil + } + responseBytes, err := io.ReadAll(httpResponse.Body) + if err != nil { + return nil, fmt.Errorf("error while reading response body for caching: %w", err) + } + h.mux.Lock() + defer h.mux.Unlock() + h.insert(&cacheEntry{ + responseData: responseBytes, + requestMethod: httpRequest.Method, + requestURL: httpRequest.URL, + requestRawQuery: httpRequest.URL.RawQuery, + responseStatus: httpResponse.StatusCode, + responseHeaders: httpResponse.Header, + expirationTime: expirationTime, + }) + httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) + } + return httpResponse, nil +} + +func (h *CachingHTTPRequestDoer) getCachedEntry(httpRequest *http.Request) *http.Response { + h.mux.Lock() + defer h.mux.Unlock() + h.removeExpiredEntries() + // Find cached response + entries := h.entriesByURL[httpRequest.URL.String()] + for _, entry := range entries { + if entry.requestMethod == httpRequest.Method && entry.requestRawQuery == httpRequest.URL.RawQuery { + return &http.Response{ + StatusCode: entry.responseStatus, + Header: entry.responseHeaders, + Body: io.NopCloser(bytes.NewReader(entry.responseData)), + } + } + } + return nil +} + +func (h *CachingHTTPRequestDoer) removeExpiredEntries() { + var current = h.head + for current != nil { + if current.expirationTime.Before(time.Now()) { + current = h.pop() + } else { + break + } + } +} + +func (h *CachingHTTPRequestDoer) prune(bytesRequired int) { + // See if we need to make room for the new entry + for h.currentSizeBytes+bytesRequired >= h.MaxBytes { + _ = h.pop() + } +} + +// insert adds a new entry to the cache. +func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { + if h.head == nil { + // First entry + h.head = entry + } else { + // Insert in the linked list, ordered by expiration time + var current = h.head + for current.next != nil && current.next.expirationTime.Before(entry.expirationTime) { + current = current.next + } + entry.next = current.next + current.next = entry + } + // Insert in the URL map for quick lookup + h.entriesByURL[entry.requestURL.String()] = append(h.entriesByURL[entry.requestURL.String()], entry) + + h.currentSizeBytes += len(entry.responseData) +} + +// pop removes the first entry from the linked list +func (h *CachingHTTPRequestDoer) pop() *cacheEntry { + requestURL := h.head.requestURL.String() + entries := h.entriesByURL[requestURL] + for i, entry := range entries { + if entry == h.head { + h.entriesByURL[requestURL] = append(entries[:i], entries[i+1:]...) + if len(h.entriesByURL[requestURL]) == 0 { + delete(h.entriesByURL, requestURL) + } + break + } + } + h.currentSizeBytes -= len(h.head.responseData) + h.head = h.head.next + return h.head +} + +func cacheHTTPResponses(requestDoer core.HTTPRequestDoer) *CachingHTTPRequestDoer { + return &CachingHTTPRequestDoer{ + MaxBytes: 10 * 1024 * 1024, + Doer: requestDoer, + entriesByURL: map[string][]*cacheEntry{}, + mux: sync.RWMutex{}, + } +} diff --git a/auth/client/iam/caching_test.go b/auth/client/iam/caching_test.go new file mode 100644 index 0000000000..f5004b2b6e --- /dev/null +++ b/auth/client/iam/caching_test.go @@ -0,0 +1,138 @@ +package iam + +import ( + "bytes" + "github.com/nuts-foundation/nuts-node/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net/http" + "testing" +) + +func Test_httpClientCache(t *testing.T) { + httpRequest := &http.Request{ + Method: http.MethodGet, + URL: test.MustParseURL("http://example.com"), + } + t.Run("does not cache POST requests", func(t *testing.T) { + client := cacheHTTPResponses(&stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "public, max-age=3600", + }, + }) + + _, err := client.Do(&http.Request{ + Method: http.MethodPost, + }) + + require.NoError(t, err) + assert.Equal(t, 0, client.currentSizeBytes) + }) + t.Run("caches GET request with max-age", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := cacheHTTPResponses(requestSink) + + httpResponse, err := client.Do(httpRequest) + require.NoError(t, err) + fetchedResponseData, _ := io.ReadAll(httpResponse.Body) + httpResponse, err = client.Do(httpRequest) + require.NoError(t, err) + cachedResponseData, _ := io.ReadAll(httpResponse.Body) + + assert.Equal(t, 13, client.currentSizeBytes) + assert.Equal(t, 1, requestSink.invocations) + assert.Equal(t, "Hello, World!", string(fetchedResponseData)) + assert.Equal(t, "Hello, World!", string(cachedResponseData)) + }) + t.Run("2 cache entries with different query parameters", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + requestSink.dataFn = func(req *http.Request) []byte { + return []byte(req.URL.String()) + } + client := cacheHTTPResponses(requestSink) + + // Initial fetch of the resources + _, err := client.Do(httpRequest) + require.NoError(t, err) + alternativeRequest := &http.Request{ + Method: http.MethodGet, + URL: test.MustParseURL("http://example.com?foo=bar"), + } + _, err = client.Do(alternativeRequest) + require.NoError(t, err) + assert.Equal(t, 2, requestSink.invocations) + + // Fetch the responses again, should be taken from cache + response1, _ := client.Do(httpRequest) + response1Data, _ := io.ReadAll(response1.Body) + response2, _ := client.Do(alternativeRequest) + response2Data, _ := io.ReadAll(response2.Body) + assert.Equal(t, 2, requestSink.invocations) + assert.Equal(t, "http://example.com", string(response1Data)) + assert.Equal(t, "http://example.com?foo=bar", string(response2Data)) + }) + t.Run("prunes cache when full", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := cacheHTTPResponses(requestSink) + client.MaxBytes = 14 + client.currentSizeBytes = 1 + + // Fill the cache + _, err := client.Do(httpRequest) + require.NoError(t, err) + _, err = client.Do(httpRequest) + require.NoError(t, err) + assert.Equal(t, 13, client.currentSizeBytes) + + // Add a new entry, should prune the first one + _, err = client.Do(httpRequest) + require.NoError(t, err) + assert.Equal(t, 13, client.currentSizeBytes) + assert.Equal(t, 2, requestSink.invocations) + }) +} + +type stubRequestDoer struct { + statusCode int + data []byte + dataFn func(r *http.Request) []byte + headers map[string]string + invocations int +} + +func (s *stubRequestDoer) Do(req *http.Request) (*http.Response, error) { + s.invocations++ + response := &http.Response{ + StatusCode: s.statusCode, + } + if s.dataFn != nil { + response.Body = io.NopCloser(bytes.NewReader(s.dataFn(req))) + } else { + response.Body = io.NopCloser(bytes.NewReader(s.data)) + } + response.Header = make(http.Header) + for key, value := range s.headers { + response.Header.Set(key, value) + } + return response, nil +} diff --git a/auth/client/iam/openid4vp.go b/auth/client/iam/openid4vp.go index 23329b864e..4c7f32f076 100644 --- a/auth/client/iam/openid4vp.go +++ b/auth/client/iam/openid4vp.go @@ -56,7 +56,7 @@ func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner return &OpenID4VPClient{ httpClient: HTTPClient{ strictMode: strictMode, - httpClient: core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil), + httpClient: cacheHTTPResponses(core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil)), }, keyResolver: keyResolver, jwtSigner: jwtSigner, From 13ff08376b236112092263ec6b4b116bcc772389 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Mon, 27 May 2024 20:40:02 +0200 Subject: [PATCH 02/10] fix --- auth/client/iam/caching.go | 32 +++++++++++++++++--------------- auth/client/iam/caching_test.go | 15 +++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/auth/client/iam/caching.go b/auth/client/iam/caching.go index c0c73a1d4b..0056b7abb7 100644 --- a/auth/client/iam/caching.go +++ b/auth/client/iam/caching.go @@ -68,15 +68,17 @@ func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, } h.mux.Lock() defer h.mux.Unlock() - h.insert(&cacheEntry{ - responseData: responseBytes, - requestMethod: httpRequest.Method, - requestURL: httpRequest.URL, - requestRawQuery: httpRequest.URL.RawQuery, - responseStatus: httpResponse.StatusCode, - responseHeaders: httpResponse.Header, - expirationTime: expirationTime, - }) + if len(responseBytes) <= h.MaxBytes { // sanity check + h.insert(&cacheEntry{ + responseData: responseBytes, + requestMethod: httpRequest.Method, + requestURL: httpRequest.URL, + requestRawQuery: httpRequest.URL.RawQuery, + responseStatus: httpResponse.StatusCode, + responseHeaders: httpResponse.Header, + expirationTime: expirationTime, + }) + } httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) } return httpResponse, nil @@ -111,15 +113,12 @@ func (h *CachingHTTPRequestDoer) removeExpiredEntries() { } } -func (h *CachingHTTPRequestDoer) prune(bytesRequired int) { +// insert adds a new entry to the cache. +func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { // See if we need to make room for the new entry - for h.currentSizeBytes+bytesRequired >= h.MaxBytes { + for h.currentSizeBytes+len(entry.responseData) >= h.MaxBytes { _ = h.pop() } -} - -// insert adds a new entry to the cache. -func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { if h.head == nil { // First entry h.head = entry @@ -140,6 +139,9 @@ func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { // pop removes the first entry from the linked list func (h *CachingHTTPRequestDoer) pop() *cacheEntry { + if h.head == nil { + return nil + } requestURL := h.head.requestURL.String() entries := h.entriesByURL[requestURL] for i, entry := range entries { diff --git a/auth/client/iam/caching_test.go b/auth/client/iam/caching_test.go index f5004b2b6e..13cc37cacb 100644 --- a/auth/client/iam/caching_test.go +++ b/auth/client/iam/caching_test.go @@ -95,20 +95,15 @@ func Test_httpClientCache(t *testing.T) { } client := cacheHTTPResponses(requestSink) client.MaxBytes = 14 - client.currentSizeBytes = 1 + client.currentSizeBytes = 5 + client.head = &cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com"), + } - // Fill the cache _, err := client.Do(httpRequest) require.NoError(t, err) - _, err = client.Do(httpRequest) - require.NoError(t, err) - assert.Equal(t, 13, client.currentSizeBytes) - - // Add a new entry, should prune the first one - _, err = client.Do(httpRequest) - require.NoError(t, err) assert.Equal(t, 13, client.currentSizeBytes) - assert.Equal(t, 2, requestSink.invocations) }) } From 95f30028318258a347005355d34d93f14d4c446d Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Mon, 27 May 2024 21:06:06 +0200 Subject: [PATCH 03/10] tests --- auth/client/iam/caching.go | 30 +++++++--- auth/client/iam/caching_test.go | 98 ++++++++++++++++++++++++++++++--- 2 files changed, 110 insertions(+), 18 deletions(-) diff --git a/auth/client/iam/caching.go b/auth/client/iam/caching.go index 0056b7abb7..b491c367b5 100644 --- a/auth/client/iam/caching.go +++ b/auth/client/iam/caching.go @@ -13,17 +13,19 @@ import ( "time" ) -// CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID/StatusList2021 clients. +// CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID clients. // It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable. +// It only works on expiration time and does not respect ETags headers. type CachingHTTPRequestDoer struct { - MaxBytes int - Doer core.HTTPRequestDoer + maxBytes int + requestDoer core.HTTPRequestDoer // currentSizeBytes is the current size of the cache in bytes. // It's used to make room for new entries when the cache is full. currentSizeBytes int // head is the first entry of a linked list of cache entries, ordered by expiration time. // The first entry is the one that will expire first, which optimizes the removal of expired entries. + // When an entry is inserted in the cache, it's inserted in the right place in the linked list (ordered by expiry). head *cacheEntry // entriesByURL is a map of cache entries, indexed by the URL of the request. // This optimizes the lookup of cache entries by URL. @@ -49,7 +51,7 @@ func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, } } - httpResponse, err := h.Doer.Do(httpRequest) + httpResponse, err := h.requestDoer.Do(httpRequest) if err != nil { return nil, err } @@ -57,8 +59,15 @@ func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false}) if err != nil { log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String()) + return httpResponse, nil + } + // We don't want to cache responses for too long, as that increases the risk of staleness, + // and could keep cause very long-lived entries to never be pruned. + maxExpirationTime := time.Now().Add(time.Hour) + if expirationTime.After(maxExpirationTime) { + expirationTime = maxExpirationTime } - if len(reasons) > 0 { + if len(reasons) > 0 || expirationTime.IsZero() { log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons) return httpResponse, nil } @@ -68,7 +77,7 @@ func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, } h.mux.Lock() defer h.mux.Unlock() - if len(responseBytes) <= h.MaxBytes { // sanity check + if len(responseBytes) <= h.maxBytes { // sanity check h.insert(&cacheEntry{ responseData: responseBytes, requestMethod: httpRequest.Method, @@ -116,7 +125,7 @@ func (h *CachingHTTPRequestDoer) removeExpiredEntries() { // insert adds a new entry to the cache. func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { // See if we need to make room for the new entry - for h.currentSizeBytes+len(entry.responseData) >= h.MaxBytes { + for h.currentSizeBytes+len(entry.responseData) >= h.maxBytes { _ = h.pop() } if h.head == nil { @@ -128,6 +137,9 @@ func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { for current.next != nil && current.next.expirationTime.Before(entry.expirationTime) { current = current.next } + if current == h.head { + h.head = entry + } entry.next = current.next current.next = entry } @@ -160,8 +172,8 @@ func (h *CachingHTTPRequestDoer) pop() *cacheEntry { func cacheHTTPResponses(requestDoer core.HTTPRequestDoer) *CachingHTTPRequestDoer { return &CachingHTTPRequestDoer{ - MaxBytes: 10 * 1024 * 1024, - Doer: requestDoer, + maxBytes: 10 * 1024 * 1024, + requestDoer: requestDoer, entriesByURL: map[string][]*cacheEntry{}, mux: sync.RWMutex{}, } diff --git a/auth/client/iam/caching_test.go b/auth/client/iam/caching_test.go index 13cc37cacb..4040a23d1c 100644 --- a/auth/client/iam/caching_test.go +++ b/auth/client/iam/caching_test.go @@ -2,12 +2,14 @@ package iam import ( "bytes" + "fmt" "github.com/nuts-foundation/nuts-node/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io" "net/http" "testing" + "time" ) func Test_httpClientCache(t *testing.T) { @@ -33,7 +35,7 @@ func Test_httpClientCache(t *testing.T) { }) t.Run("caches GET request with max-age", func(t *testing.T) { requestSink := &stubRequestDoer{ - statusCode: http.StatusOK, + statusCode: http.StatusCreated, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "max-age=3600", @@ -41,17 +43,49 @@ func Test_httpClientCache(t *testing.T) { } client := cacheHTTPResponses(requestSink) + // Initial fetch httpResponse, err := client.Do(httpRequest) require.NoError(t, err) + assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) fetchedResponseData, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(fetchedResponseData)) + + // Fetch the response again, should be taken from cache httpResponse, err = client.Do(httpRequest) require.NoError(t, err) + assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) cachedResponseData, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(cachedResponseData)) assert.Equal(t, 13, client.currentSizeBytes) assert.Equal(t, 1, requestSink.invocations) - assert.Equal(t, "Hello, World!", string(fetchedResponseData)) - assert.Equal(t, "Hello, World!", string(cachedResponseData)) + }) + t.Run("does not cache responses with no-store", func(t *testing.T) { + client := cacheHTTPResponses(&stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "nothing", + }, + }) + + _, err := client.Do(httpRequest) + require.NoError(t, err) + assert.Equal(t, 0, client.currentSizeBytes) + }) + t.Run("max-age is too long", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": fmt.Sprintf("max-age=%d", int(time.Hour.Seconds()*24)), + }, + } + client := cacheHTTPResponses(requestSink) + + _, err := client.Do(httpRequest) + require.NoError(t, err) + assert.LessOrEqual(t, time.Now().Sub(client.head.expirationTime), time.Hour) }) t.Run("2 cache entries with different query parameters", func(t *testing.T) { requestSink := &stubRequestDoer{ @@ -94,17 +128,63 @@ func Test_httpClientCache(t *testing.T) { }, } client := cacheHTTPResponses(requestSink) - client.MaxBytes = 14 - client.currentSizeBytes = 5 - client.head = &cacheEntry{ - responseData: []byte("Hello"), - requestURL: test.MustParseURL("http://example.com"), - } + client.maxBytes = 14 + client.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com"), + expirationTime: time.Now().Add(time.Hour), + }) _, err := client.Do(httpRequest) require.NoError(t, err) assert.Equal(t, 13, client.currentSizeBytes) }) + t.Run("orders entries by expirationTime for optimized pruning", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := cacheHTTPResponses(requestSink) + client.maxBytes = 10000 + client.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/3"), + expirationTime: time.Now().Add(time.Hour * 3), + }) + assert.Equal(t, client.head.requestURL.String(), "http://example.com/3") + client.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/2"), + expirationTime: time.Now().Add(time.Hour * 2), + }) + assert.Equal(t, client.head.requestURL.String(), "http://example.com/2") + client.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/1"), + expirationTime: time.Now().Add(time.Hour), + }) + assert.Equal(t, client.head.requestURL.String(), "http://example.com/1") + }) + t.Run("entries that exceed max cache size aren't cached", func(t *testing.T) { + requestSink := &stubRequestDoer{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := cacheHTTPResponses(requestSink) + client.maxBytes = 5 + + httpResponse, err := client.Do(httpRequest) + require.NoError(t, err) + data, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(data)) + assert.Equal(t, 0, client.currentSizeBytes) + }) } type stubRequestDoer struct { From 326fd8481d7d25a567403bc12df45b0d7c39e739 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Tue, 28 May 2024 08:52:35 +0200 Subject: [PATCH 04/10] copyright --- auth/client/iam/caching.go | 18 ++++++++++++++++++ auth/client/iam/caching_test.go | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/auth/client/iam/caching.go b/auth/client/iam/caching.go index b491c367b5..4fa352148b 100644 --- a/auth/client/iam/caching.go +++ b/auth/client/iam/caching.go @@ -1,3 +1,21 @@ +/* + * Copyright (C) 2024 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + package iam import ( diff --git a/auth/client/iam/caching_test.go b/auth/client/iam/caching_test.go index 4040a23d1c..00ed27b0b3 100644 --- a/auth/client/iam/caching_test.go +++ b/auth/client/iam/caching_test.go @@ -1,3 +1,21 @@ +/* + * Copyright (C) 2024 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + package iam import ( From ac4f8c7ac252401c145fa7f8278d4f3669bb595e Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Fri, 31 May 2024 12:02:00 +0200 Subject: [PATCH 05/10] PR feedback --- README.rst | 3 +- auth/auth.go | 2 +- auth/client/iam/caching.go | 94 +++++++++++-------- auth/client/iam/caching_test.go | 23 ++--- auth/client/iam/openid4vp.go | 10 +- auth/cmd/cmd.go | 4 + auth/cmd/cmd_test.go | 1 + auth/config.go | 15 +-- docs/pages/deployment/cli-reference.rst | 1 + docs/pages/deployment/server_options.rst | 2 +- .../deployment/server_options_didnuts.rst | 1 + 11 files changed, 93 insertions(+), 63 deletions(-) diff --git a/README.rst b/README.rst index 528698e0ec..fb7d74d70f 100644 --- a/README.rst +++ b/README.rst @@ -199,7 +199,7 @@ The following options can be configured on the server: http.internal.auth.type Whether to enable authentication for /internal endpoints, specify 'token_v2' for bearer token mode or 'token' for legacy bearer token mode. http.public.address \:8080 Address and port the server will be listening to for public-facing endpoints. **JSONLD** - jsonld.contexts.localmapping [https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. + jsonld.contexts.localmapping [https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson,https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. jsonld.contexts.remoteallowlist [https://schema.org,https://www.w3.org/2018/credentials/v1,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json,https://w3id.org/vc/status-list/2021/v1] In strict mode, fetching external JSON-LD contexts is not allowed except for context-URLs listed here. **PKI** pki.maxupdatefailhours 4 Maximum number of hours that a denylist update can fail @@ -238,6 +238,7 @@ If your use case does not use these features, you can ignore this table. auth.accesstokenlifespan 60 defines how long (in seconds) an access token is valid. Uses default in strict mode. auth.clockskew 5000 allowed JWT Clock skew in milliseconds auth.contractvalidators [irma,dummy,employeeid] sets the different contract validators to use + auth.http.cache.maxbytes 10485760 HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. auth.irma.autoupdateschemas true set if you want automatically update the IRMA schemas every 60 minutes. auth.irma.schememanager pbdf IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'. **Events** diff --git a/auth/auth.go b/auth/auth.go index aebe86b7a2..4a83fe706c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -108,7 +108,7 @@ func (auth *Auth) RelyingParty() oauth.RelyingParty { func (auth *Auth) IAMClient() iam.Client { keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} - return iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, auth.httpClientTimeout) + return iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, auth.httpClientTimeout, auth.config.HTTPResponseCacheSize) } // Configure the Auth struct by creating a validator and create an Irma server diff --git a/auth/client/iam/caching.go b/auth/client/iam/caching.go index 4fa352148b..aaf31930eb 100644 --- a/auth/client/iam/caching.go +++ b/auth/client/iam/caching.go @@ -31,9 +31,14 @@ import ( "time" ) +// maxCacheTime is the maximum time responses are cached. +// Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime. +const maxCacheTime = time.Hour + // CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID clients. // It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable. // It only works on expiration time and does not respect ETags headers. +// When maxBytes is reached, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways). type CachingHTTPRequestDoer struct { maxBytes int requestDoer core.HTTPRequestDoer @@ -64,54 +69,64 @@ type cacheEntry struct { func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, error) { if httpRequest.Method == http.MethodGet { - if response := h.getCachedEntry(httpRequest); response != nil { + if response := h.cachedEntry(httpRequest); response != nil { return response, nil } } - httpResponse, err := h.requestDoer.Do(httpRequest) if err != nil { return nil, err } - if httpRequest.Method == http.MethodGet { - reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false}) - if err != nil { - log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String()) - return httpResponse, nil - } - // We don't want to cache responses for too long, as that increases the risk of staleness, - // and could keep cause very long-lived entries to never be pruned. - maxExpirationTime := time.Now().Add(time.Hour) - if expirationTime.After(maxExpirationTime) { - expirationTime = maxExpirationTime - } - if len(reasons) > 0 || expirationTime.IsZero() { - log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons) - return httpResponse, nil - } - responseBytes, err := io.ReadAll(httpResponse.Body) - if err != nil { - return nil, fmt.Errorf("error while reading response body for caching: %w", err) - } - h.mux.Lock() - defer h.mux.Unlock() - if len(responseBytes) <= h.maxBytes { // sanity check - h.insert(&cacheEntry{ - responseData: responseBytes, - requestMethod: httpRequest.Method, - requestURL: httpRequest.URL, - requestRawQuery: httpRequest.URL.RawQuery, - responseStatus: httpResponse.StatusCode, - responseHeaders: httpResponse.Header, - expirationTime: expirationTime, - }) - } - httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) + err = h.cacheResponse(httpRequest, httpResponse) + if err != nil { + return nil, err } return httpResponse, nil } -func (h *CachingHTTPRequestDoer) getCachedEntry(httpRequest *http.Request) *http.Response { +// cacheResponse caches the response if it's cacheable. +func (h *CachingHTTPRequestDoer) cacheResponse(httpRequest *http.Request, httpResponse *http.Response) error { + if httpRequest.Method != http.MethodGet { + return nil + } + reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false}) + if err != nil { + log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String()) + return nil + } + // We don't want to cache responses for too long, as that increases the risk of staleness, + // and could keep cause very long-lived entries to never be pruned. + maxExpirationTime := time.Now().Add(maxCacheTime) + if expirationTime.After(maxExpirationTime) { + expirationTime = maxExpirationTime + } + if len(reasons) > 0 || expirationTime.IsZero() { + log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons) + return nil + } + responseBytes, err := io.ReadAll(httpResponse.Body) + if err != nil { + return fmt.Errorf("error while reading response body for caching: %w", err) + } + h.mux.Lock() + defer h.mux.Unlock() + if len(responseBytes) <= h.maxBytes { // sanity check + h.insert(&cacheEntry{ + responseData: responseBytes, + requestMethod: httpRequest.Method, + requestURL: httpRequest.URL, + requestRawQuery: httpRequest.URL.RawQuery, + responseStatus: httpResponse.StatusCode, + responseHeaders: httpResponse.Header, + expirationTime: expirationTime, + }) + } + httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) + return nil +} + +// cachedEntry returns a cached response if it exists. +func (h *CachingHTTPRequestDoer) cachedEntry(httpRequest *http.Request) *http.Response { h.mux.Lock() defer h.mux.Unlock() h.removeExpiredEntries() @@ -188,9 +203,10 @@ func (h *CachingHTTPRequestDoer) pop() *cacheEntry { return h.head } -func cacheHTTPResponses(requestDoer core.HTTPRequestDoer) *CachingHTTPRequestDoer { +// cachingHTTPClient +func cachingHTTPClient(requestDoer core.HTTPRequestDoer, responsesCacheSize int) *CachingHTTPRequestDoer { return &CachingHTTPRequestDoer{ - maxBytes: 10 * 1024 * 1024, + maxBytes: responsesCacheSize, requestDoer: requestDoer, entriesByURL: map[string][]*cacheEntry{}, mux: sync.RWMutex{}, diff --git a/auth/client/iam/caching_test.go b/auth/client/iam/caching_test.go index 00ed27b0b3..7e67f64411 100644 --- a/auth/client/iam/caching_test.go +++ b/auth/client/iam/caching_test.go @@ -36,13 +36,13 @@ func Test_httpClientCache(t *testing.T) { URL: test.MustParseURL("http://example.com"), } t.Run("does not cache POST requests", func(t *testing.T) { - client := cacheHTTPResponses(&stubRequestDoer{ + client := cachingHTTPClient(&stubRequestDoer{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "public, max-age=3600", }, - }) + }, 1000) _, err := client.Do(&http.Request{ Method: http.MethodPost, @@ -59,7 +59,7 @@ func Test_httpClientCache(t *testing.T) { "Cache-Control": "max-age=3600", }, } - client := cacheHTTPResponses(requestSink) + client := cachingHTTPClient(requestSink, 1000) // Initial fetch httpResponse, err := client.Do(httpRequest) @@ -79,13 +79,13 @@ func Test_httpClientCache(t *testing.T) { assert.Equal(t, 1, requestSink.invocations) }) t.Run("does not cache responses with no-store", func(t *testing.T) { - client := cacheHTTPResponses(&stubRequestDoer{ + client := cachingHTTPClient(&stubRequestDoer{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "nothing", }, - }) + }, 1000) _, err := client.Do(httpRequest) require.NoError(t, err) @@ -99,7 +99,7 @@ func Test_httpClientCache(t *testing.T) { "Cache-Control": fmt.Sprintf("max-age=%d", int(time.Hour.Seconds()*24)), }, } - client := cacheHTTPResponses(requestSink) + client := cachingHTTPClient(requestSink, 1000) _, err := client.Do(httpRequest) require.NoError(t, err) @@ -115,7 +115,7 @@ func Test_httpClientCache(t *testing.T) { requestSink.dataFn = func(req *http.Request) []byte { return []byte(req.URL.String()) } - client := cacheHTTPResponses(requestSink) + client := cachingHTTPClient(requestSink, 1000) // Initial fetch of the resources _, err := client.Do(httpRequest) @@ -145,8 +145,7 @@ func Test_httpClientCache(t *testing.T) { "Cache-Control": "max-age=3600", }, } - client := cacheHTTPResponses(requestSink) - client.maxBytes = 14 + client := cachingHTTPClient(requestSink, 14) client.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com"), @@ -165,8 +164,7 @@ func Test_httpClientCache(t *testing.T) { "Cache-Control": "max-age=3600", }, } - client := cacheHTTPResponses(requestSink) - client.maxBytes = 10000 + client := cachingHTTPClient(requestSink, 10000) client.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com/3"), @@ -194,8 +192,7 @@ func Test_httpClientCache(t *testing.T) { "Cache-Control": "max-age=3600", }, } - client := cacheHTTPResponses(requestSink) - client.maxBytes = 5 + client := cachingHTTPClient(requestSink, 5) httpResponse, err := client.Do(httpRequest) require.NoError(t, err) diff --git a/auth/client/iam/openid4vp.go b/auth/client/iam/openid4vp.go index 4c7f32f076..d488d61157 100644 --- a/auth/client/iam/openid4vp.go +++ b/auth/client/iam/openid4vp.go @@ -52,11 +52,17 @@ type OpenID4VPClient struct { } // NewClient returns an implementation of Holder -func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner nutsCrypto.JWTSigner, strictMode bool, httpClientTimeout time.Duration) *OpenID4VPClient { +// responsesCacheSizeInBytes specifies the max. number of bytes to cache in memory. If 0, no caching is done. +func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner nutsCrypto.JWTSigner, strictMode bool, + httpClientTimeout time.Duration, responsesCacheSizeInBytes int) *OpenID4VPClient { + var requestDoer core.HTTPRequestDoer = core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil) + if responsesCacheSizeInBytes > 0 { + requestDoer = cachingHTTPClient(requestDoer, responsesCacheSizeInBytes) + } return &OpenID4VPClient{ httpClient: HTTPClient{ strictMode: strictMode, - httpClient: cacheHTTPResponses(core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil)), + httpClient: requestDoer, }, keyResolver: keyResolver, jwtSigner: jwtSigner, diff --git a/auth/cmd/cmd.go b/auth/cmd/cmd.go index 50832a1036..135181dcfa 100644 --- a/auth/cmd/cmd.go +++ b/auth/cmd/cmd.go @@ -38,6 +38,9 @@ const ConfIrmaSchemeManager = "auth.irma.schememanager" // ConfHTTPTimeout defines a timeout (in seconds) which is used by the Auth API HTTP client const ConfHTTPTimeout = "auth.http.timeout" +// ConfHTTPResponseCacheSize defines the maximum HTTP client response cache size in bytes. +const ConfHTTPResponseCacheSize = "auth.http.cache.maxbytes" + // ConfAccessTokenLifeSpan defines how long (in seconds) an access token is valid const ConfAccessTokenLifeSpan = "auth.accesstokenlifespan" @@ -49,6 +52,7 @@ func FlagSet() *pflag.FlagSet { flags.String(ConfIrmaSchemeManager, defs.Irma.SchemeManager, "IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'.") flags.Bool(ConfAutoUpdateIrmaSchemas, defs.Irma.AutoUpdateSchemas, "set if you want automatically update the IRMA schemas every 60 minutes.") flags.Int(ConfHTTPTimeout, defs.HTTPTimeout, "HTTP timeout (in seconds) used by the Auth API HTTP client") + flags.Int(ConfHTTPResponseCacheSize, defs.HTTPResponseCacheSize, "HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses.") flags.Int(ConfClockSkew, defs.ClockSkew, "allowed JWT Clock skew in milliseconds") flags.Int(ConfAccessTokenLifeSpan, defs.AccessTokenLifeSpan, "defines how long (in seconds) an access token is valid. Uses default in strict mode.") flags.StringSlice(ConfContractValidators, defs.ContractValidators, "sets the different contract validators to use") diff --git a/auth/cmd/cmd_test.go b/auth/cmd/cmd_test.go index 3293844996..8da1538388 100644 --- a/auth/cmd/cmd_test.go +++ b/auth/cmd/cmd_test.go @@ -45,6 +45,7 @@ func TestFlagSet(t *testing.T) { ConfAccessTokenLifeSpan, ConfClockSkew, ConfContractValidators, + ConfHTTPResponseCacheSize, ConfHTTPTimeout, ConfAutoUpdateIrmaSchemas, ConfIrmaSchemeManager, diff --git a/auth/config.go b/auth/config.go index e482853026..b9748bbc2b 100644 --- a/auth/config.go +++ b/auth/config.go @@ -26,11 +26,13 @@ import ( // Config holds all the configuration params type Config struct { - Irma IrmaConfig `koanf:"irma"` - HTTPTimeout int `koanf:"http.timeout"` - ClockSkew int `koanf:"clockskew"` - ContractValidators []string `koanf:"contractvalidators"` - AccessTokenLifeSpan int `koanf:"accesstokenlifespan"` + Irma IrmaConfig `koanf:"irma"` + HTTPTimeout int `koanf:"http.timeout"` + // HTTPResponseCacheSize is the maximum number of bytes cached by the HTTP client. + HTTPResponseCacheSize int `koanf:"http.cache.maxbytes"` + ClockSkew int `koanf:"clockskew"` + ContractValidators []string `koanf:"contractvalidators"` + AccessTokenLifeSpan int `koanf:"accesstokenlifespan"` } type IrmaConfig struct { @@ -51,6 +53,7 @@ func DefaultConfig() Config { dummy.ContractFormat, selfsigned.ContractFormat, }, - AccessTokenLifeSpan: 60, // seconds, as specced in RFC003 + AccessTokenLifeSpan: 60, // seconds, as specced in RFC003 + HTTPResponseCacheSize: 10 * 1024 * 1024, // 10mb } } diff --git a/docs/pages/deployment/cli-reference.rst b/docs/pages/deployment/cli-reference.rst index 078424ad4c..19dc751e50 100755 --- a/docs/pages/deployment/cli-reference.rst +++ b/docs/pages/deployment/cli-reference.rst @@ -16,6 +16,7 @@ The following options apply to the server commands below: --auth.accesstokenlifespan int defines how long (in seconds) an access token is valid. Uses default in strict mode. (default 60) --auth.clockskew int allowed JWT Clock skew in milliseconds (default 5000) --auth.contractvalidators strings sets the different contract validators to use (default [irma,dummy,employeeid]) + --auth.http.cache.maxbytes int HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. (default 10485760) --auth.irma.autoupdateschemas set if you want automatically update the IRMA schemas every 60 minutes. (default true) --auth.irma.schememanager string IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'. (default "pbdf") --configfile string Nuts config file (default "./config/nuts.yaml") diff --git a/docs/pages/deployment/server_options.rst b/docs/pages/deployment/server_options.rst index 2f49f7d824..85d8464405 100755 --- a/docs/pages/deployment/server_options.rst +++ b/docs/pages/deployment/server_options.rst @@ -34,7 +34,7 @@ http.internal.auth.type Whether to enable authentication for /internal endpoints, specify 'token_v2' for bearer token mode or 'token' for legacy bearer token mode. http.public.address \:8080 Address and port the server will be listening to for public-facing endpoints. **JSONLD** - jsonld.contexts.localmapping [https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. + jsonld.contexts.localmapping [https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson,https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. jsonld.contexts.remoteallowlist [https://schema.org,https://www.w3.org/2018/credentials/v1,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json,https://w3id.org/vc/status-list/2021/v1] In strict mode, fetching external JSON-LD contexts is not allowed except for context-URLs listed here. **PKI** pki.maxupdatefailhours 4 Maximum number of hours that a denylist update can fail diff --git a/docs/pages/deployment/server_options_didnuts.rst b/docs/pages/deployment/server_options_didnuts.rst index 3569e6a9a5..8e4073d78a 100755 --- a/docs/pages/deployment/server_options_didnuts.rst +++ b/docs/pages/deployment/server_options_didnuts.rst @@ -14,6 +14,7 @@ auth.accesstokenlifespan 60 defines how long (in seconds) an access token is valid. Uses default in strict mode. auth.clockskew 5000 allowed JWT Clock skew in milliseconds auth.contractvalidators [irma,dummy,employeeid] sets the different contract validators to use + auth.http.cache.maxbytes 10485760 HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. auth.irma.autoupdateschemas true set if you want automatically update the IRMA schemas every 60 minutes. auth.irma.schememanager pbdf IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'. **Events** From a7b0b44d468bc6b9833f723e60db143cd2d859e9 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Fri, 31 May 2024 13:39:55 +0200 Subject: [PATCH 06/10] instantiate IAM Client once to allow caching --- auth/auth.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/auth/auth.go b/auth/auth.go index 4a83fe706c..0418f76296 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -60,6 +60,7 @@ type Auth struct { strictMode bool httpClientTimeout time.Duration tlsConfig *tls.Config + iamClient *iam.OpenID4VPClient } // Name returns the name of the module. @@ -163,6 +164,10 @@ func (auth *Auth) Configure(config core.ServerConfig) error { return err } + keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} + auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, + auth.httpClientTimeout, auth.config.HTTPResponseCacheSize) + return nil } From 609c80948385c304347af90827cbe7fac077ce92 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Fri, 31 May 2024 13:44:58 +0200 Subject: [PATCH 07/10] f2 --- auth/auth.go | 12 ++++++------ auth/auth_test.go | 4 +++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 0418f76296..d8377de36b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -108,8 +108,7 @@ func (auth *Auth) RelyingParty() oauth.RelyingParty { } func (auth *Auth) IAMClient() iam.Client { - keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} - return iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, auth.httpClientTimeout, auth.config.HTTPResponseCacheSize) + return auth.iamClient } // Configure the Auth struct by creating a validator and create an Irma server @@ -147,18 +146,19 @@ func (auth *Auth) Configure(config core.ServerConfig) error { return err } + var httpClientTimeout time.Duration if auth.config.HTTPTimeout >= 0 { - auth.httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second + httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second } else { // auth.http.config got deprecated in favor of httpclient.timeout - auth.httpClientTimeout = config.HTTPClient.Timeout + httpClientTimeout = config.HTTPClient.Timeout } // V1 API related stuff accessTokenLifeSpan := time.Duration(auth.config.AccessTokenLifeSpan) * time.Second auth.authzServer = oauth.NewAuthorizationServer(auth.vdrInstance.Resolver(), auth.vcr, auth.vcr.Verifier(), auth.serviceResolver, auth.keyStore, auth.contractNotary, auth.jsonldManager, accessTokenLifeSpan) auth.relyingParty = oauth.NewRelyingParty(auth.vdrInstance.Resolver(), auth.serviceResolver, - auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode) + auth.keyStore, auth.vcr.Wallet(), httpClientTimeout, auth.tlsConfig, config.Strictmode) if err := auth.authzServer.Configure(auth.config.ClockSkew, config.Strictmode); err != nil { return err @@ -166,7 +166,7 @@ func (auth *Auth) Configure(config core.ServerConfig) error { keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, - auth.httpClientTimeout, auth.config.HTTPResponseCacheSize) + httpClientTimeout, auth.config.HTTPResponseCacheSize) return nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index 1c31d6c3c7..bfb712e469 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -114,11 +114,13 @@ func TestAuth_IAMClient(t *testing.T) { config := DefaultConfig() config.ContractValidators = []string{"dummy"} ctrl := gomock.NewController(t) - pkiMock := pki.NewMockProvider(ctrl) // no calls are expected + pkiMock := pki.NewMockProvider(ctrl) + pkiMock.EXPECT().CreateTLSConfig(gomock.Any()) // for v5 HTTP client vdrInstance := vdr.NewMockVDR(ctrl) vdrInstance.EXPECT().Resolver().AnyTimes() i := NewAuthInstance(config, vdrInstance, vcr.NewTestVCRInstance(t), crypto.NewMemoryCryptoInstance(), nil, nil, pkiMock) + require.NoError(t, i.Configure(core.TestServerConfig())) assert.NotNil(t, i.IAMClient()) }) From cff28a19c8fa0d31062c83f8042ceefc3d2e582a Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Fri, 31 May 2024 15:36:45 +0200 Subject: [PATCH 08/10] wip --- auth/auth.go | 3 +- auth/client/iam/openid4vp.go | 10 +--- auth/config.go | 15 ++--- {auth/client/iam => http/client}/caching.go | 45 +++++---------- .../iam => http/client}/caching_test.go | 56 +++++++++---------- http/config.go | 3 + http/engine.go | 7 +++ 7 files changed, 61 insertions(+), 78 deletions(-) rename {auth/client/iam => http/client}/caching.go (82%) rename {auth/client/iam => http/client}/caching_test.go (82%) diff --git a/auth/auth.go b/auth/auth.go index d8377de36b..afcbaf6df1 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -165,8 +165,7 @@ func (auth *Auth) Configure(config core.ServerConfig) error { } keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} - auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, - httpClientTimeout, auth.config.HTTPResponseCacheSize) + auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, httpClientTimeout) return nil } diff --git a/auth/client/iam/openid4vp.go b/auth/client/iam/openid4vp.go index d488d61157..23329b864e 100644 --- a/auth/client/iam/openid4vp.go +++ b/auth/client/iam/openid4vp.go @@ -52,17 +52,11 @@ type OpenID4VPClient struct { } // NewClient returns an implementation of Holder -// responsesCacheSizeInBytes specifies the max. number of bytes to cache in memory. If 0, no caching is done. -func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner nutsCrypto.JWTSigner, strictMode bool, - httpClientTimeout time.Duration, responsesCacheSizeInBytes int) *OpenID4VPClient { - var requestDoer core.HTTPRequestDoer = core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil) - if responsesCacheSizeInBytes > 0 { - requestDoer = cachingHTTPClient(requestDoer, responsesCacheSizeInBytes) - } +func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner nutsCrypto.JWTSigner, strictMode bool, httpClientTimeout time.Duration) *OpenID4VPClient { return &OpenID4VPClient{ httpClient: HTTPClient{ strictMode: strictMode, - httpClient: requestDoer, + httpClient: core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil), }, keyResolver: keyResolver, jwtSigner: jwtSigner, diff --git a/auth/config.go b/auth/config.go index b9748bbc2b..e482853026 100644 --- a/auth/config.go +++ b/auth/config.go @@ -26,13 +26,11 @@ import ( // Config holds all the configuration params type Config struct { - Irma IrmaConfig `koanf:"irma"` - HTTPTimeout int `koanf:"http.timeout"` - // HTTPResponseCacheSize is the maximum number of bytes cached by the HTTP client. - HTTPResponseCacheSize int `koanf:"http.cache.maxbytes"` - ClockSkew int `koanf:"clockskew"` - ContractValidators []string `koanf:"contractvalidators"` - AccessTokenLifeSpan int `koanf:"accesstokenlifespan"` + Irma IrmaConfig `koanf:"irma"` + HTTPTimeout int `koanf:"http.timeout"` + ClockSkew int `koanf:"clockskew"` + ContractValidators []string `koanf:"contractvalidators"` + AccessTokenLifeSpan int `koanf:"accesstokenlifespan"` } type IrmaConfig struct { @@ -53,7 +51,6 @@ func DefaultConfig() Config { dummy.ContractFormat, selfsigned.ContractFormat, }, - AccessTokenLifeSpan: 60, // seconds, as specced in RFC003 - HTTPResponseCacheSize: 10 * 1024 * 1024, // 10mb + AccessTokenLifeSpan: 60, // seconds, as specced in RFC003 } } diff --git a/auth/client/iam/caching.go b/http/client/caching.go similarity index 82% rename from auth/client/iam/caching.go rename to http/client/caching.go index aaf31930eb..722881f2c3 100644 --- a/auth/client/iam/caching.go +++ b/http/client/caching.go @@ -1,28 +1,9 @@ -/* - * Copyright (C) 2024 Nuts community - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package iam +package client import ( "bytes" "fmt" - "github.com/nuts-foundation/nuts-node/auth/log" - "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/log" "github.com/pquerna/cachecontrol" "io" "net/http" @@ -35,13 +16,15 @@ import ( // Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime. const maxCacheTime = time.Hour +var _ http.RoundTripper = &CachingHTTPRequestDoer{} + // CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID clients. // It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable. // It only works on expiration time and does not respect ETags headers. // When maxBytes is reached, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways). type CachingHTTPRequestDoer struct { - maxBytes int - requestDoer core.HTTPRequestDoer + maxBytes int + wrappedTransport http.RoundTripper // currentSizeBytes is the current size of the cache in bytes. // It's used to make room for new entries when the cache is full. @@ -67,13 +50,13 @@ type cacheEntry struct { responseHeaders http.Header } -func (h *CachingHTTPRequestDoer) Do(httpRequest *http.Request) (*http.Response, error) { +func (h *CachingHTTPRequestDoer) RoundTrip(httpRequest *http.Request) (*http.Response, error) { if httpRequest.Method == http.MethodGet { if response := h.cachedEntry(httpRequest); response != nil { return response, nil } } - httpResponse, err := h.requestDoer.Do(httpRequest) + httpResponse, err := h.wrappedTransport.RoundTrip(httpRequest) if err != nil { return nil, err } @@ -203,12 +186,12 @@ func (h *CachingHTTPRequestDoer) pop() *cacheEntry { return h.head } -// cachingHTTPClient -func cachingHTTPClient(requestDoer core.HTTPRequestDoer, responsesCacheSize int) *CachingHTTPRequestDoer { +// NewCachingTransport creates a new CachingHTTPTransport with the given underlying transport and cache size. +func NewCachingTransport(underlyingTransport http.RoundTripper, responsesCacheSize int) *CachingHTTPRequestDoer { return &CachingHTTPRequestDoer{ - maxBytes: responsesCacheSize, - requestDoer: requestDoer, - entriesByURL: map[string][]*cacheEntry{}, - mux: sync.RWMutex{}, + maxBytes: responsesCacheSize, + wrappedTransport: underlyingTransport, + entriesByURL: map[string][]*cacheEntry{}, + mux: sync.RWMutex{}, } } diff --git a/auth/client/iam/caching_test.go b/http/client/caching_test.go similarity index 82% rename from auth/client/iam/caching_test.go rename to http/client/caching_test.go index 7e67f64411..979b1ebc22 100644 --- a/auth/client/iam/caching_test.go +++ b/http/client/caching_test.go @@ -16,7 +16,7 @@ * */ -package iam +package client import ( "bytes" @@ -36,7 +36,7 @@ func Test_httpClientCache(t *testing.T) { URL: test.MustParseURL("http://example.com"), } t.Run("does not cache POST requests", func(t *testing.T) { - client := cachingHTTPClient(&stubRequestDoer{ + client := NewCachingTransport(&stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ @@ -44,7 +44,7 @@ func Test_httpClientCache(t *testing.T) { }, }, 1000) - _, err := client.Do(&http.Request{ + _, err := client.RoundTrip(&http.Request{ Method: http.MethodPost, }) @@ -52,24 +52,24 @@ func Test_httpClientCache(t *testing.T) { assert.Equal(t, 0, client.currentSizeBytes) }) t.Run("caches GET request with max-age", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusCreated, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "max-age=3600", }, } - client := cachingHTTPClient(requestSink, 1000) + client := NewCachingTransport(requestSink, 1000) // Initial fetch - httpResponse, err := client.Do(httpRequest) + httpResponse, err := client.RoundTrip(httpRequest) require.NoError(t, err) assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) fetchedResponseData, _ := io.ReadAll(httpResponse.Body) assert.Equal(t, "Hello, World!", string(fetchedResponseData)) // Fetch the response again, should be taken from cache - httpResponse, err = client.Do(httpRequest) + httpResponse, err = client.RoundTrip(httpRequest) require.NoError(t, err) assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) cachedResponseData, _ := io.ReadAll(httpResponse.Body) @@ -79,7 +79,7 @@ func Test_httpClientCache(t *testing.T) { assert.Equal(t, 1, requestSink.invocations) }) t.Run("does not cache responses with no-store", func(t *testing.T) { - client := cachingHTTPClient(&stubRequestDoer{ + client := NewCachingTransport(&stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ @@ -87,26 +87,26 @@ func Test_httpClientCache(t *testing.T) { }, }, 1000) - _, err := client.Do(httpRequest) + _, err := client.RoundTrip(httpRequest) require.NoError(t, err) assert.Equal(t, 0, client.currentSizeBytes) }) t.Run("max-age is too long", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": fmt.Sprintf("max-age=%d", int(time.Hour.Seconds()*24)), }, } - client := cachingHTTPClient(requestSink, 1000) + client := NewCachingTransport(requestSink, 1000) - _, err := client.Do(httpRequest) + _, err := client.RoundTrip(httpRequest) require.NoError(t, err) assert.LessOrEqual(t, time.Now().Sub(client.head.expirationTime), time.Hour) }) t.Run("2 cache entries with different query parameters", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusOK, headers: map[string]string{ "Cache-Control": "max-age=3600", @@ -115,56 +115,56 @@ func Test_httpClientCache(t *testing.T) { requestSink.dataFn = func(req *http.Request) []byte { return []byte(req.URL.String()) } - client := cachingHTTPClient(requestSink, 1000) + client := NewCachingTransport(requestSink, 1000) // Initial fetch of the resources - _, err := client.Do(httpRequest) + _, err := client.RoundTrip(httpRequest) require.NoError(t, err) alternativeRequest := &http.Request{ Method: http.MethodGet, URL: test.MustParseURL("http://example.com?foo=bar"), } - _, err = client.Do(alternativeRequest) + _, err = client.RoundTrip(alternativeRequest) require.NoError(t, err) assert.Equal(t, 2, requestSink.invocations) // Fetch the responses again, should be taken from cache - response1, _ := client.Do(httpRequest) + response1, _ := client.RoundTrip(httpRequest) response1Data, _ := io.ReadAll(response1.Body) - response2, _ := client.Do(alternativeRequest) + response2, _ := client.RoundTrip(alternativeRequest) response2Data, _ := io.ReadAll(response2.Body) assert.Equal(t, 2, requestSink.invocations) assert.Equal(t, "http://example.com", string(response1Data)) assert.Equal(t, "http://example.com?foo=bar", string(response2Data)) }) t.Run("prunes cache when full", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "max-age=3600", }, } - client := cachingHTTPClient(requestSink, 14) + client := NewCachingTransport(requestSink, 14) client.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com"), expirationTime: time.Now().Add(time.Hour), }) - _, err := client.Do(httpRequest) + _, err := client.RoundTrip(httpRequest) require.NoError(t, err) assert.Equal(t, 13, client.currentSizeBytes) }) t.Run("orders entries by expirationTime for optimized pruning", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "max-age=3600", }, } - client := cachingHTTPClient(requestSink, 10000) + client := NewCachingTransport(requestSink, 10000) client.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com/3"), @@ -185,16 +185,16 @@ func Test_httpClientCache(t *testing.T) { assert.Equal(t, client.head.requestURL.String(), "http://example.com/1") }) t.Run("entries that exceed max cache size aren't cached", func(t *testing.T) { - requestSink := &stubRequestDoer{ + requestSink := &stubRoundTripper{ statusCode: http.StatusOK, data: []byte("Hello, World!"), headers: map[string]string{ "Cache-Control": "max-age=3600", }, } - client := cachingHTTPClient(requestSink, 5) + client := NewCachingTransport(requestSink, 5) - httpResponse, err := client.Do(httpRequest) + httpResponse, err := client.RoundTrip(httpRequest) require.NoError(t, err) data, _ := io.ReadAll(httpResponse.Body) assert.Equal(t, "Hello, World!", string(data)) @@ -202,7 +202,7 @@ func Test_httpClientCache(t *testing.T) { }) } -type stubRequestDoer struct { +type stubRoundTripper struct { statusCode int data []byte dataFn func(r *http.Request) []byte @@ -210,7 +210,7 @@ type stubRequestDoer struct { invocations int } -func (s *stubRequestDoer) Do(req *http.Request) (*http.Response, error) { +func (s *stubRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { s.invocations++ response := &http.Response{ StatusCode: s.statusCode, diff --git a/http/config.go b/http/config.go index 02fe3d096d..4554802dfb 100644 --- a/http/config.go +++ b/http/config.go @@ -28,6 +28,7 @@ func DefaultConfig() Config { Public: PublicConfig{ Address: ":8080", }, + ResponseCacheSize: 10 * 1024 * 1024, // 10mb } } @@ -37,6 +38,8 @@ type Config struct { Log LogLevel `koanf:"log"` Public PublicConfig `koanf:"public"` Internal InternalConfig `koanf:"internal"` + // ResponseCacheSize is the maximum number of bytes cached by HTTP clients. + ResponseCacheSize int `koanf:"cache.maxbytes"` } // PublicConfig contains the configuration for outside-facing HTTP endpoints. diff --git a/http/engine.go b/http/engine.go index df03ee89a3..00c17f4900 100644 --- a/http/engine.go +++ b/http/engine.go @@ -23,6 +23,7 @@ import ( "crypto" "errors" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "os" "strings" @@ -67,6 +68,12 @@ func (h Engine) Router() core.EchoRouter { // Configure loads the configuration for the HTTP engine. func (h *Engine) Configure(serverConfig core.ServerConfig) error { + // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. + if h.config.ResponseCacheSize > 0 { + defaultTransport := http.DefaultTransport.(*http.Transport) + http.DefaultTransport = client.NewCachingTransport(defaultTransport, h.config.ResponseCacheSize) + } + // Override default Echo HTTP error when bearer token is expected but not provided. // Echo returns "Bad Request (400)" by default, but we use this for incorrect use of API parameters. // "Unauthorized (401)" is a better fit. From ba9974f386ba636f193c77b7250139f992305fef Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Sat, 1 Jun 2024 08:30:39 +0200 Subject: [PATCH 09/10] move core.StrictHTTPClient to HTTP module to allow global caching --- auth/auth.go | 14 +-- auth/auth_test.go | 4 +- auth/client/iam/openid4vp.go | 3 +- auth/client/iam/openid4vp_test.go | 5 +- auth/cmd/cmd.go | 4 - auth/cmd/cmd_test.go | 1 - core/http_client.go | 41 ------- core/http_client_test.go | 12 -- discovery/api/server/client/http.go | 3 +- http/client/caching.go | 163 +++++++++++++++------------- http/client/caching_test.go | 26 ++--- http/client/client.go | 47 ++++++++ http/client/client_test.go | 71 ++++++++++++ http/cmd/cmd.go | 1 + http/engine.go | 20 +++- policy/api/v1/client/client.go | 3 +- vcr/vcr.go | 7 +- vcr/vcr_test.go | 2 + vdr/didweb/web.go | 13 +-- vdr/didweb/web_test.go | 6 +- vdr/vdr_test.go | 3 +- 21 files changed, 264 insertions(+), 185 deletions(-) create mode 100644 http/client/client.go create mode 100644 http/client/client_test.go diff --git a/auth/auth.go b/auth/auth.go index afcbaf6df1..aebe86b7a2 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -60,7 +60,6 @@ type Auth struct { strictMode bool httpClientTimeout time.Duration tlsConfig *tls.Config - iamClient *iam.OpenID4VPClient } // Name returns the name of the module. @@ -108,7 +107,8 @@ func (auth *Auth) RelyingParty() oauth.RelyingParty { } func (auth *Auth) IAMClient() iam.Client { - return auth.iamClient + keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} + return iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, auth.httpClientTimeout) } // Configure the Auth struct by creating a validator and create an Irma server @@ -146,27 +146,23 @@ func (auth *Auth) Configure(config core.ServerConfig) error { return err } - var httpClientTimeout time.Duration if auth.config.HTTPTimeout >= 0 { - httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second + auth.httpClientTimeout = time.Duration(auth.config.HTTPTimeout) * time.Second } else { // auth.http.config got deprecated in favor of httpclient.timeout - httpClientTimeout = config.HTTPClient.Timeout + auth.httpClientTimeout = config.HTTPClient.Timeout } // V1 API related stuff accessTokenLifeSpan := time.Duration(auth.config.AccessTokenLifeSpan) * time.Second auth.authzServer = oauth.NewAuthorizationServer(auth.vdrInstance.Resolver(), auth.vcr, auth.vcr.Verifier(), auth.serviceResolver, auth.keyStore, auth.contractNotary, auth.jsonldManager, accessTokenLifeSpan) auth.relyingParty = oauth.NewRelyingParty(auth.vdrInstance.Resolver(), auth.serviceResolver, - auth.keyStore, auth.vcr.Wallet(), httpClientTimeout, auth.tlsConfig, config.Strictmode) + auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode) if err := auth.authzServer.Configure(auth.config.ClockSkew, config.Strictmode); err != nil { return err } - keyResolver := resolver.DIDKeyResolver{Resolver: auth.vdrInstance.Resolver()} - auth.iamClient = iam.NewClient(auth.vcr.Wallet(), keyResolver, auth.keyStore, auth.strictMode, httpClientTimeout) - return nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index bfb712e469..1c31d6c3c7 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -114,13 +114,11 @@ func TestAuth_IAMClient(t *testing.T) { config := DefaultConfig() config.ContractValidators = []string{"dummy"} ctrl := gomock.NewController(t) - pkiMock := pki.NewMockProvider(ctrl) - pkiMock.EXPECT().CreateTLSConfig(gomock.Any()) // for v5 HTTP client + pkiMock := pki.NewMockProvider(ctrl) // no calls are expected vdrInstance := vdr.NewMockVDR(ctrl) vdrInstance.EXPECT().Resolver().AnyTimes() i := NewAuthInstance(config, vdrInstance, vcr.NewTestVCRInstance(t), crypto.NewMemoryCryptoInstance(), nil, nil, pkiMock) - require.NoError(t, i.Configure(core.TestServerConfig())) assert.NotNil(t, i.IAMClient()) }) diff --git a/auth/client/iam/openid4vp.go b/auth/client/iam/openid4vp.go index 23329b864e..76414ba892 100644 --- a/auth/client/iam/openid4vp.go +++ b/auth/client/iam/openid4vp.go @@ -23,6 +23,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "net/url" "time" @@ -56,7 +57,7 @@ func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner return &OpenID4VPClient{ httpClient: HTTPClient{ strictMode: strictMode, - httpClient: core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil), + httpClient: client.NewWithCache(httpClientTimeout), }, keyResolver: keyResolver, jwtSigner: jwtSigner, diff --git a/auth/client/iam/openid4vp_test.go b/auth/client/iam/openid4vp_test.go index 86c0dc4137..7204dba354 100644 --- a/auth/client/iam/openid4vp_test.go +++ b/auth/client/iam/openid4vp_test.go @@ -23,13 +23,12 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "net/http/httptest" "testing" "time" - "github.com/nuts-foundation/nuts-node/core" - ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/go-did/vc" @@ -399,7 +398,7 @@ func createClientTestContext(t *testing.T, tlsConfig *tls.Config) *clientTestCon wallet: wallet, httpClient: HTTPClient{ strictMode: false, - httpClient: core.NewStrictHTTPClient(false, 10*time.Second, tlsConfig), + httpClient: client.NewWithTLSConfig(10*time.Second, tlsConfig), }, jwtSigner: jwtSigner, keyResolver: keyResolver, diff --git a/auth/cmd/cmd.go b/auth/cmd/cmd.go index 135181dcfa..50832a1036 100644 --- a/auth/cmd/cmd.go +++ b/auth/cmd/cmd.go @@ -38,9 +38,6 @@ const ConfIrmaSchemeManager = "auth.irma.schememanager" // ConfHTTPTimeout defines a timeout (in seconds) which is used by the Auth API HTTP client const ConfHTTPTimeout = "auth.http.timeout" -// ConfHTTPResponseCacheSize defines the maximum HTTP client response cache size in bytes. -const ConfHTTPResponseCacheSize = "auth.http.cache.maxbytes" - // ConfAccessTokenLifeSpan defines how long (in seconds) an access token is valid const ConfAccessTokenLifeSpan = "auth.accesstokenlifespan" @@ -52,7 +49,6 @@ func FlagSet() *pflag.FlagSet { flags.String(ConfIrmaSchemeManager, defs.Irma.SchemeManager, "IRMA schemeManager to use for attributes. Can be either 'pbdf' or 'irma-demo'.") flags.Bool(ConfAutoUpdateIrmaSchemas, defs.Irma.AutoUpdateSchemas, "set if you want automatically update the IRMA schemas every 60 minutes.") flags.Int(ConfHTTPTimeout, defs.HTTPTimeout, "HTTP timeout (in seconds) used by the Auth API HTTP client") - flags.Int(ConfHTTPResponseCacheSize, defs.HTTPResponseCacheSize, "HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses.") flags.Int(ConfClockSkew, defs.ClockSkew, "allowed JWT Clock skew in milliseconds") flags.Int(ConfAccessTokenLifeSpan, defs.AccessTokenLifeSpan, "defines how long (in seconds) an access token is valid. Uses default in strict mode.") flags.StringSlice(ConfContractValidators, defs.ContractValidators, "sets the different contract validators to use") diff --git a/auth/cmd/cmd_test.go b/auth/cmd/cmd_test.go index 8da1538388..3293844996 100644 --- a/auth/cmd/cmd_test.go +++ b/auth/cmd/cmd_test.go @@ -45,7 +45,6 @@ func TestFlagSet(t *testing.T) { ConfAccessTokenLifeSpan, ConfClockSkew, ConfContractValidators, - ConfHTTPResponseCacheSize, ConfHTTPTimeout, ConfAutoUpdateIrmaSchemas, ConfIrmaSchemeManager, diff --git a/core/http_client.go b/core/http_client.go index 2778c690f7..dc9aff03bb 100644 --- a/core/http_client.go +++ b/core/http_client.go @@ -21,13 +21,10 @@ package core import ( "context" - "crypto/tls" - "errors" "fmt" "github.com/sirupsen/logrus" "io" "net/http" - "time" ) // HttpResponseBodyLogClipAt is the maximum length of a response body to log. @@ -175,41 +172,3 @@ func newEmptyTokenGenerator() AuthorizationTokenGenerator { return "", nil } } - -// NewStrictHTTPClient creates a HTTPRequestDoer that only allows HTTPS calls when strictmode is enabled. -func NewStrictHTTPClient(strictmode bool, timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient { - if tlsConfig == nil { - tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - } - - transport := http.DefaultTransport - // Might not be http.Transport in testing - if httpTransport, ok := transport.(*http.Transport); ok { - // cloning the transport might reduce performance. - httpTransport = httpTransport.Clone() - httpTransport.TLSClientConfig = tlsConfig - transport = httpTransport - } - - return &StrictHTTPClient{ - client: &http.Client{ - Transport: transport, - Timeout: timeout, - }, - strictMode: strictmode, - } -} - -type StrictHTTPClient struct { - client *http.Client - strictMode bool -} - -func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { - if s.strictMode && req.URL.Scheme != "https" { - return nil, errors.New("strictmode is enabled, but request is not over HTTPS") - } - return s.client.Do(req) -} diff --git a/core/http_client_test.go b/core/http_client_test.go index 8a82a5711d..891364b279 100644 --- a/core/http_client_test.go +++ b/core/http_client_test.go @@ -31,7 +31,6 @@ import ( "net/url" "strings" "testing" - "time" ) func TestHTTPClient(t *testing.T) { @@ -91,17 +90,6 @@ func TestHTTPClient(t *testing.T) { }) } -func TestStrictHTTPClient_Do(t *testing.T) { - t.Run("error on HTTP call when strictmode is enabled", func(t *testing.T) { - client := NewStrictHTTPClient(true, time.Second, nil) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) - - _, err := client.Do(httpRequest) - - assert.Error(t, err) - }) -} - func TestUserAgentRequestEditor(t *testing.T) { GitVersion = "" req := &stdHttp.Request{Header: map[string][]string{}} diff --git a/discovery/api/server/client/http.go b/discovery/api/server/client/http.go index 0fb18e05fb..cc47de45cf 100644 --- a/discovery/api/server/client/http.go +++ b/discovery/api/server/client/http.go @@ -27,6 +27,7 @@ import ( "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/discovery/log" + "github.com/nuts-foundation/nuts-node/http/client" "io" "net/http" "net/url" @@ -36,7 +37,7 @@ import ( // New creates a new DefaultHTTPClient. func New(strictMode bool, timeout time.Duration, tlsConfig *tls.Config) *DefaultHTTPClient { return &DefaultHTTPClient{ - client: core.NewStrictHTTPClient(strictMode, timeout, tlsConfig), + client: client.NewWithTLSConfig(timeout, tlsConfig), } } diff --git a/http/client/caching.go b/http/client/caching.go index 722881f2c3..40e5afc0bf 100644 --- a/http/client/caching.go +++ b/http/client/caching.go @@ -12,55 +12,46 @@ import ( "time" ) +// DefaultCachingTransport is a http.RoundTripper that can be used as a default transport for HTTP clients. +// If caching is enabled, it will cache responses according to RFC 7234. +// If caching is disabled, it will behave like http.DefaultTransport. +var DefaultCachingTransport = http.DefaultTransport + // maxCacheTime is the maximum time responses are cached. // Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime. const maxCacheTime = time.Hour -var _ http.RoundTripper = &CachingHTTPRequestDoer{} - -// CachingHTTPRequestDoer is a cache for HTTP responses for DID/OAuth2/OpenID clients. -// It only caches GET requests (since generally only metadata is cacheable), and only if the response is cacheable. -// It only works on expiration time and does not respect ETags headers. -// When maxBytes is reached, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways). -type CachingHTTPRequestDoer struct { - maxBytes int - wrappedTransport http.RoundTripper +var _ http.RoundTripper = &CachingRoundTripper{} - // currentSizeBytes is the current size of the cache in bytes. - // It's used to make room for new entries when the cache is full. - currentSizeBytes int - // head is the first entry of a linked list of cache entries, ordered by expiration time. - // The first entry is the one that will expire first, which optimizes the removal of expired entries. - // When an entry is inserted in the cache, it's inserted in the right place in the linked list (ordered by expiry). - head *cacheEntry - // entriesByURL is a map of cache entries, indexed by the URL of the request. - // This optimizes the lookup of cache entries by URL. - entriesByURL map[string][]*cacheEntry - mux sync.RWMutex +// NewCachingTransport creates a new CachingHTTPTransport with the given underlying transport and cache size. +func NewCachingTransport(underlyingTransport http.RoundTripper, responsesCacheSize int) *CachingRoundTripper { + return &CachingRoundTripper{ + cache: newCache(responsesCacheSize), + wrappedTransport: underlyingTransport, + } } -type cacheEntry struct { - responseData []byte - requestURL *url.URL - requestMethod string - requestRawQuery string - expirationTime time.Time - next *cacheEntry - responseStatus int - responseHeaders http.Header +// CachingRoundTripper is a simple HTTP client cache for HTTP responses. +// It only caches GET requests (since for POST request caching, request bodies need to be cached as well), +// and only if the response is cacheable according to RFC 7234. +// It only works on expiration time and does not respect ETags headers. +// When the cache is full, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways). +type CachingRoundTripper struct { + cache *responseCache + wrappedTransport http.RoundTripper } -func (h *CachingHTTPRequestDoer) RoundTrip(httpRequest *http.Request) (*http.Response, error) { +func (r *CachingRoundTripper) RoundTrip(httpRequest *http.Request) (*http.Response, error) { if httpRequest.Method == http.MethodGet { - if response := h.cachedEntry(httpRequest); response != nil { + if response := r.cache.get(httpRequest); response != nil { return response, nil } } - httpResponse, err := h.wrappedTransport.RoundTrip(httpRequest) + httpResponse, err := r.wrappedTransport.RoundTrip(httpRequest) if err != nil { return nil, err } - err = h.cacheResponse(httpRequest, httpResponse) + err = r.cacheResponse(httpRequest, httpResponse) if err != nil { return nil, err } @@ -68,7 +59,7 @@ func (h *CachingHTTPRequestDoer) RoundTrip(httpRequest *http.Request) (*http.Res } // cacheResponse caches the response if it's cacheable. -func (h *CachingHTTPRequestDoer) cacheResponse(httpRequest *http.Request, httpResponse *http.Response) error { +func (r *CachingRoundTripper) cacheResponse(httpRequest *http.Request, httpResponse *http.Response) error { if httpRequest.Method != http.MethodGet { return nil } @@ -91,25 +82,55 @@ func (h *CachingHTTPRequestDoer) cacheResponse(httpRequest *http.Request, httpRe if err != nil { return fmt.Errorf("error while reading response body for caching: %w", err) } - h.mux.Lock() - defer h.mux.Unlock() - if len(responseBytes) <= h.maxBytes { // sanity check - h.insert(&cacheEntry{ - responseData: responseBytes, - requestMethod: httpRequest.Method, - requestURL: httpRequest.URL, - requestRawQuery: httpRequest.URL.RawQuery, - responseStatus: httpResponse.StatusCode, - responseHeaders: httpResponse.Header, - expirationTime: expirationTime, - }) - } + r.cache.insert(&cacheEntry{ + responseData: responseBytes, + requestMethod: httpRequest.Method, + requestURL: httpRequest.URL, + requestRawQuery: httpRequest.URL.RawQuery, + responseStatus: httpResponse.StatusCode, + responseHeaders: httpResponse.Header, + expirationTime: expirationTime, + }) httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) return nil } -// cachedEntry returns a cached response if it exists. -func (h *CachingHTTPRequestDoer) cachedEntry(httpRequest *http.Request) *http.Response { +func newCache(responsesCacheSize int) *responseCache { + return &responseCache{ + maxBytes: responsesCacheSize, + entriesByURL: map[string][]*cacheEntry{}, + mux: sync.RWMutex{}, + } +} + +type responseCache struct { + maxBytes int + // currentSizeBytes is the current size of the cache in bytes. + // It's used to make room for new entries when the cache is full. + currentSizeBytes int + // head is the first entry of a linked list of cache entries, ordered by expiration time. + // The first entry is the one that will expire first, which optimizes the removal of expired entries. + // When an entry is inserted in the cache, it's inserted in the right place in the linked list (ordered by expiry). + head *cacheEntry + // entriesByURL is a map of cache entries, indexed by the URL of the request. + // This optimizes the lookup of cache entries by URL. + entriesByURL map[string][]*cacheEntry + mux sync.RWMutex +} + +type cacheEntry struct { + responseData []byte + requestURL *url.URL + requestMethod string + requestRawQuery string + expirationTime time.Time + next *cacheEntry + responseStatus int + responseHeaders http.Header +} + +// get is called by the transport to get a cached response. +func (h *responseCache) get(httpRequest *http.Request) *http.Response { h.mux.Lock() defer h.mux.Unlock() h.removeExpiredEntries() @@ -127,19 +148,13 @@ func (h *CachingHTTPRequestDoer) cachedEntry(httpRequest *http.Request) *http.Re return nil } -func (h *CachingHTTPRequestDoer) removeExpiredEntries() { - var current = h.head - for current != nil { - if current.expirationTime.Before(time.Now()) { - current = h.pop() - } else { - break - } +// insert is called by the transport to insert a new entry to the cache. +func (h *responseCache) insert(entry *cacheEntry) { + if len(entry.responseData) > h.maxBytes { // sanity check: don't cache responses that are larger than the cache + return } -} - -// insert adds a new entry to the cache. -func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { + h.mux.Lock() + defer h.mux.Unlock() // See if we need to make room for the new entry for h.currentSizeBytes+len(entry.responseData) >= h.maxBytes { _ = h.pop() @@ -165,8 +180,20 @@ func (h *CachingHTTPRequestDoer) insert(entry *cacheEntry) { h.currentSizeBytes += len(entry.responseData) } -// pop removes the first entry from the linked list -func (h *CachingHTTPRequestDoer) pop() *cacheEntry { +// removeExpiredEntries removes all entries that have expired. Do not call it directly. +func (h *responseCache) removeExpiredEntries() { + var current = h.head + for current != nil { + if current.expirationTime.Before(time.Now()) { + current = h.pop() + } else { + break + } + } +} + +// pop removes the first entry from the linked list. Do not call it directly. +func (h *responseCache) pop() *cacheEntry { if h.head == nil { return nil } @@ -185,13 +212,3 @@ func (h *CachingHTTPRequestDoer) pop() *cacheEntry { h.head = h.head.next return h.head } - -// NewCachingTransport creates a new CachingHTTPTransport with the given underlying transport and cache size. -func NewCachingTransport(underlyingTransport http.RoundTripper, responsesCacheSize int) *CachingHTTPRequestDoer { - return &CachingHTTPRequestDoer{ - maxBytes: responsesCacheSize, - wrappedTransport: underlyingTransport, - entriesByURL: map[string][]*cacheEntry{}, - mux: sync.RWMutex{}, - } -} diff --git a/http/client/caching_test.go b/http/client/caching_test.go index 979b1ebc22..7c8ec33528 100644 --- a/http/client/caching_test.go +++ b/http/client/caching_test.go @@ -49,7 +49,7 @@ func Test_httpClientCache(t *testing.T) { }) require.NoError(t, err) - assert.Equal(t, 0, client.currentSizeBytes) + assert.Equal(t, 0, client.cache.currentSizeBytes) }) t.Run("caches GET request with max-age", func(t *testing.T) { requestSink := &stubRoundTripper{ @@ -75,7 +75,7 @@ func Test_httpClientCache(t *testing.T) { cachedResponseData, _ := io.ReadAll(httpResponse.Body) assert.Equal(t, "Hello, World!", string(cachedResponseData)) - assert.Equal(t, 13, client.currentSizeBytes) + assert.Equal(t, 13, client.cache.currentSizeBytes) assert.Equal(t, 1, requestSink.invocations) }) t.Run("does not cache responses with no-store", func(t *testing.T) { @@ -89,7 +89,7 @@ func Test_httpClientCache(t *testing.T) { _, err := client.RoundTrip(httpRequest) require.NoError(t, err) - assert.Equal(t, 0, client.currentSizeBytes) + assert.Equal(t, 0, client.cache.currentSizeBytes) }) t.Run("max-age is too long", func(t *testing.T) { requestSink := &stubRoundTripper{ @@ -103,7 +103,7 @@ func Test_httpClientCache(t *testing.T) { _, err := client.RoundTrip(httpRequest) require.NoError(t, err) - assert.LessOrEqual(t, time.Now().Sub(client.head.expirationTime), time.Hour) + assert.LessOrEqual(t, time.Now().Sub(client.cache.head.expirationTime), time.Hour) }) t.Run("2 cache entries with different query parameters", func(t *testing.T) { requestSink := &stubRoundTripper{ @@ -146,7 +146,7 @@ func Test_httpClientCache(t *testing.T) { }, } client := NewCachingTransport(requestSink, 14) - client.insert(&cacheEntry{ + client.cache.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com"), expirationTime: time.Now().Add(time.Hour), @@ -154,7 +154,7 @@ func Test_httpClientCache(t *testing.T) { _, err := client.RoundTrip(httpRequest) require.NoError(t, err) - assert.Equal(t, 13, client.currentSizeBytes) + assert.Equal(t, 13, client.cache.currentSizeBytes) }) t.Run("orders entries by expirationTime for optimized pruning", func(t *testing.T) { requestSink := &stubRoundTripper{ @@ -165,24 +165,24 @@ func Test_httpClientCache(t *testing.T) { }, } client := NewCachingTransport(requestSink, 10000) - client.insert(&cacheEntry{ + client.cache.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com/3"), expirationTime: time.Now().Add(time.Hour * 3), }) - assert.Equal(t, client.head.requestURL.String(), "http://example.com/3") - client.insert(&cacheEntry{ + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/3") + client.cache.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com/2"), expirationTime: time.Now().Add(time.Hour * 2), }) - assert.Equal(t, client.head.requestURL.String(), "http://example.com/2") - client.insert(&cacheEntry{ + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/2") + client.cache.insert(&cacheEntry{ responseData: []byte("Hello"), requestURL: test.MustParseURL("http://example.com/1"), expirationTime: time.Now().Add(time.Hour), }) - assert.Equal(t, client.head.requestURL.String(), "http://example.com/1") + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/1") }) t.Run("entries that exceed max cache size aren't cached", func(t *testing.T) { requestSink := &stubRoundTripper{ @@ -198,7 +198,7 @@ func Test_httpClientCache(t *testing.T) { require.NoError(t, err) data, _ := io.ReadAll(httpResponse.Body) assert.Equal(t, "Hello, World!", string(data)) - assert.Equal(t, 0, client.currentSizeBytes) + assert.Equal(t, 0, client.cache.currentSizeBytes) }) } diff --git a/http/client/client.go b/http/client/client.go new file mode 100644 index 0000000000..f93a1d8426 --- /dev/null +++ b/http/client/client.go @@ -0,0 +1,47 @@ +package client + +import ( + "crypto/tls" + "errors" + "net/http" + "time" +) + +// StrictMode is a flag that can be set to true to enable strict mode for the HTTP client. +var StrictMode bool + +// NewWithCache creates a new HTTP client with the given timeout. +// It uses the DefaultCachingTransport as the underlying transport. +func NewWithCache(timeout time.Duration) *StrictHTTPClient { + return &StrictHTTPClient{ + client: &http.Client{ + Transport: DefaultCachingTransport, + Timeout: timeout, + }, + } +} + +// NewWithTLSConfig creates a new HTTP client with the given timeout and TLS configuration. +// It copies the http.DefaultTransport and sets the TLSClientConfig to the given tls.Config. +// As such, it can't be used in conjunction with the CachingRoundTripper. +func NewWithTLSConfig(timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConfig + return &StrictHTTPClient{ + client: &http.Client{ + Transport: transport, + Timeout: timeout, + }, + } +} + +type StrictHTTPClient struct { + client *http.Client +} + +func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { + if StrictMode && req.URL.Scheme != "https" { + return nil, errors.New("strictmode is enabled, but request is not over HTTPS") + } + return s.client.Do(req) +} diff --git a/http/client/client_test.go b/http/client/client_test.go new file mode 100644 index 0000000000..6e42649956 --- /dev/null +++ b/http/client/client_test.go @@ -0,0 +1,71 @@ +package client + +import ( + "crypto/tls" + "github.com/stretchr/testify/assert" + stdHttp "net/http" + "testing" + "time" +) + +func TestStrictHTTPClient(t *testing.T) { + t.Run("caching transport", func(t *testing.T) { + t.Run("strict mode enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) + t.Run("strict mode disabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = false + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.NoError(t, err) + assert.Equal(t, 1, rt.invocations) + }) + }) + t.Run("TLS transport", func(t *testing.T) { + t.Run("strict mode enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) + t.Run("sets TLS config", func(t *testing.T) { + client := NewWithTLSConfig(time.Second, &tls.Config{ + InsecureSkipVerify: true, + }) + ts := client.client.Transport.(*stdHttp.Transport) + assert.True(t, ts.TLSClientConfig.InsecureSkipVerify) + }) + }) + t.Run("error on HTTP call when strictmode is enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) +} diff --git a/http/cmd/cmd.go b/http/cmd/cmd.go index caf0a42764..89936654a7 100644 --- a/http/cmd/cmd.go +++ b/http/cmd/cmd.go @@ -43,6 +43,7 @@ func FlagSet() *pflag.FlagSet { flags.String("http.internal.auth.audience", defs.Internal.Auth.Audience, "Expected audience for JWT tokens (default: hostname)") flags.String("http.internal.auth.authorizedkeyspath", defs.Internal.Auth.AuthorizedKeysPath, "Path to an authorized_keys file for trusted JWT signers") flags.String("http.log", string(defs.Log), fmt.Sprintf("What to log about HTTP requests. Options are '%s', '%s' (log request method, URI, IP and response code), and '%s' (log the request and response body, in addition to the metadata). When debug vebosity is set the authorization headers are also logged when the request is fully logged.", http.LogNothingLevel, http.LogMetadataLevel, http.LogMetadataAndBodyLevel)) + flags.Int("http.cache.maxbytes", defs.ResponseCacheSize, "HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses.") return flags } diff --git a/http/engine.go b/http/engine.go index 00c17f4900..49b4b4b7c9 100644 --- a/http/engine.go +++ b/http/engine.go @@ -21,6 +21,7 @@ package http import ( "context" "crypto" + "crypto/tls" "errors" "fmt" "github.com/nuts-foundation/nuts-node/http/client" @@ -68,11 +69,7 @@ func (h Engine) Router() core.EchoRouter { // Configure loads the configuration for the HTTP engine. func (h *Engine) Configure(serverConfig core.ServerConfig) error { - // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. - if h.config.ResponseCacheSize > 0 { - defaultTransport := http.DefaultTransport.(*http.Transport) - http.DefaultTransport = client.NewCachingTransport(defaultTransport, h.config.ResponseCacheSize) - } + h.configureClient(serverConfig) // Override default Echo HTTP error when bearer token is expected but not provided. // Echo returns "Bad Request (400)" by default, but we use this for incorrect use of API parameters. @@ -104,6 +101,19 @@ func (h *Engine) Configure(serverConfig core.ServerConfig) error { return h.applyAuthMiddleware(h.server, "/internal", h.config.Internal.Auth) } +func (h *Engine) configureClient(serverConfig core.ServerConfig) { + client.StrictMode = serverConfig.Strictmode + httpTransport := http.DefaultTransport.(*http.Transport) + if httpTransport.TLSClientConfig == nil { + httpTransport.TLSClientConfig = &tls.Config{} + } + httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 + // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. + if h.config.ResponseCacheSize > 0 { + client.DefaultCachingTransport = client.NewCachingTransport(http.DefaultTransport, h.config.ResponseCacheSize) + } +} + func (h *Engine) createEchoServer() (EchoServer, error) { echoServer := echo.New() echoServer.HideBanner = true diff --git a/policy/api/v1/client/client.go b/policy/api/v1/client/client.go index 0902ef5d72..431d41db67 100644 --- a/policy/api/v1/client/client.go +++ b/policy/api/v1/client/client.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vcr/pe" "net/http" "time" @@ -40,7 +41,7 @@ type HTTPClient struct { func NewHTTPClient(strictMode bool, timeout time.Duration, tlsConfig *tls.Config) HTTPClient { return HTTPClient{ strictMode: strictMode, - httpClient: core.NewStrictHTTPClient(strictMode, timeout, tlsConfig), + httpClient: client.NewWithTLSConfig(timeout, tlsConfig), } } diff --git a/vcr/vcr.go b/vcr/vcr.go index 76a5fa65fc..8dc3b6834f 100644 --- a/vcr/vcr.go +++ b/vcr/vcr.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "github.com/nuts-foundation/go-leia/v4" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/pki" "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" @@ -219,12 +220,12 @@ func (c *vcr) Configure(config core.ServerConfig) error { // meaning while the issuer allocated an HTTP connection the wallet will try to allocate one as well. // This moved back to 1 http.Client when the credential is requested asynchronously. // Should be fixed as part of https://github.com/nuts-foundation/nuts-node/issues/2039 (also fix core.NewStrictHTTPClient) - c.issuerHttpClient = core.NewStrictHTTPClient(config.Strictmode, c.config.OpenID4VCI.Timeout, tlsConfig) - c.walletHttpClient = core.NewStrictHTTPClient(config.Strictmode, c.config.OpenID4VCI.Timeout, tlsConfig) + c.issuerHttpClient = client.NewWithTLSConfig(c.config.OpenID4VCI.Timeout, tlsConfig) + c.walletHttpClient = client.NewWithTLSConfig(c.config.OpenID4VCI.Timeout, tlsConfig) c.openidSessionStore = c.storageClient.GetSessionDatabase() } - status := revocation.NewStatusList2021(c.storageClient.GetSQLDatabase(), core.NewStrictHTTPClient(config.Strictmode, config.HTTPClient.Timeout, tlsConfig)) + status := revocation.NewStatusList2021(c.storageClient.GetSQLDatabase(), client.NewWithCache(config.HTTPClient.Timeout)) c.issuer = issuer.NewIssuer(c.issuerStore, c, networkPublisher, openidHandlerFn, didResolver, c.keyStore, c.jsonldManager, c.trustConfig, status) c.verifier = verifier.NewVerifier(c.verifierStore, didResolver, c.keyResolver, c.jsonldManager, c.trustConfig, status) diff --git a/vcr/vcr_test.go b/vcr/vcr_test.go index 79e2fdd985..9a5cc60589 100644 --- a/vcr/vcr_test.go +++ b/vcr/vcr_test.go @@ -27,6 +27,7 @@ import ( "github.com/nuts-foundation/go-leia/v4" "github.com/nuts-foundation/go-stoabs" bbolt2 "github.com/nuts-foundation/go-stoabs/bbolt" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/pki" "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" @@ -83,6 +84,7 @@ func TestVCR_Configure(t *testing.T) { }) t.Run("strictmode passed to client APIs", func(t *testing.T) { // load test VC + client.StrictMode = true testVC := test.ValidNutsOrganizationCredential(t) issuerDID := did.MustParseDID(testVC.Issuer.String()) testDirectory := io.TestDirectory(t) diff --git a/vdr/didweb/web.go b/vdr/didweb/web.go index 75d1e6e283..50d9fe50fc 100644 --- a/vdr/didweb/web.go +++ b/vdr/didweb/web.go @@ -19,11 +19,11 @@ package didweb import ( - "crypto/tls" "errors" "fmt" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vdr/resolver" "mime" "net/http" @@ -42,18 +42,9 @@ type Resolver struct { // NewResolver creates a new did:web Resolver with default TLS configuration. func NewResolver() *Resolver { - transport := http.DefaultTransport - if httpTransport, ok := transport.(*http.Transport); ok { - // Might not be http.Transport in testing - httpTransport = httpTransport.Clone() - httpTransport.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - transport = httpTransport - } return &Resolver{ HttpClient: &http.Client{ - Transport: transport, + Transport: client.DefaultCachingTransport, Timeout: 5 * time.Second, }, } diff --git a/vdr/didweb/web_test.go b/vdr/didweb/web_test.go index d8c7093a45..519464be39 100644 --- a/vdr/didweb/web_test.go +++ b/vdr/didweb/web_test.go @@ -19,8 +19,8 @@ package didweb import ( - "crypto/tls" "github.com/nuts-foundation/go-did/did" + "github.com/nuts-foundation/nuts-node/http/client" http2 "github.com/nuts-foundation/nuts-node/test/http" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,8 +56,8 @@ func TestResolver_NewResolver(t *testing.T) { resolver := NewResolver() assert.NotNil(t, resolver.HttpClient) - t.Run("it sets min TLS version", func(t *testing.T) { - assert.Equal(t, uint16(tls.VersionTLS12), resolver.HttpClient.Transport.(*http.Transport).TLSClientConfig.MinVersion) + t.Run("it uses cached transport", func(t *testing.T) { + assert.Same(t, client.DefaultCachingTransport, resolver.HttpClient.Transport) }) } diff --git a/vdr/vdr_test.go b/vdr/vdr_test.go index acc16b68ff..834af56e1a 100644 --- a/vdr/vdr_test.go +++ b/vdr/vdr_test.go @@ -30,6 +30,7 @@ import ( ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/nuts-node/audit" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vdr/didnuts" "github.com/nuts-foundation/nuts-node/vdr/didnuts/didstore" @@ -443,7 +444,7 @@ func TestVDR_Configure(t *testing.T) { storageInstance := storage.NewTestStorageEngine(t) t.Run("it can resolve using did:web", func(t *testing.T) { t.Run("not in database", func(t *testing.T) { - http.DefaultTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + client.DefaultCachingTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{ Header: map[string][]string{"Content-Type": {"application/json"}}, StatusCode: http.StatusOK, From 8a5c8f542d49ea34522e4ab56827b53533f3aa13 Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Sun, 2 Jun 2024 07:01:14 +0200 Subject: [PATCH 10/10] f --- http/client/client.go | 8 ++++++++ http/engine.go | 6 ------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/http/client/client.go b/http/client/client.go index f93a1d8426..0ff36908d2 100644 --- a/http/client/client.go +++ b/http/client/client.go @@ -7,6 +7,14 @@ import ( "time" ) +func init() { + httpTransport := http.DefaultTransport.(*http.Transport) + if httpTransport.TLSClientConfig == nil { + httpTransport.TLSClientConfig = &tls.Config{} + } + httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 +} + // StrictMode is a flag that can be set to true to enable strict mode for the HTTP client. var StrictMode bool diff --git a/http/engine.go b/http/engine.go index 0728c62942..8e9e4aadbf 100644 --- a/http/engine.go +++ b/http/engine.go @@ -21,7 +21,6 @@ package http import ( "context" "crypto" - "crypto/tls" "errors" "fmt" "github.com/nuts-foundation/nuts-node/http/client" @@ -104,11 +103,6 @@ func (h *Engine) Configure(serverConfig core.ServerConfig) error { func (h *Engine) configureClient(serverConfig core.ServerConfig) { client.StrictMode = serverConfig.Strictmode - httpTransport := http.DefaultTransport.(*http.Transport) - if httpTransport.TLSClientConfig == nil { - httpTransport.TLSClientConfig = &tls.Config{} - } - httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. if h.config.ResponseCacheSize > 0 { client.DefaultCachingTransport = client.NewCachingTransport(http.DefaultTransport, h.config.ResponseCacheSize)