diff --git a/README.rst b/README.rst index 39c17e6a36..c5e6484212 100644 --- a/README.rst +++ b/README.rst @@ -180,7 +180,7 @@ The following options can be configured on the server: verbosity info Log level (trace, debug, info, warn, error) httpclient.timeout 30s Request time-out for HTTP clients, such as '10s'. Refer to Golang's 'time.Duration' syntax for a more elaborate description of the syntax. **Crypto** - crypto.storage Storage to use, ''%s' for an external backend (deprecated).' for file system (for development purposes), 'fs' for HashiCorp Vault KV store, 'vaultkv' for Azure Key Vault.%!(EXTRA string=azure-keyvault, string=external) + crypto.storage Storage to use, 'fs' for file system (for development purposes), 'vaultkv' for HashiCorp Vault KV store, 'azure-keyvault' for Azure Key Vault, 'external' for an external backend (deprecated). crypto.azurekv.hsm false Whether to store the key in a hardware security module (HSM). If true, the Azure Key Vault must be configured for HSM usage. Default: false crypto.azurekv.timeout 10s Timeout of client calls to Azure Key Vault, in Golang time.Duration string format (e.g. 10s). crypto.azurekv.url The URL of the Azure Key Vault. @@ -194,13 +194,14 @@ The following options can be configured on the server: discovery.server.ids [] IDs of the Discovery Service for which to act as server. If an ID does not map to a loaded service definition, the node will fail to start. **HTTP** http.log metadata What to log about HTTP requests. Options are 'nothing', 'metadata' (log request method, URI, IP and response code), and 'metadata-and-body' (log the request and response body, in addition to the metadata). When debug vebosity is set the authorization headers are also logged when the request is fully logged. + http.cache.maxbytes 10485760 HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. http.internal.address 127.0.0.1:8081 Address and port the server will be listening to for internal-facing endpoints. http.internal.auth.audience Expected audience for JWT tokens (default: hostname) http.internal.auth.authorizedkeyspath Path to an authorized_keys file for trusted JWT signers http.internal.auth.type Whether to enable authentication for /internal endpoints, specify 'token_v2' for bearer token mode or 'token' for legacy bearer token mode. http.public.address \:8080 Address and port the server will be listening to for public-facing endpoints. **JSONLD** - jsonld.contexts.localmapping [https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson,https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. + jsonld.contexts.localmapping [https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. jsonld.contexts.remoteallowlist [https://schema.org,https://www.w3.org/2018/credentials/v1,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json,https://w3id.org/vc/status-list/2021/v1] In strict mode, fetching external JSON-LD contexts is not allowed except for context-URLs listed here. **PKI** pki.maxupdatefailhours 4 Maximum number of hours that a denylist update can fail diff --git a/auth/client/iam/openid4vp.go b/auth/client/iam/openid4vp.go index 23329b864e..76414ba892 100644 --- a/auth/client/iam/openid4vp.go +++ b/auth/client/iam/openid4vp.go @@ -23,6 +23,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "net/url" "time" @@ -56,7 +57,7 @@ func NewClient(wallet holder.Wallet, keyResolver resolver.KeyResolver, jwtSigner return &OpenID4VPClient{ httpClient: HTTPClient{ strictMode: strictMode, - httpClient: core.NewStrictHTTPClient(strictMode, httpClientTimeout, nil), + httpClient: client.NewWithCache(httpClientTimeout), }, keyResolver: keyResolver, jwtSigner: jwtSigner, diff --git a/auth/client/iam/openid4vp_test.go b/auth/client/iam/openid4vp_test.go index 3d36d2e0b1..555dc426ed 100644 --- a/auth/client/iam/openid4vp_test.go +++ b/auth/client/iam/openid4vp_test.go @@ -23,13 +23,12 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "net/http/httptest" "testing" "time" - "github.com/nuts-foundation/nuts-node/core" - ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/go-did/vc" @@ -399,7 +398,7 @@ func createClientTestContext(t *testing.T, tlsConfig *tls.Config) *clientTestCon wallet: wallet, httpClient: HTTPClient{ strictMode: false, - httpClient: core.NewStrictHTTPClient(false, 10*time.Second, tlsConfig), + httpClient: client.NewWithTLSConfig(10*time.Second, tlsConfig), }, jwtSigner: jwtSigner, keyResolver: keyResolver, diff --git a/core/http_client.go b/core/http_client.go index 2778c690f7..dc9aff03bb 100644 --- a/core/http_client.go +++ b/core/http_client.go @@ -21,13 +21,10 @@ package core import ( "context" - "crypto/tls" - "errors" "fmt" "github.com/sirupsen/logrus" "io" "net/http" - "time" ) // HttpResponseBodyLogClipAt is the maximum length of a response body to log. @@ -175,41 +172,3 @@ func newEmptyTokenGenerator() AuthorizationTokenGenerator { return "", nil } } - -// NewStrictHTTPClient creates a HTTPRequestDoer that only allows HTTPS calls when strictmode is enabled. -func NewStrictHTTPClient(strictmode bool, timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient { - if tlsConfig == nil { - tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - } - - transport := http.DefaultTransport - // Might not be http.Transport in testing - if httpTransport, ok := transport.(*http.Transport); ok { - // cloning the transport might reduce performance. - httpTransport = httpTransport.Clone() - httpTransport.TLSClientConfig = tlsConfig - transport = httpTransport - } - - return &StrictHTTPClient{ - client: &http.Client{ - Transport: transport, - Timeout: timeout, - }, - strictMode: strictmode, - } -} - -type StrictHTTPClient struct { - client *http.Client - strictMode bool -} - -func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { - if s.strictMode && req.URL.Scheme != "https" { - return nil, errors.New("strictmode is enabled, but request is not over HTTPS") - } - return s.client.Do(req) -} diff --git a/core/http_client_test.go b/core/http_client_test.go index 8a82a5711d..891364b279 100644 --- a/core/http_client_test.go +++ b/core/http_client_test.go @@ -31,7 +31,6 @@ import ( "net/url" "strings" "testing" - "time" ) func TestHTTPClient(t *testing.T) { @@ -91,17 +90,6 @@ func TestHTTPClient(t *testing.T) { }) } -func TestStrictHTTPClient_Do(t *testing.T) { - t.Run("error on HTTP call when strictmode is enabled", func(t *testing.T) { - client := NewStrictHTTPClient(true, time.Second, nil) - httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) - - _, err := client.Do(httpRequest) - - assert.Error(t, err) - }) -} - func TestUserAgentRequestEditor(t *testing.T) { GitVersion = "" req := &stdHttp.Request{Header: map[string][]string{}} diff --git a/discovery/api/server/client/http.go b/discovery/api/server/client/http.go index 0fb18e05fb..cc47de45cf 100644 --- a/discovery/api/server/client/http.go +++ b/discovery/api/server/client/http.go @@ -27,6 +27,7 @@ import ( "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/discovery/log" + "github.com/nuts-foundation/nuts-node/http/client" "io" "net/http" "net/url" @@ -36,7 +37,7 @@ import ( // New creates a new DefaultHTTPClient. func New(strictMode bool, timeout time.Duration, tlsConfig *tls.Config) *DefaultHTTPClient { return &DefaultHTTPClient{ - client: core.NewStrictHTTPClient(strictMode, timeout, tlsConfig), + client: client.NewWithTLSConfig(timeout, tlsConfig), } } diff --git a/docs/pages/deployment/cli-reference.rst b/docs/pages/deployment/cli-reference.rst index 91c903b6f5..7b60dfcb59 100755 --- a/docs/pages/deployment/cli-reference.rst +++ b/docs/pages/deployment/cli-reference.rst @@ -23,7 +23,7 @@ The following options apply to the server commands below: --crypto.azurekv.hsm Whether to store the key in a hardware security module (HSM). If true, the Azure Key Vault must be configured for HSM usage. Default: false --crypto.azurekv.timeout duration Timeout of client calls to Azure Key Vault, in Golang time.Duration string format (e.g. 10s). (default 10s) --crypto.azurekv.url string The URL of the Azure Key Vault. - --crypto.storage string Storage to use, ''%s' for an external backend (deprecated).' for file system (for development purposes), 'fs' for HashiCorp Vault KV store, 'vaultkv' for Azure Key Vault.%!(EXTRA string=azure-keyvault, string=external) + --crypto.storage string Storage to use, 'fs' for file system (for development purposes), 'vaultkv' for HashiCorp Vault KV store, 'azure-keyvault' for Azure Key Vault, 'external' for an external backend (deprecated). --crypto.vault.address string The Vault address. If set it overwrites the VAULT_ADDR env var. --crypto.vault.pathprefix string The Vault path prefix. (default "kv") --crypto.vault.timeout duration Timeout of client calls to Vault, in Golang time.Duration string format (e.g. 1s). (default 5s) @@ -38,6 +38,7 @@ The following options apply to the server commands below: --events.nats.timeout int Timeout for NATS server operations (default 30) --goldenhammer.enabled Whether to enable automatically fixing DID documents with the required endpoints. (default true) --goldenhammer.interval duration The interval in which to check for DID documents to fix. (default 10m0s) + --http.cache.maxbytes int HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. (default 10485760) --http.internal.address string Address and port the server will be listening to for internal-facing endpoints. (default "127.0.0.1:8081") --http.internal.auth.audience string Expected audience for JWT tokens (default: hostname) --http.internal.auth.authorizedkeyspath string Path to an authorized_keys file for trusted JWT signers diff --git a/docs/pages/deployment/server_options.rst b/docs/pages/deployment/server_options.rst index 3076e601ff..34a73ca25f 100755 --- a/docs/pages/deployment/server_options.rst +++ b/docs/pages/deployment/server_options.rst @@ -15,7 +15,7 @@ verbosity info Log level (trace, debug, info, warn, error) httpclient.timeout 30s Request time-out for HTTP clients, such as '10s'. Refer to Golang's 'time.Duration' syntax for a more elaborate description of the syntax. **Crypto** - crypto.storage Storage to use, ''%s' for an external backend (deprecated).' for file system (for development purposes), 'fs' for HashiCorp Vault KV store, 'vaultkv' for Azure Key Vault.%!(EXTRA string=azure-keyvault, string=external) + crypto.storage Storage to use, 'fs' for file system (for development purposes), 'vaultkv' for HashiCorp Vault KV store, 'azure-keyvault' for Azure Key Vault, 'external' for an external backend (deprecated). crypto.azurekv.hsm false Whether to store the key in a hardware security module (HSM). If true, the Azure Key Vault must be configured for HSM usage. Default: false crypto.azurekv.timeout 10s Timeout of client calls to Azure Key Vault, in Golang time.Duration string format (e.g. 10s). crypto.azurekv.url The URL of the Azure Key Vault. @@ -29,13 +29,14 @@ discovery.server.ids [] IDs of the Discovery Service for which to act as server. If an ID does not map to a loaded service definition, the node will fail to start. **HTTP** http.log metadata What to log about HTTP requests. Options are 'nothing', 'metadata' (log request method, URI, IP and response code), and 'metadata-and-body' (log the request and response body, in addition to the metadata). When debug vebosity is set the authorization headers are also logged when the request is fully logged. + http.cache.maxbytes 10485760 HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses. http.internal.address 127.0.0.1:8081 Address and port the server will be listening to for internal-facing endpoints. http.internal.auth.audience Expected audience for JWT tokens (default: hostname) http.internal.auth.authorizedkeyspath Path to an authorized_keys file for trusted JWT signers http.internal.auth.type Whether to enable authentication for /internal endpoints, specify 'token_v2' for bearer token mode or 'token' for legacy bearer token mode. http.public.address \:8080 Address and port the server will be listening to for public-facing endpoints. **JSONLD** - jsonld.contexts.localmapping [https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson,https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. + jsonld.contexts.localmapping [https://w3id.org/vc/status-list/2021/v1=assets/contexts/w3c-statuslist2021.ldjson,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json=assets/contexts/lds-jws2020-v1.ldjson,https://schema.org=assets/contexts/schema-org-v13.ldjson,https://nuts.nl/credentials/v1=assets/contexts/nuts.ldjson,https://www.w3.org/2018/credentials/v1=assets/contexts/w3c-credentials-v1.ldjson] This setting allows mapping external URLs to local files for e.g. preventing external dependencies. These mappings have precedence over those in remoteallowlist. jsonld.contexts.remoteallowlist [https://schema.org,https://www.w3.org/2018/credentials/v1,https://w3c-ccg.github.io/lds-jws2020/contexts/lds-jws2020-v1.json,https://w3id.org/vc/status-list/2021/v1] In strict mode, fetching external JSON-LD contexts is not allowed except for context-URLs listed here. **PKI** pki.maxupdatefailhours 4 Maximum number of hours that a denylist update can fail diff --git a/http/client/caching.go b/http/client/caching.go new file mode 100644 index 0000000000..40e5afc0bf --- /dev/null +++ b/http/client/caching.go @@ -0,0 +1,214 @@ +package client + +import ( + "bytes" + "fmt" + "github.com/nuts-foundation/nuts-node/http/log" + "github.com/pquerna/cachecontrol" + "io" + "net/http" + "net/url" + "sync" + "time" +) + +// DefaultCachingTransport is a http.RoundTripper that can be used as a default transport for HTTP clients. +// If caching is enabled, it will cache responses according to RFC 7234. +// If caching is disabled, it will behave like http.DefaultTransport. +var DefaultCachingTransport = http.DefaultTransport + +// maxCacheTime is the maximum time responses are cached. +// Even if the server responds with a longer cache time, responses are never cached longer than maxCacheTime. +const maxCacheTime = time.Hour + +var _ http.RoundTripper = &CachingRoundTripper{} + +// NewCachingTransport creates a new CachingHTTPTransport with the given underlying transport and cache size. +func NewCachingTransport(underlyingTransport http.RoundTripper, responsesCacheSize int) *CachingRoundTripper { + return &CachingRoundTripper{ + cache: newCache(responsesCacheSize), + wrappedTransport: underlyingTransport, + } +} + +// CachingRoundTripper is a simple HTTP client cache for HTTP responses. +// It only caches GET requests (since for POST request caching, request bodies need to be cached as well), +// and only if the response is cacheable according to RFC 7234. +// It only works on expiration time and does not respect ETags headers. +// When the cache is full, the entries that expire first are removed to make room for new entries (since those are the first ones to be pruned any ways). +type CachingRoundTripper struct { + cache *responseCache + wrappedTransport http.RoundTripper +} + +func (r *CachingRoundTripper) RoundTrip(httpRequest *http.Request) (*http.Response, error) { + if httpRequest.Method == http.MethodGet { + if response := r.cache.get(httpRequest); response != nil { + return response, nil + } + } + httpResponse, err := r.wrappedTransport.RoundTrip(httpRequest) + if err != nil { + return nil, err + } + err = r.cacheResponse(httpRequest, httpResponse) + if err != nil { + return nil, err + } + return httpResponse, nil +} + +// cacheResponse caches the response if it's cacheable. +func (r *CachingRoundTripper) cacheResponse(httpRequest *http.Request, httpResponse *http.Response) error { + if httpRequest.Method != http.MethodGet { + return nil + } + reasons, expirationTime, err := cachecontrol.CachableResponse(httpRequest, httpResponse, cachecontrol.Options{PrivateCache: false}) + if err != nil { + log.Logger().WithError(err).Infof("error while checking cacheability of response (url=%s), not caching", httpRequest.URL.String()) + return nil + } + // We don't want to cache responses for too long, as that increases the risk of staleness, + // and could keep cause very long-lived entries to never be pruned. + maxExpirationTime := time.Now().Add(maxCacheTime) + if expirationTime.After(maxExpirationTime) { + expirationTime = maxExpirationTime + } + if len(reasons) > 0 || expirationTime.IsZero() { + log.Logger().Debugf("response (url=%s) is not cacheable: %v", httpRequest.URL.String(), reasons) + return nil + } + responseBytes, err := io.ReadAll(httpResponse.Body) + if err != nil { + return fmt.Errorf("error while reading response body for caching: %w", err) + } + r.cache.insert(&cacheEntry{ + responseData: responseBytes, + requestMethod: httpRequest.Method, + requestURL: httpRequest.URL, + requestRawQuery: httpRequest.URL.RawQuery, + responseStatus: httpResponse.StatusCode, + responseHeaders: httpResponse.Header, + expirationTime: expirationTime, + }) + httpResponse.Body = io.NopCloser(bytes.NewReader(responseBytes)) + return nil +} + +func newCache(responsesCacheSize int) *responseCache { + return &responseCache{ + maxBytes: responsesCacheSize, + entriesByURL: map[string][]*cacheEntry{}, + mux: sync.RWMutex{}, + } +} + +type responseCache struct { + maxBytes int + // currentSizeBytes is the current size of the cache in bytes. + // It's used to make room for new entries when the cache is full. + currentSizeBytes int + // head is the first entry of a linked list of cache entries, ordered by expiration time. + // The first entry is the one that will expire first, which optimizes the removal of expired entries. + // When an entry is inserted in the cache, it's inserted in the right place in the linked list (ordered by expiry). + head *cacheEntry + // entriesByURL is a map of cache entries, indexed by the URL of the request. + // This optimizes the lookup of cache entries by URL. + entriesByURL map[string][]*cacheEntry + mux sync.RWMutex +} + +type cacheEntry struct { + responseData []byte + requestURL *url.URL + requestMethod string + requestRawQuery string + expirationTime time.Time + next *cacheEntry + responseStatus int + responseHeaders http.Header +} + +// get is called by the transport to get a cached response. +func (h *responseCache) get(httpRequest *http.Request) *http.Response { + h.mux.Lock() + defer h.mux.Unlock() + h.removeExpiredEntries() + // Find cached response + entries := h.entriesByURL[httpRequest.URL.String()] + for _, entry := range entries { + if entry.requestMethod == httpRequest.Method && entry.requestRawQuery == httpRequest.URL.RawQuery { + return &http.Response{ + StatusCode: entry.responseStatus, + Header: entry.responseHeaders, + Body: io.NopCloser(bytes.NewReader(entry.responseData)), + } + } + } + return nil +} + +// insert is called by the transport to insert a new entry to the cache. +func (h *responseCache) insert(entry *cacheEntry) { + if len(entry.responseData) > h.maxBytes { // sanity check: don't cache responses that are larger than the cache + return + } + h.mux.Lock() + defer h.mux.Unlock() + // See if we need to make room for the new entry + for h.currentSizeBytes+len(entry.responseData) >= h.maxBytes { + _ = h.pop() + } + if h.head == nil { + // First entry + h.head = entry + } else { + // Insert in the linked list, ordered by expiration time + var current = h.head + for current.next != nil && current.next.expirationTime.Before(entry.expirationTime) { + current = current.next + } + if current == h.head { + h.head = entry + } + entry.next = current.next + current.next = entry + } + // Insert in the URL map for quick lookup + h.entriesByURL[entry.requestURL.String()] = append(h.entriesByURL[entry.requestURL.String()], entry) + + h.currentSizeBytes += len(entry.responseData) +} + +// removeExpiredEntries removes all entries that have expired. Do not call it directly. +func (h *responseCache) removeExpiredEntries() { + var current = h.head + for current != nil { + if current.expirationTime.Before(time.Now()) { + current = h.pop() + } else { + break + } + } +} + +// pop removes the first entry from the linked list. Do not call it directly. +func (h *responseCache) pop() *cacheEntry { + if h.head == nil { + return nil + } + requestURL := h.head.requestURL.String() + entries := h.entriesByURL[requestURL] + for i, entry := range entries { + if entry == h.head { + h.entriesByURL[requestURL] = append(entries[:i], entries[i+1:]...) + if len(h.entriesByURL[requestURL]) == 0 { + delete(h.entriesByURL, requestURL) + } + break + } + } + h.currentSizeBytes -= len(h.head.responseData) + h.head = h.head.next + return h.head +} diff --git a/http/client/caching_test.go b/http/client/caching_test.go new file mode 100644 index 0000000000..7c8ec33528 --- /dev/null +++ b/http/client/caching_test.go @@ -0,0 +1,228 @@ +/* + * Copyright (C) 2024 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package client + +import ( + "bytes" + "fmt" + "github.com/nuts-foundation/nuts-node/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net/http" + "testing" + "time" +) + +func Test_httpClientCache(t *testing.T) { + httpRequest := &http.Request{ + Method: http.MethodGet, + URL: test.MustParseURL("http://example.com"), + } + t.Run("does not cache POST requests", func(t *testing.T) { + client := NewCachingTransport(&stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "public, max-age=3600", + }, + }, 1000) + + _, err := client.RoundTrip(&http.Request{ + Method: http.MethodPost, + }) + + require.NoError(t, err) + assert.Equal(t, 0, client.cache.currentSizeBytes) + }) + t.Run("caches GET request with max-age", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusCreated, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := NewCachingTransport(requestSink, 1000) + + // Initial fetch + httpResponse, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) + fetchedResponseData, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(fetchedResponseData)) + + // Fetch the response again, should be taken from cache + httpResponse, err = client.RoundTrip(httpRequest) + require.NoError(t, err) + assert.Equal(t, http.StatusCreated, httpResponse.StatusCode) + cachedResponseData, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(cachedResponseData)) + + assert.Equal(t, 13, client.cache.currentSizeBytes) + assert.Equal(t, 1, requestSink.invocations) + }) + t.Run("does not cache responses with no-store", func(t *testing.T) { + client := NewCachingTransport(&stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "nothing", + }, + }, 1000) + + _, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + assert.Equal(t, 0, client.cache.currentSizeBytes) + }) + t.Run("max-age is too long", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": fmt.Sprintf("max-age=%d", int(time.Hour.Seconds()*24)), + }, + } + client := NewCachingTransport(requestSink, 1000) + + _, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + assert.LessOrEqual(t, time.Now().Sub(client.cache.head.expirationTime), time.Hour) + }) + t.Run("2 cache entries with different query parameters", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusOK, + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + requestSink.dataFn = func(req *http.Request) []byte { + return []byte(req.URL.String()) + } + client := NewCachingTransport(requestSink, 1000) + + // Initial fetch of the resources + _, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + alternativeRequest := &http.Request{ + Method: http.MethodGet, + URL: test.MustParseURL("http://example.com?foo=bar"), + } + _, err = client.RoundTrip(alternativeRequest) + require.NoError(t, err) + assert.Equal(t, 2, requestSink.invocations) + + // Fetch the responses again, should be taken from cache + response1, _ := client.RoundTrip(httpRequest) + response1Data, _ := io.ReadAll(response1.Body) + response2, _ := client.RoundTrip(alternativeRequest) + response2Data, _ := io.ReadAll(response2.Body) + assert.Equal(t, 2, requestSink.invocations) + assert.Equal(t, "http://example.com", string(response1Data)) + assert.Equal(t, "http://example.com?foo=bar", string(response2Data)) + }) + t.Run("prunes cache when full", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := NewCachingTransport(requestSink, 14) + client.cache.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com"), + expirationTime: time.Now().Add(time.Hour), + }) + + _, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + assert.Equal(t, 13, client.cache.currentSizeBytes) + }) + t.Run("orders entries by expirationTime for optimized pruning", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := NewCachingTransport(requestSink, 10000) + client.cache.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/3"), + expirationTime: time.Now().Add(time.Hour * 3), + }) + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/3") + client.cache.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/2"), + expirationTime: time.Now().Add(time.Hour * 2), + }) + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/2") + client.cache.insert(&cacheEntry{ + responseData: []byte("Hello"), + requestURL: test.MustParseURL("http://example.com/1"), + expirationTime: time.Now().Add(time.Hour), + }) + assert.Equal(t, client.cache.head.requestURL.String(), "http://example.com/1") + }) + t.Run("entries that exceed max cache size aren't cached", func(t *testing.T) { + requestSink := &stubRoundTripper{ + statusCode: http.StatusOK, + data: []byte("Hello, World!"), + headers: map[string]string{ + "Cache-Control": "max-age=3600", + }, + } + client := NewCachingTransport(requestSink, 5) + + httpResponse, err := client.RoundTrip(httpRequest) + require.NoError(t, err) + data, _ := io.ReadAll(httpResponse.Body) + assert.Equal(t, "Hello, World!", string(data)) + assert.Equal(t, 0, client.cache.currentSizeBytes) + }) +} + +type stubRoundTripper struct { + statusCode int + data []byte + dataFn func(r *http.Request) []byte + headers map[string]string + invocations int +} + +func (s *stubRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + s.invocations++ + response := &http.Response{ + StatusCode: s.statusCode, + } + if s.dataFn != nil { + response.Body = io.NopCloser(bytes.NewReader(s.dataFn(req))) + } else { + response.Body = io.NopCloser(bytes.NewReader(s.data)) + } + response.Header = make(http.Header) + for key, value := range s.headers { + response.Header.Set(key, value) + } + return response, nil +} diff --git a/http/client/client.go b/http/client/client.go new file mode 100644 index 0000000000..0ff36908d2 --- /dev/null +++ b/http/client/client.go @@ -0,0 +1,55 @@ +package client + +import ( + "crypto/tls" + "errors" + "net/http" + "time" +) + +func init() { + httpTransport := http.DefaultTransport.(*http.Transport) + if httpTransport.TLSClientConfig == nil { + httpTransport.TLSClientConfig = &tls.Config{} + } + httpTransport.TLSClientConfig.MinVersion = tls.VersionTLS12 +} + +// StrictMode is a flag that can be set to true to enable strict mode for the HTTP client. +var StrictMode bool + +// NewWithCache creates a new HTTP client with the given timeout. +// It uses the DefaultCachingTransport as the underlying transport. +func NewWithCache(timeout time.Duration) *StrictHTTPClient { + return &StrictHTTPClient{ + client: &http.Client{ + Transport: DefaultCachingTransport, + Timeout: timeout, + }, + } +} + +// NewWithTLSConfig creates a new HTTP client with the given timeout and TLS configuration. +// It copies the http.DefaultTransport and sets the TLSClientConfig to the given tls.Config. +// As such, it can't be used in conjunction with the CachingRoundTripper. +func NewWithTLSConfig(timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPClient { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConfig + return &StrictHTTPClient{ + client: &http.Client{ + Transport: transport, + Timeout: timeout, + }, + } +} + +type StrictHTTPClient struct { + client *http.Client +} + +func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { + if StrictMode && req.URL.Scheme != "https" { + return nil, errors.New("strictmode is enabled, but request is not over HTTPS") + } + return s.client.Do(req) +} diff --git a/http/client/client_test.go b/http/client/client_test.go new file mode 100644 index 0000000000..6e42649956 --- /dev/null +++ b/http/client/client_test.go @@ -0,0 +1,71 @@ +package client + +import ( + "crypto/tls" + "github.com/stretchr/testify/assert" + stdHttp "net/http" + "testing" + "time" +) + +func TestStrictHTTPClient(t *testing.T) { + t.Run("caching transport", func(t *testing.T) { + t.Run("strict mode enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) + t.Run("strict mode disabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = false + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.NoError(t, err) + assert.Equal(t, 1, rt.invocations) + }) + }) + t.Run("TLS transport", func(t *testing.T) { + t.Run("strict mode enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) + t.Run("sets TLS config", func(t *testing.T) { + client := NewWithTLSConfig(time.Second, &tls.Config{ + InsecureSkipVerify: true, + }) + ts := client.client.Transport.(*stdHttp.Transport) + assert.True(t, ts.TLSClientConfig.InsecureSkipVerify) + }) + }) + t.Run("error on HTTP call when strictmode is enabled", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + + client := NewWithCache(time.Second) + httpRequest, _ := stdHttp.NewRequest("GET", "http://example.com", nil) + _, err := client.Do(httpRequest) + + assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.Equal(t, 0, rt.invocations) + }) +} diff --git a/http/cmd/cmd.go b/http/cmd/cmd.go index caf0a42764..89936654a7 100644 --- a/http/cmd/cmd.go +++ b/http/cmd/cmd.go @@ -43,6 +43,7 @@ func FlagSet() *pflag.FlagSet { flags.String("http.internal.auth.audience", defs.Internal.Auth.Audience, "Expected audience for JWT tokens (default: hostname)") flags.String("http.internal.auth.authorizedkeyspath", defs.Internal.Auth.AuthorizedKeysPath, "Path to an authorized_keys file for trusted JWT signers") flags.String("http.log", string(defs.Log), fmt.Sprintf("What to log about HTTP requests. Options are '%s', '%s' (log request method, URI, IP and response code), and '%s' (log the request and response body, in addition to the metadata). When debug vebosity is set the authorization headers are also logged when the request is fully logged.", http.LogNothingLevel, http.LogMetadataLevel, http.LogMetadataAndBodyLevel)) + flags.Int("http.cache.maxbytes", defs.ResponseCacheSize, "HTTP client maximum size of the response cache in bytes. If 0, the HTTP client does not cache responses.") return flags } diff --git a/http/config.go b/http/config.go index 02fe3d096d..4554802dfb 100644 --- a/http/config.go +++ b/http/config.go @@ -28,6 +28,7 @@ func DefaultConfig() Config { Public: PublicConfig{ Address: ":8080", }, + ResponseCacheSize: 10 * 1024 * 1024, // 10mb } } @@ -37,6 +38,8 @@ type Config struct { Log LogLevel `koanf:"log"` Public PublicConfig `koanf:"public"` Internal InternalConfig `koanf:"internal"` + // ResponseCacheSize is the maximum number of bytes cached by HTTP clients. + ResponseCacheSize int `koanf:"cache.maxbytes"` } // PublicConfig contains the configuration for outside-facing HTTP endpoints. diff --git a/http/engine.go b/http/engine.go index 5a1859deb3..8e9e4aadbf 100644 --- a/http/engine.go +++ b/http/engine.go @@ -23,6 +23,7 @@ import ( "crypto" "errors" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "net/http" "os" "strings" @@ -68,6 +69,8 @@ func (h Engine) Router() core.EchoRouter { // Configure loads the configuration for the HTTP engine. func (h *Engine) Configure(serverConfig core.ServerConfig) error { + h.configureClient(serverConfig) + // Override default Echo HTTP error when bearer token is expected but not provided. // Echo returns "Bad Request (400)" by default, but we use this for incorrect use of API parameters. // "Unauthorized (401)" is a better fit. @@ -98,6 +101,14 @@ func (h *Engine) Configure(serverConfig core.ServerConfig) error { return h.applyAuthMiddleware(h.server, "/internal", h.config.Internal.Auth) } +func (h *Engine) configureClient(serverConfig core.ServerConfig) { + client.StrictMode = serverConfig.Strictmode + // Configure the HTTP caching client, if enabled. Set it to http.DefaultTransport so it can be used by any subsystem. + if h.config.ResponseCacheSize > 0 { + client.DefaultCachingTransport = client.NewCachingTransport(http.DefaultTransport, h.config.ResponseCacheSize) + } +} + func (h *Engine) createEchoServer() (EchoServer, error) { echoServer := echo.New() echoServer.HideBanner = true diff --git a/policy/api/v1/client/client.go b/policy/api/v1/client/client.go index 0902ef5d72..431d41db67 100644 --- a/policy/api/v1/client/client.go +++ b/policy/api/v1/client/client.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vcr/pe" "net/http" "time" @@ -40,7 +41,7 @@ type HTTPClient struct { func NewHTTPClient(strictMode bool, timeout time.Duration, tlsConfig *tls.Config) HTTPClient { return HTTPClient{ strictMode: strictMode, - httpClient: core.NewStrictHTTPClient(strictMode, timeout, tlsConfig), + httpClient: client.NewWithTLSConfig(timeout, tlsConfig), } } diff --git a/vcr/vcr.go b/vcr/vcr.go index 76a5fa65fc..8dc3b6834f 100644 --- a/vcr/vcr.go +++ b/vcr/vcr.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "github.com/nuts-foundation/go-leia/v4" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/pki" "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" @@ -219,12 +220,12 @@ func (c *vcr) Configure(config core.ServerConfig) error { // meaning while the issuer allocated an HTTP connection the wallet will try to allocate one as well. // This moved back to 1 http.Client when the credential is requested asynchronously. // Should be fixed as part of https://github.com/nuts-foundation/nuts-node/issues/2039 (also fix core.NewStrictHTTPClient) - c.issuerHttpClient = core.NewStrictHTTPClient(config.Strictmode, c.config.OpenID4VCI.Timeout, tlsConfig) - c.walletHttpClient = core.NewStrictHTTPClient(config.Strictmode, c.config.OpenID4VCI.Timeout, tlsConfig) + c.issuerHttpClient = client.NewWithTLSConfig(c.config.OpenID4VCI.Timeout, tlsConfig) + c.walletHttpClient = client.NewWithTLSConfig(c.config.OpenID4VCI.Timeout, tlsConfig) c.openidSessionStore = c.storageClient.GetSessionDatabase() } - status := revocation.NewStatusList2021(c.storageClient.GetSQLDatabase(), core.NewStrictHTTPClient(config.Strictmode, config.HTTPClient.Timeout, tlsConfig)) + status := revocation.NewStatusList2021(c.storageClient.GetSQLDatabase(), client.NewWithCache(config.HTTPClient.Timeout)) c.issuer = issuer.NewIssuer(c.issuerStore, c, networkPublisher, openidHandlerFn, didResolver, c.keyStore, c.jsonldManager, c.trustConfig, status) c.verifier = verifier.NewVerifier(c.verifierStore, didResolver, c.keyResolver, c.jsonldManager, c.trustConfig, status) diff --git a/vcr/vcr_test.go b/vcr/vcr_test.go index 79e2fdd985..9a5cc60589 100644 --- a/vcr/vcr_test.go +++ b/vcr/vcr_test.go @@ -27,6 +27,7 @@ import ( "github.com/nuts-foundation/go-leia/v4" "github.com/nuts-foundation/go-stoabs" bbolt2 "github.com/nuts-foundation/go-stoabs/bbolt" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/pki" "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" @@ -83,6 +84,7 @@ func TestVCR_Configure(t *testing.T) { }) t.Run("strictmode passed to client APIs", func(t *testing.T) { // load test VC + client.StrictMode = true testVC := test.ValidNutsOrganizationCredential(t) issuerDID := did.MustParseDID(testVC.Issuer.String()) testDirectory := io.TestDirectory(t) diff --git a/vdr/didweb/web.go b/vdr/didweb/web.go index 75d1e6e283..50d9fe50fc 100644 --- a/vdr/didweb/web.go +++ b/vdr/didweb/web.go @@ -19,11 +19,11 @@ package didweb import ( - "crypto/tls" "errors" "fmt" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/vdr/resolver" "mime" "net/http" @@ -42,18 +42,9 @@ type Resolver struct { // NewResolver creates a new did:web Resolver with default TLS configuration. func NewResolver() *Resolver { - transport := http.DefaultTransport - if httpTransport, ok := transport.(*http.Transport); ok { - // Might not be http.Transport in testing - httpTransport = httpTransport.Clone() - httpTransport.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - transport = httpTransport - } return &Resolver{ HttpClient: &http.Client{ - Transport: transport, + Transport: client.DefaultCachingTransport, Timeout: 5 * time.Second, }, } diff --git a/vdr/didweb/web_test.go b/vdr/didweb/web_test.go index d8c7093a45..519464be39 100644 --- a/vdr/didweb/web_test.go +++ b/vdr/didweb/web_test.go @@ -19,8 +19,8 @@ package didweb import ( - "crypto/tls" "github.com/nuts-foundation/go-did/did" + "github.com/nuts-foundation/nuts-node/http/client" http2 "github.com/nuts-foundation/nuts-node/test/http" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,8 +56,8 @@ func TestResolver_NewResolver(t *testing.T) { resolver := NewResolver() assert.NotNil(t, resolver.HttpClient) - t.Run("it sets min TLS version", func(t *testing.T) { - assert.Equal(t, uint16(tls.VersionTLS12), resolver.HttpClient.Transport.(*http.Transport).TLSClientConfig.MinVersion) + t.Run("it uses cached transport", func(t *testing.T) { + assert.Same(t, client.DefaultCachingTransport, resolver.HttpClient.Transport) }) } diff --git a/vdr/vdr_test.go b/vdr/vdr_test.go index acc16b68ff..834af56e1a 100644 --- a/vdr/vdr_test.go +++ b/vdr/vdr_test.go @@ -30,6 +30,7 @@ import ( ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/nuts-node/audit" "github.com/nuts-foundation/nuts-node/core" + "github.com/nuts-foundation/nuts-node/http/client" "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vdr/didnuts" "github.com/nuts-foundation/nuts-node/vdr/didnuts/didstore" @@ -443,7 +444,7 @@ func TestVDR_Configure(t *testing.T) { storageInstance := storage.NewTestStorageEngine(t) t.Run("it can resolve using did:web", func(t *testing.T) { t.Run("not in database", func(t *testing.T) { - http.DefaultTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + client.DefaultCachingTransport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{ Header: map[string][]string{"Content-Type": {"application/json"}}, StatusCode: http.StatusOK,