From d4946ea088bc32ac5f07c6acfafc442af7943aea Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 18:36:29 +0900 Subject: [PATCH 01/18] import retryable http package --- go.mod | 5 +++++ go.sum | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/go.mod b/go.mod index 7aad2ed61..04e281d52 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,11 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require ( + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect +) + require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index 20e11ddf6..f1759263d 100644 --- a/go.sum +++ b/go.sum @@ -176,9 +176,13 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/hackerwins/go-memdb v1.3.3-0.20211225080334-513a74641622 h1:7UYuTq6zV83XV4zqn14gUuTtcywzbxGhUnj+hr/MUrE= github.com/hackerwins/go-memdb v1.3.3-0.20211225080334-513a74641622/go.mod h1:uBTr1oQbtuMgd1SSGoR8YV27eT3sBHbYiNm53bMpgSg= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-immutable-radix v1.3.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= From 0485a95d48ee3d0797315203c7f7c02812879ed8 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 18:38:22 +0900 Subject: [PATCH 02/18] refactor client with retryable http package --- pkg/webhook/client.go | 212 +++++++++++++++++++------------------ server/backend/backend.go | 33 +++--- server/rpc/auth/webhook.go | 21 ++-- 3 files changed, 139 insertions(+), 127 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 9fb76bde4..fa947e1cb 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -18,7 +18,6 @@ package webhook import ( - "bytes" "context" "crypto/hmac" "crypto/sha256" @@ -26,14 +25,15 @@ import ( "encoding/json" "errors" "fmt" - "math" - "net/http" - "syscall" - "time" - + "github.com/hashicorp/go-retryablehttp" "github.com/yorkie-team/yorkie/pkg/cache" "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/server/logging" + "io" + "log" + "net/http" + "syscall" + "time" ) var ( @@ -49,159 +49,167 @@ var ( // Options are the options for the webhook client. type Options struct { - CacheKeyPrefix string - CacheTTL time.Duration + CacheTTL time.Duration MaxRetries uint64 + MinWaitInterval time.Duration MaxWaitInterval time.Duration - - HMACKey string + RequestTimeout time.Duration } // Client is a client for the webhook. type Client[Req any, Res any] struct { - cache *cache.LRUExpireCache[string, types.Pair[int, *Res]] - url string - options Options + cache *cache.LRUExpireCache[string, types.Pair[int, *Res]] + retryClient *retryablehttp.Client + options Options } // NewClient creates a new instance of Client. func NewClient[Req any, Res any]( - url string, Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], options Options, ) *Client[Req, Res] { return &Client[Req, Res]{ - url: url, - cache: Cache, + cache: Cache, + retryClient: &retryablehttp.Client{ + HTTPClient: &http.Client{ + Timeout: options.RequestTimeout, + }, + RetryMax: int(options.MaxRetries), + RetryWaitMin: options.MinWaitInterval, + RetryWaitMax: options.MaxWaitInterval, + CheckRetry: shouldRetry, + Logger: nil, + Backoff: retryablehttp.DefaultBackoff, + ErrorHandler: func(resp *http.Response, err error, numTries int) (*http.Response, error) { + if err == nil && numTries == int(options.MaxRetries)+1 { + return nil, ErrWebhookTimeout + } + return resp, fmt.Errorf("after %d attempts, errors were: %w", numTries, err) + }, + }, options: options, } } // Send sends the given request to the webhook. -func (c *Client[Req, Res]) Send(ctx context.Context, req Req) (*Res, int, error) { - body, err := json.Marshal(req) +func (c *Client[Req, Res]) Send( + ctx context.Context, + CacheKeyPrefix, url, HMACKey string, + reqData Req, +) (*Res, int, error) { + body, err := json.Marshal(reqData) if err != nil { return nil, 0, fmt.Errorf("marshal webhook request: %w", err) } - cacheKey := c.options.CacheKeyPrefix + ":" + string(body) + cacheKey := CacheKeyPrefix + ":" + string(body) if entry, ok := c.cache.Get(cacheKey); ok { return entry.Second, entry.First, nil } - var res Res - status, err := c.withExponentialBackoff(ctx, func() (int, error) { - resp, err := c.post("application/json", body) - if err != nil { - return 0, fmt.Errorf("post to webhook: %w", err) - } - defer func() { - if err := resp.Body.Close(); err != nil { - // TODO(hackerwins): Consider to remove the dependency of logging. - logging.From(ctx).Error(err) - } - }() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusUnauthorized && - resp.StatusCode != http.StatusForbidden { - return resp.StatusCode, ErrUnexpectedStatusCode + req, err := c.buildRequest(ctx, url, HMACKey, body) + if err != nil { + return nil, 0, fmt.Errorf("build request: %w", err) + } + + resp, err := c.retryClient.Do(req) + if err != nil { + var statusCode int + if resp != nil { + statusCode = resp.StatusCode } + log.Println(err) - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return resp.StatusCode, ErrUnexpectedResponse + return nil, statusCode, fmt.Errorf("post to webhook: %w", err) + } + defer func() { + io.Copy(io.Discard, resp.Body) + if err := resp.Body.Close(); err != nil { + // TODO(hackerwins): Consider to remove the dependency of logging. + logging.From(ctx).Error(err) } + }() - return resp.StatusCode, nil - }) - if err != nil { - return nil, status, err + if !isExpectedCode(resp.StatusCode) { + return nil, resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedStatusCode) + } + + var res Res + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil, resp.StatusCode, ErrUnexpectedResponse } // TODO(hackerwins): We should consider caching the response of Unauthorized as well. - if status != http.StatusUnauthorized { - c.cache.Add(cacheKey, types.Pair[int, *Res]{First: status, Second: &res}, c.options.CacheTTL) + if resp.StatusCode != http.StatusUnauthorized { + c.cache.Add(cacheKey, types.Pair[int, *Res]{First: resp.StatusCode, Second: &res}, c.options.CacheTTL) } - return &res, status, nil + return &res, resp.StatusCode, nil } -// post sends an HTTP POST request with HMAC-SHA256 signature headers. -// If key is empty, post sends an HTTP POST without signature. -func (c *Client[Req, Res]) post(contentType string, body []byte) (*http.Response, error) { - req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(body)) +func (c *Client[Req, Res]) buildRequest(ctx context.Context, url, HMACKey string, body []byte) (*retryablehttp.Request, error) { + req, err := retryablehttp.NewRequestWithContext(ctx, "POST", url, body) if err != nil { - return nil, fmt.Errorf("create HTTP request: %w", err) + return nil, fmt.Errorf("create POST request with context: %w", err) } - req.Header.Set("Content-Type", contentType) - if c.options.HMACKey != "" { - mac := hmac.New(sha256.New, []byte(c.options.HMACKey)) - if _, err := mac.Write(body); err != nil { - return nil, fmt.Errorf("write HMAC body: %w", err) + req.Header.Set("Content-Type", "application/json") + + if HMACKey != "" { + if err := setSignature(req, body, HMACKey); err != nil { + return req, fmt.Errorf("set HMAC signature: %w", err) } - signature := mac.Sum(nil) - signatureHex := hex.EncodeToString(signature) // Convert to hex string - req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex)) } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send to %s: %w", c.url, err) // Wrapped with context + return req, nil +} + +func setSignature(req *retryablehttp.Request, Data []byte, HMACKey string) error { + mac := hmac.New(sha256.New, []byte(HMACKey)) + if _, err := mac.Write(Data); err != nil { + return fmt.Errorf("write HMAC body: %w", err) } + signature := mac.Sum(nil) + signatureHex := hex.EncodeToString(signature) + req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex)) - return resp, nil + return nil } -func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) { - var retries uint64 - var statusCode int - for retries <= c.options.MaxRetries { - statusCode, err := webhookFn() - if !shouldRetry(statusCode, err) { - if err == ErrUnexpectedStatusCode { - return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) - } - - return statusCode, err +// shouldRetry returns true if the given error should be retried. +// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry +func shouldRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { + // If the connection is reset, we should retry. + if err != nil { + var errno syscall.Errno + if errors.As(err, &errno) && errors.Is(errno, syscall.ECONNRESET) { + return true, nil } - waitBeforeRetry := waitInterval(retries, c.options.MaxWaitInterval) + return false, err + } - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-time.After(waitBeforeRetry): + if resp != nil { + code := resp.StatusCode + if isExpectedCode(code) { + return false, nil + } + if isRetryCode(code) { + return true, nil } - - retries++ } - return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout) + return false, nil } -// waitInterval returns the interval of given retries. (2^retries * 100) milliseconds. -func waitInterval(retries uint64, maxWaitInterval time.Duration) time.Duration { - interval := time.Duration(math.Pow(2, float64(retries))) * 100 * time.Millisecond - if maxWaitInterval < interval { - return maxWaitInterval - } - - return interval +func isExpectedCode(code int) bool { + return code == http.StatusOK || code == http.StatusUnauthorized || code == http.StatusForbidden } -// shouldRetry returns true if the given error should be retried. -// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry -func shouldRetry(statusCode int, err error) bool { - // If the connection is reset, we should retry. - var errno syscall.Errno - if errors.As(err, &errno) { - return errno == syscall.ECONNRESET - } - - return statusCode == http.StatusInternalServerError || - statusCode == http.StatusServiceUnavailable || - statusCode == http.StatusGatewayTimeout || - statusCode == http.StatusTooManyRequests +func isRetryCode(code int) bool { + return code == http.StatusInternalServerError || + code == http.StatusServiceUnavailable || + code == http.StatusGatewayTimeout || + code == http.StatusTooManyRequests } diff --git a/server/backend/backend.go b/server/backend/backend.go index cf9bd4bf3..65e0f54b9 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -22,7 +22,9 @@ package backend import ( "context" "fmt" + "github.com/yorkie-team/yorkie/pkg/webhook" "os" + "time" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/cache" @@ -43,11 +45,9 @@ import ( type Backend struct { Config *Config - // AuthWebhookCache is used to cache the response of the auth webhook. - WebhookCache *cache.LRUExpireCache[string, pkgtypes.Pair[ - int, - *types.AuthWebhookResponse, - ]] + // AuthWebhookClient is used to send auth webhook. + AuthWebhookClient *webhook.Client[types.AuthWebhookRequest, types.AuthWebhookResponse] + // PubSub is used to publish/subscribe events to/from clients. PubSub *pubsub.PubSub // Locker is used to lock/unlock resources. @@ -82,9 +82,18 @@ func New( conf.Hostname = hostname } - // 02. Create in-memory cache, pubsub, and locker. - cache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( - conf.AuthWebhookCacheSize, + // 02. Create auth webhook client with in-memory cache, pubsub, and locker. + auth := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( + cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( + conf.AuthWebhookCacheSize, + ), + webhook.Options{ + CacheTTL: conf.ParseAuthWebhookCacheTTL(), + MaxRetries: conf.AuthWebhookMaxRetries, + MinWaitInterval: 200 * time.Millisecond, + MaxWaitInterval: conf.ParseAuthWebhookMaxWaitInterval(), + RequestTimeout: 30 * time.Second, + }, ) locker := sync.New() pubsub := pubsub.New() @@ -141,10 +150,10 @@ func New( ) return &Backend{ - Config: conf, - WebhookCache: cache, - Locker: locker, - PubSub: pubsub, + Config: conf, + AuthWebhookClient: auth, + Locker: locker, + PubSub: pubsub, Metrics: metrics, DB: db, diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index 218390f78..cd866bea0 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -44,22 +44,17 @@ func verifyAccess( token string, accessInfo *types.AccessInfo, ) error { - cli := webhook.NewClient[types.AuthWebhookRequest]( + res, status, err := be.AuthWebhookClient.Send( + ctx, + prj.PublicKey+":auth", prj.AuthWebhookURL, - be.WebhookCache, - webhook.Options{ - CacheKeyPrefix: prj.PublicKey + ":auth", - CacheTTL: be.Config.ParseAuthWebhookCacheTTL(), - MaxRetries: be.Config.AuthWebhookMaxRetries, - MaxWaitInterval: be.Config.ParseAuthWebhookMaxWaitInterval(), + "", + types.AuthWebhookRequest{ + Token: token, + Method: accessInfo.Method, + Attributes: accessInfo.Attributes, }, ) - - res, status, err := cli.Send(ctx, types.AuthWebhookRequest{ - Token: token, - Method: accessInfo.Method, - Attributes: accessInfo.Attributes, - }) if err != nil { return fmt.Errorf("send to webhook: %w", err) } From 50349e108c94b34aff838bffe93b575ea08d6dd6 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 18:40:00 +0900 Subject: [PATCH 03/18] import webhook config in test --- server/packs/packs_test.go | 20 +++++++++++--------- server/rpc/server_test.go | 20 +++++++++++--------- test/complex/main_test.go | 1 + 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/server/packs/packs_test.go b/server/packs/packs_test.go index fa7413cfd..66494b923 100644 --- a/server/packs/packs_test.go +++ b/server/packs/packs_test.go @@ -94,15 +94,17 @@ func TestMain(m *testing.M) { testBackend, err = backend.New( &backend.Config{ - AdminUser: helper.AdminUser, - AdminPassword: helper.AdminPassword, - UseDefaultProject: helper.UseDefaultProject, - ClientDeactivateThreshold: helper.ClientDeactivateThreshold, - SnapshotThreshold: helper.SnapshotThreshold, - AuthWebhookCacheSize: helper.AuthWebhookSize, - ProjectCacheSize: helper.ProjectCacheSize, - ProjectCacheTTL: helper.ProjectCacheTTL.String(), - AdminTokenDuration: helper.AdminTokenDuration, + AdminUser: helper.AdminUser, + AdminPassword: helper.AdminPassword, + UseDefaultProject: helper.UseDefaultProject, + ClientDeactivateThreshold: helper.ClientDeactivateThreshold, + SnapshotThreshold: helper.SnapshotThreshold, + AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), + AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + ProjectCacheSize: helper.ProjectCacheSize, + ProjectCacheTTL: helper.ProjectCacheTTL.String(), + AdminTokenDuration: helper.AdminTokenDuration, }, &mongo.Config{ ConnectionURI: helper.MongoConnectionURI, YorkieDatabase: helper.TestDBName(), diff --git a/server/rpc/server_test.go b/server/rpc/server_test.go index 205733b35..81a2a1fdc 100644 --- a/server/rpc/server_test.go +++ b/server/rpc/server_test.go @@ -68,15 +68,17 @@ func TestMain(m *testing.M) { } be, err := backend.New(&backend.Config{ - AdminUser: helper.AdminUser, - AdminPassword: helper.AdminPassword, - UseDefaultProject: helper.UseDefaultProject, - ClientDeactivateThreshold: helper.ClientDeactivateThreshold, - SnapshotThreshold: helper.SnapshotThreshold, - AuthWebhookCacheSize: helper.AuthWebhookSize, - ProjectCacheSize: helper.ProjectCacheSize, - ProjectCacheTTL: helper.ProjectCacheTTL.String(), - AdminTokenDuration: helper.AdminTokenDuration, + AdminUser: helper.AdminUser, + AdminPassword: helper.AdminPassword, + UseDefaultProject: helper.UseDefaultProject, + ClientDeactivateThreshold: helper.ClientDeactivateThreshold, + SnapshotThreshold: helper.SnapshotThreshold, + AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), + AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + ProjectCacheSize: helper.ProjectCacheSize, + ProjectCacheTTL: helper.ProjectCacheTTL.String(), + AdminTokenDuration: helper.AdminTokenDuration, }, &mongo.Config{ ConnectionURI: helper.MongoConnectionURI, YorkieDatabase: helper.TestDBName(), diff --git a/test/complex/main_test.go b/test/complex/main_test.go index 26b530b99..cd96b3886 100644 --- a/test/complex/main_test.go +++ b/test/complex/main_test.go @@ -79,6 +79,7 @@ func TestMain(m *testing.M) { ClientDeactivateThreshold: helper.ClientDeactivateThreshold, SnapshotThreshold: helper.SnapshotThreshold, AuthWebhookCacheSize: helper.AuthWebhookSize, + AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), ProjectCacheSize: helper.ProjectCacheSize, ProjectCacheTTL: helper.ProjectCacheTTL.String(), AdminTokenDuration: helper.AdminTokenDuration, From 9cf8a7637ba78d882a38c1bc9bb344aa81456b79 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 18:40:34 +0900 Subject: [PATCH 04/18] fix client test --- pkg/webhook/client_test.go | 83 +++++++++++++++----------------------- 1 file changed, 33 insertions(+), 50 deletions(-) diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 61f7004eb..41cd9b4f6 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/yorkie-team/yorkie/pkg/cache" - "github.com/yorkie-team/yorkie/pkg/types" + pkgtypes "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/pkg/webhook" ) @@ -46,7 +46,6 @@ func verifySignature(signatureHeader, secret string, body []byte) error { func TestHMAC(t *testing.T) { const secretKey = "my-secret-key" const wrongKey = "wrong-key" - reqData := testRequest{Name: "HMAC Tester"} resData := testResponse{Greeting: "HMAC OK"} testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -69,22 +68,26 @@ func TestHMAC(t *testing.T) { assert.NoError(t, json.NewEncoder(w).Encode(resData)) })) defer testServer.Close() - + client := webhook.NewClient[testRequest, testResponse]( + cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *testResponse]]( + 100, + ), + webhook.Options{ + CacheTTL: 10 * time.Second, + MaxRetries: 0, + MinWaitInterval: 2 * time.Second, + MaxWaitInterval: 10 * time.Second, + RequestTimeout: 30 * time.Second, + }, + ) t.Run("webhook client with valid HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest, testResponse]( + resp, statusCode, err := client.Send( + context.Background(), + ":auth", testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: secretKey, - }, + secretKey, + testRequest{Name: t.Name()}, ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) @@ -92,59 +95,39 @@ func TestHMAC(t *testing.T) { }) t.Run("webhook client with invalid HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest, testResponse]( + resp, statusCode, err := client.Send( + context.Background(), + ":auth", testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: wrongKey, - }, + wrongKey, + testRequest{Name: t.Name()}, ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) assert.Error(t, err) assert.Equal(t, http.StatusForbidden, statusCode) assert.Nil(t, resp) }) t.Run("webhook client without HMAC key test", func(t *testing.T) { - client := webhook.NewClient[testRequest]( + resp, statusCode, err := client.Send( + context.Background(), + ":auth", testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - }, + "", + testRequest{Name: t.Name()}, ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, reqData) assert.Error(t, err) assert.Equal(t, http.StatusUnauthorized, statusCode) assert.Nil(t, resp) }) t.Run("webhook client with empty body test", func(t *testing.T) { - client := webhook.NewClient[testRequest]( + resp, statusCode, err := client.Send( + context.Background(), + ":auth", testServer.URL, - cache.NewLRUExpireCache[string, types.Pair[int, *testResponse]](100), - webhook.Options{ - CacheKeyPrefix: "testPrefix-hmac", - CacheTTL: 5 * time.Second, - MaxRetries: 0, - MaxWaitInterval: 200 * time.Millisecond, - HMACKey: secretKey, - }, + secretKey, + testRequest{}, ) - - ctx := context.Background() - resp, statusCode, err := client.Send(ctx, testRequest{}) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) From d5fe2fc67d4229bb46d909c7df8450668fb9829f Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 18:56:23 +0900 Subject: [PATCH 05/18] lint --- pkg/webhook/client.go | 98 ++++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index fa947e1cb..45f0be665 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -25,15 +25,16 @@ import ( "encoding/json" "errors" "fmt" - "github.com/hashicorp/go-retryablehttp" - "github.com/yorkie-team/yorkie/pkg/cache" - "github.com/yorkie-team/yorkie/pkg/types" - "github.com/yorkie-team/yorkie/server/logging" "io" - "log" "net/http" "syscall" "time" + + "github.com/hashicorp/go-retryablehttp" + + "github.com/yorkie-team/yorkie/pkg/cache" + "github.com/yorkie-team/yorkie/pkg/types" + "github.com/yorkie-team/yorkie/server/logging" ) var ( @@ -66,36 +67,38 @@ type Client[Req any, Res any] struct { // NewClient creates a new instance of Client. func NewClient[Req any, Res any]( - Cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], + cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], options Options, ) *Client[Req, Res] { - return &Client[Req, Res]{ - cache: Cache, - retryClient: &retryablehttp.Client{ - HTTPClient: &http.Client{ - Timeout: options.RequestTimeout, - }, - RetryMax: int(options.MaxRetries), - RetryWaitMin: options.MinWaitInterval, - RetryWaitMax: options.MaxWaitInterval, - CheckRetry: shouldRetry, - Logger: nil, - Backoff: retryablehttp.DefaultBackoff, - ErrorHandler: func(resp *http.Response, err error, numTries int) (*http.Response, error) { - if err == nil && numTries == int(options.MaxRetries)+1 { - return nil, ErrWebhookTimeout - } - return resp, fmt.Errorf("after %d attempts, errors were: %w", numTries, err) - }, + retryClient := &retryablehttp.Client{ + HTTPClient: &http.Client{ + Timeout: options.RequestTimeout, }, - options: options, + RetryMax: int(options.MaxRetries), + RetryWaitMin: options.MinWaitInterval, + RetryWaitMax: options.MaxWaitInterval, + CheckRetry: shouldRetry, + Logger: nil, + Backoff: retryablehttp.DefaultBackoff, + ErrorHandler: func(resp *http.Response, err error, numTries int) (*http.Response, error) { + if err == nil && numTries == int(options.MaxRetries)+1 { + return nil, ErrWebhookTimeout + } + return resp, fmt.Errorf("after %d attempts, errors were: %w", numTries, err) + }, + } + + return &Client[Req, Res]{ + cache: cache, + retryClient: retryClient, + options: options, } } // Send sends the given request to the webhook. func (c *Client[Req, Res]) Send( ctx context.Context, - CacheKeyPrefix, url, HMACKey string, + cacheKeyPrefix, url, hmacKey string, reqData Req, ) (*Res, int, error) { body, err := json.Marshal(reqData) @@ -103,31 +106,30 @@ func (c *Client[Req, Res]) Send( return nil, 0, fmt.Errorf("marshal webhook request: %w", err) } - cacheKey := CacheKeyPrefix + ":" + string(body) + cacheKey := fmt.Sprintf("%s:%s", cacheKeyPrefix, string(body)) if entry, ok := c.cache.Get(cacheKey); ok { return entry.Second, entry.First, nil } - req, err := c.buildRequest(ctx, url, HMACKey, body) + req, err := c.buildRequest(ctx, url, hmacKey, body) if err != nil { return nil, 0, fmt.Errorf("build request: %w", err) } resp, err := c.retryClient.Do(req) if err != nil { - var statusCode int + statusCode := 0 if resp != nil { statusCode = resp.StatusCode } - log.Println(err) - return nil, statusCode, fmt.Errorf("post to webhook: %w", err) + return nil, statusCode, fmt.Errorf("post webhook request: %w", err) } defer func() { io.Copy(io.Discard, resp.Body) - if err := resp.Body.Close(); err != nil { + if cerr := resp.Body.Close(); cerr != nil { // TODO(hackerwins): Consider to remove the dependency of logging. - logging.From(ctx).Error(err) + logging.From(ctx).Error(cerr) } }() @@ -140,7 +142,6 @@ func (c *Client[Req, Res]) Send( return nil, resp.StatusCode, ErrUnexpectedResponse } - // TODO(hackerwins): We should consider caching the response of Unauthorized as well. if resp.StatusCode != http.StatusUnauthorized { c.cache.Add(cacheKey, types.Pair[int, *Res]{First: resp.StatusCode, Second: &res}, c.options.CacheTTL) } @@ -148,32 +149,31 @@ func (c *Client[Req, Res]) Send( return &res, resp.StatusCode, nil } -func (c *Client[Req, Res]) buildRequest(ctx context.Context, url, HMACKey string, body []byte) (*retryablehttp.Request, error) { - req, err := retryablehttp.NewRequestWithContext(ctx, "POST", url, body) +// buildRequest creates a new HTTP POST request with the appropriate headers. +func (c *Client[Req, Res]) buildRequest(ctx context.Context, url, hmacKey string, body []byte) (*retryablehttp.Request, error) { + req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, url, body) if err != nil { return nil, fmt.Errorf("create POST request with context: %w", err) } req.Header.Set("Content-Type", "application/json") - - if HMACKey != "" { - if err := setSignature(req, body, HMACKey); err != nil { - return req, fmt.Errorf("set HMAC signature: %w", err) + if hmacKey != "" { + if err := setSignature(req, body, hmacKey); err != nil { + return nil, fmt.Errorf("set HMAC signature: %w", err) } } return req, nil } -func setSignature(req *retryablehttp.Request, Data []byte, HMACKey string) error { - mac := hmac.New(sha256.New, []byte(HMACKey)) - if _, err := mac.Write(Data); err != nil { +// setSignature sets the HMAC signature header for the request. +func setSignature(req *retryablehttp.Request, data []byte, hmacKey string) error { + mac := hmac.New(sha256.New, []byte(hmacKey)) + if _, err := mac.Write(data); err != nil { return fmt.Errorf("write HMAC body: %w", err) } - signature := mac.Sum(nil) - signatureHex := hex.EncodeToString(signature) + signatureHex := hex.EncodeToString(mac.Sum(nil)) req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex)) - return nil } @@ -203,10 +203,14 @@ func shouldRetry(ctx context.Context, resp *http.Response, err error) (bool, err return false, nil } +// isExpectedCode checks if the status code is acceptable. func isExpectedCode(code int) bool { - return code == http.StatusOK || code == http.StatusUnauthorized || code == http.StatusForbidden + return code == http.StatusOK || + code == http.StatusUnauthorized || + code == http.StatusForbidden } +// isRetryCode checks if the status code is one that should trigger a retry. func isRetryCode(code int) bool { return code == http.StatusInternalServerError || code == http.StatusServiceUnavailable || From c6e0febdecd5c518c5272e6c7fb20966057f5129 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 19:21:56 +0900 Subject: [PATCH 06/18] extract cache from `webhook/client` --- pkg/webhook/client.go | 24 ++------------ server/backend/backend.go | 33 +++++++++++-------- server/rpc/auth/webhook.go | 66 ++++++++++++++++++++++++++------------ 3 files changed, 67 insertions(+), 56 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 45f0be665..6cf1d9fef 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -32,8 +32,6 @@ import ( "github.com/hashicorp/go-retryablehttp" - "github.com/yorkie-team/yorkie/pkg/cache" - "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/server/logging" ) @@ -50,8 +48,6 @@ var ( // Options are the options for the webhook client. type Options struct { - CacheTTL time.Duration - MaxRetries uint64 MinWaitInterval time.Duration MaxWaitInterval time.Duration @@ -60,14 +56,12 @@ type Options struct { // Client is a client for the webhook. type Client[Req any, Res any] struct { - cache *cache.LRUExpireCache[string, types.Pair[int, *Res]] retryClient *retryablehttp.Client options Options } // NewClient creates a new instance of Client. func NewClient[Req any, Res any]( - cache *cache.LRUExpireCache[string, types.Pair[int, *Res]], options Options, ) *Client[Req, Res] { retryClient := &retryablehttp.Client{ @@ -89,7 +83,6 @@ func NewClient[Req any, Res any]( } return &Client[Req, Res]{ - cache: cache, retryClient: retryClient, options: options, } @@ -98,18 +91,9 @@ func NewClient[Req any, Res any]( // Send sends the given request to the webhook. func (c *Client[Req, Res]) Send( ctx context.Context, - cacheKeyPrefix, url, hmacKey string, - reqData Req, + url, hmacKey string, + body []byte, ) (*Res, int, error) { - body, err := json.Marshal(reqData) - if err != nil { - return nil, 0, fmt.Errorf("marshal webhook request: %w", err) - } - - cacheKey := fmt.Sprintf("%s:%s", cacheKeyPrefix, string(body)) - if entry, ok := c.cache.Get(cacheKey); ok { - return entry.Second, entry.First, nil - } req, err := c.buildRequest(ctx, url, hmacKey, body) if err != nil { @@ -142,10 +126,6 @@ func (c *Client[Req, Res]) Send( return nil, resp.StatusCode, ErrUnexpectedResponse } - if resp.StatusCode != http.StatusUnauthorized { - c.cache.Add(cacheKey, types.Pair[int, *Res]{First: resp.StatusCode, Second: &res}, c.options.CacheTTL) - } - return &res, resp.StatusCode, nil } diff --git a/server/backend/backend.go b/server/backend/backend.go index 65e0f54b9..6eaabb05d 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -45,8 +45,13 @@ import ( type Backend struct { Config *Config - // AuthWebhookClient is used to send auth webhook. - AuthWebhookClient *webhook.Client[types.AuthWebhookRequest, types.AuthWebhookResponse] + // AuthWebhookCache is used to cache the response of the auth webhook. + WebhookCache *cache.LRUExpireCache[string, pkgtypes.Pair[ + int, + *types.AuthWebhookResponse, + ]] + // WebhookClient is used to send auth webhook. + WebhookClient *webhook.Client[types.AuthWebhookRequest, types.AuthWebhookResponse] // PubSub is used to publish/subscribe events to/from clients. PubSub *pubsub.PubSub @@ -82,27 +87,28 @@ func New( conf.Hostname = hostname } - // 02. Create auth webhook client with in-memory cache, pubsub, and locker. + // 02. Create auth webhook client and cache. + cache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( + conf.AuthWebhookCacheSize, + ) auth := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( - cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( - conf.AuthWebhookCacheSize, - ), webhook.Options{ - CacheTTL: conf.ParseAuthWebhookCacheTTL(), MaxRetries: conf.AuthWebhookMaxRetries, MinWaitInterval: 200 * time.Millisecond, MaxWaitInterval: conf.ParseAuthWebhookMaxWaitInterval(), RequestTimeout: 30 * time.Second, }, ) + + // 03. Create pubsub, and locker. locker := sync.New() pubsub := pubsub.New() - // 03. Create the background instance. The background instance is used to + // 04. Create the background instance. The background instance is used to // manage background tasks. bg := background.New(metrics) - // 04. Create the database instance. If the MongoDB configuration is given, + // 05. Create the database instance. If the MongoDB configuration is given, // create a MongoDB instance. Otherwise, create a memory database instance. var err error var db database.Database @@ -150,10 +156,11 @@ func New( ) return &Backend{ - Config: conf, - AuthWebhookClient: auth, - Locker: locker, - PubSub: pubsub, + Config: conf, + WebhookCache: cache, + WebhookClient: auth, + Locker: locker, + PubSub: pubsub, Metrics: metrics, DB: db, diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index cd866bea0..a12537103 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -18,12 +18,14 @@ package auth import ( "context" + "encoding/json" "errors" "fmt" "net/http" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/internal/metaerrors" + pkgtypes "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend" ) @@ -44,36 +46,58 @@ func verifyAccess( token string, accessInfo *types.AccessInfo, ) error { - res, status, err := be.AuthWebhookClient.Send( + req := types.AuthWebhookRequest{ + Token: token, + Method: accessInfo.Method, + Attributes: accessInfo.Attributes, + } + + body, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("marshal webhook request: %w", err) + } + + cacheKey := generateCacheKey(prj.PublicKey, body) + if entry, ok := be.WebhookCache.Get(cacheKey); ok { + return handleWebhookResponse(entry.First, entry.Second) + } + + res, status, err := be.WebhookClient.Send( ctx, - prj.PublicKey+":auth", prj.AuthWebhookURL, "", - types.AuthWebhookRequest{ - Token: token, - Method: accessInfo.Method, - Attributes: accessInfo.Attributes, - }, + body, ) if err != nil { return fmt.Errorf("send to webhook: %w", err) } - if status == http.StatusOK && res.Allowed { - return nil - } - if status == http.StatusForbidden && !res.Allowed { - return metaerrors.New( - ErrPermissionDenied, - map[string]string{"reason": res.Reason}, - ) - } - if status == http.StatusUnauthorized && !res.Allowed { - return metaerrors.New( - ErrUnauthenticated, - map[string]string{"reason": res.Reason}, + if status != http.StatusUnauthorized { + be.WebhookCache.Add( + cacheKey, + pkgtypes.Pair[int, *types.AuthWebhookResponse]{First: status, Second: res}, + be.Config.ParseAuthWebhookCacheTTL(), ) } - return fmt.Errorf("%d: %w", status, webhook.ErrUnexpectedResponse) + return handleWebhookResponse(status, res) +} + +// generateCacheKey creates a unique key for caching webhook responses. +func generateCacheKey(publicKey string, body []byte) string { + return fmt.Sprintf("%s:auth:%s", publicKey, body) +} + +// handleWebhookResponse processes the webhook response and returns an error if necessary. +func handleWebhookResponse(status int, res *types.AuthWebhookResponse) error { + switch { + case status == http.StatusOK && res.Allowed: + return nil + case status == http.StatusForbidden && !res.Allowed: + return metaerrors.New(ErrPermissionDenied, map[string]string{"reason": res.Reason}) + case status == http.StatusUnauthorized && !res.Allowed: + return metaerrors.New(ErrUnauthenticated, map[string]string{"reason": res.Reason}) + default: + return fmt.Errorf("%d: %w", status, webhook.ErrUnexpectedResponse) + } } From 745591ea8e1f0d4acc60ee92a3487b6a8602c427 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 19:25:40 +0900 Subject: [PATCH 07/18] fix webhook client test - remove cache --- pkg/webhook/client_test.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 41cd9b4f6..0e7a2669f 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -16,8 +16,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/yorkie-team/yorkie/pkg/cache" - pkgtypes "github.com/yorkie-team/yorkie/pkg/types" "github.com/yorkie-team/yorkie/pkg/webhook" ) @@ -69,11 +67,7 @@ func TestHMAC(t *testing.T) { })) defer testServer.Close() client := webhook.NewClient[testRequest, testResponse]( - cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *testResponse]]( - 100, - ), webhook.Options{ - CacheTTL: 10 * time.Second, MaxRetries: 0, MinWaitInterval: 2 * time.Second, MaxWaitInterval: 10 * time.Second, @@ -81,12 +75,14 @@ func TestHMAC(t *testing.T) { }, ) t.Run("webhook client with valid HMAC key test", func(t *testing.T) { + body, err := json.Marshal(testRequest{Name: t.Name()}) + assert.NoError(t, err) + resp, statusCode, err := client.Send( context.Background(), - ":auth", testServer.URL, secretKey, - testRequest{Name: t.Name()}, + body, ) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) @@ -95,12 +91,14 @@ func TestHMAC(t *testing.T) { }) t.Run("webhook client with invalid HMAC key test", func(t *testing.T) { + body, err := json.Marshal(testRequest{Name: t.Name()}) + assert.NoError(t, err) + resp, statusCode, err := client.Send( context.Background(), - ":auth", testServer.URL, wrongKey, - testRequest{Name: t.Name()}, + body, ) assert.Error(t, err) assert.Equal(t, http.StatusForbidden, statusCode) @@ -108,12 +106,14 @@ func TestHMAC(t *testing.T) { }) t.Run("webhook client without HMAC key test", func(t *testing.T) { + body, err := json.Marshal(testRequest{Name: t.Name()}) + assert.NoError(t, err) + resp, statusCode, err := client.Send( context.Background(), - ":auth", testServer.URL, "", - testRequest{Name: t.Name()}, + body, ) assert.Error(t, err) assert.Equal(t, http.StatusUnauthorized, statusCode) @@ -121,12 +121,14 @@ func TestHMAC(t *testing.T) { }) t.Run("webhook client with empty body test", func(t *testing.T) { + body, err := json.Marshal(testRequest{}) + assert.NoError(t, err) + resp, statusCode, err := client.Send( context.Background(), - ":auth", testServer.URL, secretKey, - testRequest{}, + body, ) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) From 6de812a2c040fada3b225778f156de169ef3e3d7 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 19:58:54 +0900 Subject: [PATCH 08/18] remove body.Close() - body is closed in retryablehttp.drainBody(resp.Body) --- pkg/webhook/client.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 6cf1d9fef..079764a59 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -109,13 +109,6 @@ func (c *Client[Req, Res]) Send( return nil, statusCode, fmt.Errorf("post webhook request: %w", err) } - defer func() { - io.Copy(io.Discard, resp.Body) - if cerr := resp.Body.Close(); cerr != nil { - // TODO(hackerwins): Consider to remove the dependency of logging. - logging.From(ctx).Error(cerr) - } - }() if !isExpectedCode(resp.StatusCode) { return nil, resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedStatusCode) From 24d1984aa72b7df5afedd4b149e6f39219eb594b Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 19:59:07 +0900 Subject: [PATCH 09/18] lint --- pkg/webhook/client.go | 12 ++++++------ server/backend/backend.go | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 079764a59..165c76647 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -25,14 +25,11 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "syscall" "time" "github.com/hashicorp/go-retryablehttp" - - "github.com/yorkie-team/yorkie/server/logging" ) var ( @@ -94,7 +91,6 @@ func (c *Client[Req, Res]) Send( url, hmacKey string, body []byte, ) (*Res, int, error) { - req, err := c.buildRequest(ctx, url, hmacKey, body) if err != nil { return nil, 0, fmt.Errorf("build request: %w", err) @@ -123,7 +119,11 @@ func (c *Client[Req, Res]) Send( } // buildRequest creates a new HTTP POST request with the appropriate headers. -func (c *Client[Req, Res]) buildRequest(ctx context.Context, url, hmacKey string, body []byte) (*retryablehttp.Request, error) { +func (c *Client[Req, Res]) buildRequest( + ctx context.Context, + url, hmacKey string, + body []byte, +) (*retryablehttp.Request, error) { req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, url, body) if err != nil { return nil, fmt.Errorf("create POST request with context: %w", err) @@ -152,7 +152,7 @@ func setSignature(req *retryablehttp.Request, data []byte, hmacKey string) error // shouldRetry returns true if the given error should be retried. // Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry -func shouldRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { +func shouldRetry(_ context.Context, resp *http.Response, err error) (bool, error) { // If the connection is reset, we should retry. if err != nil { var errno syscall.Errno diff --git a/server/backend/backend.go b/server/backend/backend.go index 6eaabb05d..57198cdab 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -22,13 +22,13 @@ package backend import ( "context" "fmt" - "github.com/yorkie-team/yorkie/pkg/webhook" "os" "time" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/cache" pkgtypes "github.com/yorkie-team/yorkie/pkg/types" + "github.com/yorkie-team/yorkie/pkg/webhook" "github.com/yorkie-team/yorkie/server/backend/background" "github.com/yorkie-team/yorkie/server/backend/database" memdb "github.com/yorkie-team/yorkie/server/backend/database/memory" From bfd2cde7c5435a17112a377c74c0a07fdb569955 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Wed, 5 Feb 2025 20:30:03 +0900 Subject: [PATCH 10/18] add config value - AuthWebhookMinWaitInterval - AuthWebhookRequestTimeout --- cmd/yorkie/server.go | 16 +++++++++++++ pkg/webhook/client.go | 7 +++--- server/backend/backend.go | 5 ++-- server/backend/config.go | 44 +++++++++++++++++++++++++++++++++++ server/backend/config_test.go | 14 +++++++++-- server/config.go | 10 ++++++++ server/config.sample.yml | 6 +++++ server/config_test.go | 8 +++++++ server/packs/packs_test.go | 2 ++ server/rpc/server_test.go | 2 ++ test/helper/helper.go | 4 ++++ 11 files changed, 110 insertions(+), 8 deletions(-) diff --git a/cmd/yorkie/server.go b/cmd/yorkie/server.go index 4e5ef04e4..911f6faec 100644 --- a/cmd/yorkie/server.go +++ b/cmd/yorkie/server.go @@ -48,6 +48,8 @@ var ( mongoPingTimeout time.Duration authWebhookMaxWaitInterval time.Duration + authWebhookMinWaitInterval time.Duration + authWebhookRequestTimeout time.Duration authWebhookCacheTTL time.Duration projectCacheTTL time.Duration @@ -64,6 +66,8 @@ func newServerCmd() *cobra.Command { conf.Backend.ClientDeactivateThreshold = clientDeactivateThreshold conf.Backend.AuthWebhookMaxWaitInterval = authWebhookMaxWaitInterval.String() + conf.Backend.AuthWebhookMinWaitInterval = authWebhookMinWaitInterval.String() + conf.Backend.AuthWebhookRequestTimeout = authWebhookRequestTimeout.String() conf.Backend.AuthWebhookCacheTTL = authWebhookCacheTTL.String() conf.Backend.ProjectCacheTTL = projectCacheTTL.String() @@ -307,6 +311,18 @@ func init() { server.DefaultAuthWebhookMaxWaitInterval, "Maximum wait interval for authorization webhook.", ) + cmd.Flags().DurationVar( + &authWebhookMinWaitInterval, + "auth-webhook-min-wait-interval", + server.DefaultAuthWebhookMinWaitInterval, + "Minimum wait interval for authorization webhook.", + ) + cmd.Flags().DurationVar( + &authWebhookRequestTimeout, + "auth-webhook-request-timeout", + server.DefaultAuthWebhookRequestTimeout, + "Maximum wait time per authorization webhook request.", + ) cmd.Flags().IntVar( &conf.Backend.AuthWebhookCacheSize, "auth-webhook-cache-size", diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 165c76647..7ba167bbf 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -68,9 +68,10 @@ func NewClient[Req any, Res any]( RetryMax: int(options.MaxRetries), RetryWaitMin: options.MinWaitInterval, RetryWaitMax: options.MaxWaitInterval, - CheckRetry: shouldRetry, - Logger: nil, - Backoff: retryablehttp.DefaultBackoff, + // Note(window9u): I think we could replace shouldRetry with `retryablehttp.DefaultRetryPolicy` + CheckRetry: shouldRetry, + Logger: nil, + Backoff: retryablehttp.DefaultBackoff, ErrorHandler: func(resp *http.Response, err error, numTries int) (*http.Response, error) { if err == nil && numTries == int(options.MaxRetries)+1 { return nil, ErrWebhookTimeout diff --git a/server/backend/backend.go b/server/backend/backend.go index 57198cdab..e55787c32 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -23,7 +23,6 @@ import ( "context" "fmt" "os" - "time" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/cache" @@ -94,9 +93,9 @@ func New( auth := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( webhook.Options{ MaxRetries: conf.AuthWebhookMaxRetries, - MinWaitInterval: 200 * time.Millisecond, + MinWaitInterval: conf.ParseAuthWebhookMinWaitInterval(), MaxWaitInterval: conf.ParseAuthWebhookMaxWaitInterval(), - RequestTimeout: 30 * time.Second, + RequestTimeout: conf.ParseAuthWebhookRequestTimeout(), }, ) diff --git a/server/backend/config.go b/server/backend/config.go index 89399ddb7..5827029ac 100644 --- a/server/backend/config.go +++ b/server/backend/config.go @@ -64,6 +64,12 @@ type Config struct { // AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. AuthWebhookMaxWaitInterval string `yaml:"AuthWebhookMaxWaitInterval"` + // AuthWebhookMinWaitInterval is the max interval that waits before retrying the authorization webhook. + AuthWebhookMinWaitInterval string `yaml:"AuthWebhookMinWaitInterval"` + + // AuthWebhookRequestTimeout is the max waiting time per auth webhook request + AuthWebhookRequestTimeout string `yaml:"AuthWebhookRequestTimeout"` + // AuthWebhookCacheSize is the cache size of the authorization webhook. AuthWebhookCacheSize int `yaml:"AuthWebhookCacheSize"` @@ -101,6 +107,22 @@ func (c *Config) Validate() error { ) } + if _, err := time.ParseDuration(c.AuthWebhookMinWaitInterval); err != nil { + return fmt.Errorf( + `invalid argument "%s" for "--auth-webhook-min-wait-interval" flag: %w`, + c.AuthWebhookMinWaitInterval, + err, + ) + } + + if _, err := time.ParseDuration(c.AuthWebhookRequestTimeout); err != nil { + return fmt.Errorf( + `invalid argument "%s" for "--auth-webhook-reqeust-timeout" flag: %w`, + c.AuthWebhookRequestTimeout, + err, + ) + } + if _, err := time.ParseDuration(c.AuthWebhookCacheTTL); err != nil { return fmt.Errorf( `invalid argument "%s" for "--auth-webhook-cache-ttl" flag: %w`, @@ -142,6 +164,28 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration { return result } +// ParseAuthWebhookMinWaitInterval returns max wait interval. +func (c *Config) ParseAuthWebhookMinWaitInterval() time.Duration { + result, err := time.ParseDuration(c.AuthWebhookMinWaitInterval) + if err != nil { + fmt.Fprintln(os.Stderr, "parse auth webhook min wait interval: %w", err) + os.Exit(1) + } + + return result +} + +// ParseAuthWebhookRequestTimeout returns max wait interval. +func (c *Config) ParseAuthWebhookRequestTimeout() time.Duration { + result, err := time.ParseDuration(c.AuthWebhookRequestTimeout) + if err != nil { + fmt.Fprintln(os.Stderr, "parse auth webhook request timeout: %w", err) + os.Exit(1) + } + + return result +} + // ParseAuthWebhookCacheTTL returns TTL for authorized cache. func (c *Config) ParseAuthWebhookCacheTTL() time.Duration { result, err := time.ParseDuration(c.AuthWebhookCacheTTL) diff --git a/server/backend/config_test.go b/server/backend/config_test.go index c4b8247eb..a7b3a34ba 100644 --- a/server/backend/config_test.go +++ b/server/backend/config_test.go @@ -29,6 +29,8 @@ func TestConfig(t *testing.T) { validConf := backend.Config{ ClientDeactivateThreshold: "1h", AuthWebhookMaxWaitInterval: "0ms", + AuthWebhookMinWaitInterval: "0ms", + AuthWebhookRequestTimeout: "0ms", AuthWebhookCacheTTL: "10s", ProjectCacheTTL: "10m", } @@ -43,11 +45,19 @@ func TestConfig(t *testing.T) { assert.Error(t, conf2.Validate()) conf3 := validConf - conf3.AuthWebhookCacheTTL = "s" + conf3.AuthWebhookMinWaitInterval = "3" assert.Error(t, conf3.Validate()) conf4 := validConf - conf4.ProjectCacheTTL = "10 minutes" + conf4.AuthWebhookRequestTimeout = "1" assert.Error(t, conf4.Validate()) + + conf5 := validConf + conf5.AuthWebhookCacheTTL = "s" + assert.Error(t, conf5.Validate()) + + conf6 := validConf + conf6.ProjectCacheTTL = "10 minutes" + assert.Error(t, conf6.Validate()) }) } diff --git a/server/config.go b/server/config.go index fd8589dbd..32e215382 100644 --- a/server/config.go +++ b/server/config.go @@ -62,6 +62,8 @@ const ( DefaultAuthWebhookMaxRetries = 10 DefaultAuthWebhookMaxWaitInterval = 3000 * time.Millisecond + DefaultAuthWebhookMinWaitInterval = 200 * time.Millisecond + DefaultAuthWebhookRequestTimeout = 30 * time.Second DefaultAuthWebhookCacheSize = 5000 DefaultAuthWebhookCacheTTL = 10 * time.Second DefaultProjectCacheSize = 256 @@ -185,6 +187,14 @@ func (c *Config) ensureDefaultValue() { c.Backend.AuthWebhookMaxWaitInterval = DefaultAuthWebhookMaxWaitInterval.String() } + if c.Backend.AuthWebhookMinWaitInterval == "" { + c.Backend.AuthWebhookMinWaitInterval = DefaultAuthWebhookMinWaitInterval.String() + } + + if c.Backend.AuthWebhookRequestTimeout == "" { + c.Backend.AuthWebhookRequestTimeout = DefaultAuthWebhookRequestTimeout.String() + } + if c.Backend.AuthWebhookCacheTTL == "" { c.Backend.AuthWebhookCacheTTL = DefaultAuthWebhookCacheTTL.String() } diff --git a/server/config.sample.yml b/server/config.sample.yml index b7f124b11..07fbbb932 100644 --- a/server/config.sample.yml +++ b/server/config.sample.yml @@ -71,6 +71,12 @@ Backend: # AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. AuthWebhookMaxWaitInterval: "3s" + # AuthWebhookMinWaitInterval is the min interval that waits before retrying the authorization webhook. + AuthWebhookMinWaitInterval: "200ms" + + # AuthWebhookRequestTimeout is the max waiting time per authorization webhook request. + AuthWebhookRequestTimeout: "30s" + # AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. AuthWebhookCacheTTL: "10s" diff --git a/server/config_test.go b/server/config_test.go index 5ca5008a9..d4b83ee41 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -70,6 +70,14 @@ func TestNewConfigFromFile(t *testing.T) { assert.NoError(t, err) assert.Equal(t, authWebhookMaxWaitInterval, server.DefaultAuthWebhookMaxWaitInterval) + authWebhookMinWaitInterval, err := time.ParseDuration(conf.Backend.AuthWebhookMinWaitInterval) + assert.NoError(t, err) + assert.Equal(t, authWebhookMinWaitInterval, server.DefaultAuthWebhookMinWaitInterval) + + authWebhookRequestTimeout, err := time.ParseDuration(conf.Backend.AuthWebhookRequestTimeout) + assert.NoError(t, err) + assert.Equal(t, authWebhookRequestTimeout, server.DefaultAuthWebhookRequestTimeout) + authWebhookCacheTTL, err := time.ParseDuration(conf.Backend.AuthWebhookCacheTTL) assert.NoError(t, err) assert.Equal(t, authWebhookCacheTTL, server.DefaultAuthWebhookCacheTTL) diff --git a/server/packs/packs_test.go b/server/packs/packs_test.go index 66494b923..b5ac2bede 100644 --- a/server/packs/packs_test.go +++ b/server/packs/packs_test.go @@ -102,6 +102,8 @@ func TestMain(m *testing.M) { AuthWebhookCacheSize: helper.AuthWebhookSize, AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: helper.AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: helper.AuthWebhookRequestTimeout.String(), ProjectCacheSize: helper.ProjectCacheSize, ProjectCacheTTL: helper.ProjectCacheTTL.String(), AdminTokenDuration: helper.AdminTokenDuration, diff --git a/server/rpc/server_test.go b/server/rpc/server_test.go index 81a2a1fdc..a0175417b 100644 --- a/server/rpc/server_test.go +++ b/server/rpc/server_test.go @@ -76,6 +76,8 @@ func TestMain(m *testing.M) { AuthWebhookCacheSize: helper.AuthWebhookSize, AuthWebhookCacheTTL: helper.AuthWebhookCacheTTL.String(), AuthWebhookMaxWaitInterval: helper.AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: helper.AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: helper.AuthWebhookRequestTimeout.String(), ProjectCacheSize: helper.ProjectCacheSize, ProjectCacheTTL: helper.ProjectCacheTTL.String(), AdminTokenDuration: helper.AdminTokenDuration, diff --git a/test/helper/helper.go b/test/helper/helper.go index 9068b5302..4137c7321 100644 --- a/test/helper/helper.go +++ b/test/helper/helper.go @@ -75,6 +75,8 @@ var ( SnapshotThreshold = int64(10) SnapshotWithPurgingChanges = false AuthWebhookMaxWaitInterval = 3 * gotime.Millisecond + AuthWebhookMinWaitInterval = 3 * gotime.Millisecond + AuthWebhookRequestTimeout = 100 * gotime.Millisecond AuthWebhookSize = 100 AuthWebhookCacheTTL = 10 * gotime.Second ProjectCacheSize = 256 @@ -264,6 +266,8 @@ func TestConfig() *server.Config { SnapshotThreshold: SnapshotThreshold, SnapshotWithPurgingChanges: SnapshotWithPurgingChanges, AuthWebhookMaxWaitInterval: AuthWebhookMaxWaitInterval.String(), + AuthWebhookMinWaitInterval: AuthWebhookMinWaitInterval.String(), + AuthWebhookRequestTimeout: AuthWebhookRequestTimeout.String(), AuthWebhookCacheSize: AuthWebhookSize, AuthWebhookCacheTTL: AuthWebhookCacheTTL.String(), ProjectCacheSize: ProjectCacheSize, From da6101acc03a2bb21466d0393ed2e6dae655d9ac Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 13:41:33 +0900 Subject: [PATCH 11/18] remove retryable http package --- go.mod | 5 -- go.sum | 4 - pkg/webhook/client.go | 172 ++++++++++++++++++++----------------- server/rpc/auth/webhook.go | 1 + 4 files changed, 93 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 04e281d52..7aad2ed61 100644 --- a/go.mod +++ b/go.mod @@ -31,11 +31,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) -require ( - github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/hashicorp/go-retryablehttp v0.7.7 // indirect -) - require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index f1759263d..20e11ddf6 100644 --- a/go.sum +++ b/go.sum @@ -176,13 +176,9 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/hackerwins/go-memdb v1.3.3-0.20211225080334-513a74641622 h1:7UYuTq6zV83XV4zqn14gUuTtcywzbxGhUnj+hr/MUrE= github.com/hackerwins/go-memdb v1.3.3-0.20211225080334-513a74641622/go.mod h1:uBTr1oQbtuMgd1SSGoR8YV27eT3sBHbYiNm53bMpgSg= -github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= -github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-immutable-radix v1.3.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= -github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 7ba167bbf..02917e14b 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -18,6 +18,7 @@ package webhook import ( + "bytes" "context" "crypto/hmac" "crypto/sha256" @@ -25,11 +26,12 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/http" "syscall" "time" - "github.com/hashicorp/go-retryablehttp" + "github.com/yorkie-team/yorkie/server/logging" ) var ( @@ -53,36 +55,19 @@ type Options struct { // Client is a client for the webhook. type Client[Req any, Res any] struct { - retryClient *retryablehttp.Client - options Options + client *http.Client + options Options } // NewClient creates a new instance of Client. func NewClient[Req any, Res any]( options Options, ) *Client[Req, Res] { - retryClient := &retryablehttp.Client{ - HTTPClient: &http.Client{ + return &Client[Req, Res]{ + client: &http.Client{ Timeout: options.RequestTimeout, }, - RetryMax: int(options.MaxRetries), - RetryWaitMin: options.MinWaitInterval, - RetryWaitMax: options.MaxWaitInterval, - // Note(window9u): I think we could replace shouldRetry with `retryablehttp.DefaultRetryPolicy` - CheckRetry: shouldRetry, - Logger: nil, - Backoff: retryablehttp.DefaultBackoff, - ErrorHandler: func(resp *http.Response, err error, numTries int) (*http.Response, error) { - if err == nil && numTries == int(options.MaxRetries)+1 { - return nil, ErrWebhookTimeout - } - return resp, fmt.Errorf("after %d attempts, errors were: %w", numTries, err) - }, - } - - return &Client[Req, Res]{ - retryClient: retryClient, - options: options, + options: options, } } @@ -92,102 +77,129 @@ func (c *Client[Req, Res]) Send( url, hmacKey string, body []byte, ) (*Res, int, error) { - req, err := c.buildRequest(ctx, url, hmacKey, body) + signature, err := createSignature(body, hmacKey) if err != nil { - return nil, 0, fmt.Errorf("build request: %w", err) + return nil, 0, fmt.Errorf("create signature: %w", err) } - resp, err := c.retryClient.Do(req) - if err != nil { - statusCode := 0 - if resp != nil { - statusCode = resp.StatusCode + var res Res + status, err := c.withExponentialBackoff(ctx, func() (int, error) { + req, err := c.buildRequest(ctx, url, signature, body) + if err != nil { + return 0, fmt.Errorf("build request: %w", err) } - return nil, statusCode, fmt.Errorf("post webhook request: %w", err) - } + resp, err := c.client.Do(req) + if err != nil { + return 0, fmt.Errorf("do request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + // TODO(hackerwins): Consider to remove the dependency of logging. + logging.From(ctx).Error(err) + } + }() - if !isExpectedCode(resp.StatusCode) { - return nil, resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedStatusCode) - } + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusUnauthorized && + resp.StatusCode != http.StatusForbidden { + return resp.StatusCode, ErrUnexpectedStatusCode + } - var res Res - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return nil, resp.StatusCode, ErrUnexpectedResponse + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return resp.StatusCode, ErrUnexpectedResponse + } + + return resp.StatusCode, nil + }) + if err != nil { + return nil, status, err } - return &res, resp.StatusCode, nil + return &res, status, nil } // buildRequest creates a new HTTP POST request with the appropriate headers. func (c *Client[Req, Res]) buildRequest( ctx context.Context, - url, hmacKey string, + url, hmac string, body []byte, -) (*retryablehttp.Request, error) { - req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, url, body) +) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(body)) if err != nil { return nil, fmt.Errorf("create POST request with context: %w", err) } req.Header.Set("Content-Type", "application/json") - if hmacKey != "" { - if err := setSignature(req, body, hmacKey); err != nil { - return nil, fmt.Errorf("set HMAC signature: %w", err) - } + + if hmac != "" { + req.Header.Set("X-Signature-256", hmac) } return req, nil } -// setSignature sets the HMAC signature header for the request. -func setSignature(req *retryablehttp.Request, data []byte, hmacKey string) error { +// createSignature sets the HMAC signature header for the request. +func createSignature(data []byte, hmacKey string) (string, error) { + if hmacKey == "" { + return "", nil + } mac := hmac.New(sha256.New, []byte(hmacKey)) if _, err := mac.Write(data); err != nil { - return fmt.Errorf("write HMAC body: %w", err) + return "", fmt.Errorf("write HMAC body: %w", err) } signatureHex := hex.EncodeToString(mac.Sum(nil)) - req.Header.Set("X-Signature-256", fmt.Sprintf("sha256=%s", signatureHex)) - return nil + return fmt.Sprintf("sha256=%s", signatureHex), nil } -// shouldRetry returns true if the given error should be retried. -// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry -func shouldRetry(_ context.Context, resp *http.Response, err error) (bool, error) { - // If the connection is reset, we should retry. - if err != nil { - var errno syscall.Errno - if errors.As(err, &errno) && errors.Is(errno, syscall.ECONNRESET) { - return true, nil +func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) { + var retries uint64 + var statusCode int + for retries <= c.options.MaxRetries { + statusCode, err := webhookFn() + if !shouldRetry(statusCode, err) { + if errors.Is(err, ErrUnexpectedStatusCode) { + return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) + } + + return statusCode, err } - return false, err - } + waitBeforeRetry := waitInterval(retries, c.options.MinWaitInterval, c.options.MaxWaitInterval) - if resp != nil { - code := resp.StatusCode - if isExpectedCode(code) { - return false, nil - } - if isRetryCode(code) { - return true, nil + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-time.After(waitBeforeRetry): } + + retries++ } - return false, nil + return statusCode, fmt.Errorf("unexpected status code from webhook %d: %w", statusCode, ErrWebhookTimeout) } -// isExpectedCode checks if the status code is acceptable. -func isExpectedCode(code int) bool { - return code == http.StatusOK || - code == http.StatusUnauthorized || - code == http.StatusForbidden +// waitInterval returns the interval of given retries. (2^retries * minWaitInterval) . +func waitInterval(retries uint64, minWaitInterval, maxWaitInterval time.Duration) time.Duration { + interval := time.Duration(math.Pow(2, float64(retries))) * minWaitInterval + if maxWaitInterval < interval { + return maxWaitInterval + } + + return interval } -// isRetryCode checks if the status code is one that should trigger a retry. -func isRetryCode(code int) bool { - return code == http.StatusInternalServerError || - code == http.StatusServiceUnavailable || - code == http.StatusGatewayTimeout || - code == http.StatusTooManyRequests +// shouldRetry returns true if the given error should be retried. +// Refer to https://github.com/kubernetes/kubernetes/search?q=DefaultShouldRetry +func shouldRetry(statusCode int, err error) bool { + // If the connection is reset, we should retry. + var errno syscall.Errno + if errors.As(err, &errno) { + return errors.Is(errno, syscall.ECONNRESET) + } + + return statusCode == http.StatusInternalServerError || + statusCode == http.StatusServiceUnavailable || + statusCode == http.StatusGatewayTimeout || + statusCode == http.StatusTooManyRequests } diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index a12537103..1d952e4a6 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -72,6 +72,7 @@ func verifyAccess( return fmt.Errorf("send to webhook: %w", err) } + // TODO(hackerwins): We should consider caching the response of Unauthorized as well. if status != http.StatusUnauthorized { be.WebhookCache.Add( cacheKey, From 7489ad601ad8d5f98d36d2593e671a05af9cff46 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 13:49:36 +0900 Subject: [PATCH 12/18] add test - retry test - request timeout - unexpected error handling --- pkg/webhook/client_test.go | 248 +++++++++++++++++++++++++++++-------- 1 file changed, 195 insertions(+), 53 deletions(-) diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 0e7a2669f..d4b1d8f99 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -29,6 +30,7 @@ type testResponse struct { Greeting string `json:"greeting"` } +// verifySignature verifies that the HMAC signature in the header matches the expected value. func verifySignature(signatureHeader, secret string, body []byte) error { mac := hmac.New(sha256.New, []byte(secret)) mac.Write(body) @@ -37,102 +39,242 @@ func verifySignature(signatureHeader, secret string, body []byte) error { if !hmac.Equal([]byte(signatureHeader), []byte(expectedSigHeader)) { return errors.New("signature validation failed") } - return nil } -func TestHMAC(t *testing.T) { - const secretKey = "my-secret-key" - const wrongKey = "wrong-key" - resData := testResponse{Greeting: "HMAC OK"} - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// newHMACTestServer creates a new httptest.Server that verifies the HMAC signature. +// It returns a valid JSON response if the signature is correct. +func newHMACTestServer(validSecret string, responseData testResponse) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { signatureHeader := r.Header.Get("X-Signature-256") if signatureHeader == "" { - w.WriteHeader(http.StatusUnauthorized) + http.Error(w, "unauthorized", http.StatusUnauthorized) return } + bodyBytes, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } - if err := verifySignature(signatureHeader, secretKey, bodyBytes); err != nil { - w.WriteHeader(http.StatusForbidden) + if err := verifySignature(signatureHeader, validSecret, bodyBytes); err != nil { + http.Error(w, "forbidden", http.StatusForbidden) return } + + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - assert.NoError(t, json.NewEncoder(w).Encode(resData)) + _ = json.NewEncoder(w).Encode(responseData) })) +} + +func newRetryServer(replyAfter int) *httptest.Server { + var requestCount int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count := int(atomic.AddInt32(&requestCount, 1)) + if count < replyAfter { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(testResponse{Greeting: "Recovered Response"}) + })) +} + +func newDelayServer(delayTime time.Duration) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(delayTime) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(testResponse{Greeting: "Recovered Response"}) + })) +} + +func TestHMAC(t *testing.T) { + const validSecret = "my-secret-key" + const invalidSecret = "wrong-key" + expectedResponse := testResponse{Greeting: "HMAC OK"} + + testServer := newHMACTestServer(validSecret, expectedResponse) defer testServer.Close() - client := webhook.NewClient[testRequest, testResponse]( - webhook.Options{ - MaxRetries: 0, - MinWaitInterval: 2 * time.Second, - MaxWaitInterval: 10 * time.Second, - RequestTimeout: 30 * time.Second, - }, - ) - t.Run("webhook client with valid HMAC key test", func(t *testing.T) { - body, err := json.Marshal(testRequest{Name: t.Name()}) + + client := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + RequestTimeout: 1 * time.Second, + }) + + t.Run("valid HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "ValidHMAC"} + body, err := json.Marshal(reqPayload) assert.NoError(t, err) - resp, statusCode, err := client.Send( - context.Background(), - testServer.URL, - secretKey, - body, - ) + resp, statusCode, err := client.Send(context.Background(), testServer.URL, validSecret, body) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) - assert.Equal(t, resData.Greeting, resp.Greeting) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) }) - t.Run("webhook client with invalid HMAC key test", func(t *testing.T) { - body, err := json.Marshal(testRequest{Name: t.Name()}) + t.Run("invalid HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "InvalidHMAC"} + body, err := json.Marshal(reqPayload) assert.NoError(t, err) - resp, statusCode, err := client.Send( - context.Background(), - testServer.URL, - wrongKey, - body, - ) + resp, statusCode, err := client.Send(context.Background(), testServer.URL, invalidSecret, body) assert.Error(t, err) + // The server responds with 403 Forbidden if the signature is invalid. assert.Equal(t, http.StatusForbidden, statusCode) assert.Nil(t, resp) }) - t.Run("webhook client without HMAC key test", func(t *testing.T) { - body, err := json.Marshal(testRequest{Name: t.Name()}) + t.Run("missing HMAC key test", func(t *testing.T) { + reqPayload := testRequest{Name: "MissingHMAC"} + body, err := json.Marshal(reqPayload) assert.NoError(t, err) - resp, statusCode, err := client.Send( - context.Background(), - testServer.URL, - "", - body, - ) + resp, statusCode, err := client.Send(context.Background(), testServer.URL, "", body) assert.Error(t, err) + // The server responds with 401 Unauthorized if no signature header is provided. assert.Equal(t, http.StatusUnauthorized, statusCode) assert.Nil(t, resp) }) - t.Run("webhook client with empty body test", func(t *testing.T) { - body, err := json.Marshal(testRequest{}) + t.Run("empty body test", func(t *testing.T) { + reqPayload := testRequest{} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := client.Send(context.Background(), testServer.URL, validSecret, body) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, statusCode) + assert.NotNil(t, resp) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) + }) +} + +func TestRetryRequest(t *testing.T) { + replyAfter := 4 + reachableRetries := replyAfter - 1 + unreachableRetries := replyAfter - 2 + server := newRetryServer(replyAfter) + defer server.Close() + + reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: uint64(reachableRetries), + MinWaitInterval: 1 * time.Millisecond, + MaxWaitInterval: 5 * time.Millisecond, + RequestTimeout: 10 * time.Millisecond, + }) + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: uint64(unreachableRetries), + MinWaitInterval: 1 * time.Millisecond, + MaxWaitInterval: 5 * time.Millisecond, + RequestTimeout: 10 * time.Millisecond, + }) + + t.Run("request fails with timeout test", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), server.URL, "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) + }) + + t.Run("request succeed after timeout", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) assert.NoError(t, err) - resp, statusCode, err := client.Send( - context.Background(), - testServer.URL, - secretKey, - body, - ) + resp, statusCode, err := reachableClient.Send(context.Background(), server.URL, "", body) + assert.NoError(t, err) + assert.Equal(t, 200, statusCode) + assert.NotNil(t, resp) + }) +} + +func TestRequestTimeout(t *testing.T) { + delayTime := 10 * time.Millisecond + server := newDelayServer(delayTime) + defer server.Close() + + reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + RequestTimeout: 15 * time.Millisecond, + }) + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + RequestTimeout: 5 * time.Millisecond, + }) + + t.Run("request succeed after timeout", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := reachableClient.Send(context.Background(), server.URL, "", body) assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) - assert.Equal(t, resData.Greeting, resp.Greeting) + }) + + t.Run("request fails with timeout test", func(t *testing.T) { + reqPayload := testRequest{Name: "TimeoutTest"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), server.URL, "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) + }) +} + +func TestErrorHandling(t *testing.T) { + server := newDelayServer(1 * time.Second) + defer server.Close() + + unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + MaxRetries: 0, + MinWaitInterval: 0, + MaxWaitInterval: 0, + RequestTimeout: 50 * time.Millisecond, + }) + + t.Run("request fails with context done test", func(t *testing.T) { + reqPayload := testRequest{Name: "ContextDone"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + resp, statusCode, err := unreachableClient.Send(ctx, server.URL, "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) + }) + + t.Run("request fails with unreachable url test", func(t *testing.T) { + reqPayload := testRequest{Name: "invalidURL"} + body, err := json.Marshal(reqPayload) + assert.NoError(t, err) + + resp, statusCode, err := unreachableClient.Send(context.Background(), "", "", body) + assert.Error(t, err) + assert.Equal(t, 0, statusCode) + assert.Nil(t, resp) }) } From dcca5af055c5fc46f7ca37b71f5af46750df2273 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 14:58:14 +0900 Subject: [PATCH 13/18] save status code - for distinguish status code 0 and shouldRetry status codes --- pkg/webhook/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 02917e14b..36c800c5c 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -155,8 +155,10 @@ func createSignature(data []byte, hmacKey string) (string, error) { func (c *Client[Req, Res]) withExponentialBackoff(ctx context.Context, webhookFn func() (int, error)) (int, error) { var retries uint64 var statusCode int + var err error + for retries <= c.options.MaxRetries { - statusCode, err := webhookFn() + statusCode, err = webhookFn() if !shouldRetry(statusCode, err) { if errors.Is(err, ErrUnexpectedStatusCode) { return statusCode, fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) From 17a8dc07d22e35b397625e70c7f6485a763c012b Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 14:59:18 +0900 Subject: [PATCH 14/18] lint --- pkg/webhook/client.go | 22 ++++++++++++--------- pkg/webhook/client_test.go | 40 ++++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 36c800c5c..757aa77ea 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -45,7 +45,7 @@ var ( ErrWebhookTimeout = errors.New("webhook timeout") ) -// Options are the options for the webhook client. +// Options are the options for the webhook httpClient. type Options struct { MaxRetries uint64 MinWaitInterval time.Duration @@ -53,10 +53,10 @@ type Options struct { RequestTimeout time.Duration } -// Client is a client for the webhook. +// Client is a httpClient for the webhook. type Client[Req any, Res any] struct { - client *http.Client - options Options + httpClient *http.Client + options Options } // NewClient creates a new instance of Client. @@ -64,7 +64,7 @@ func NewClient[Req any, Res any]( options Options, ) *Client[Req, Res] { return &Client[Req, Res]{ - client: &http.Client{ + httpClient: &http.Client{ Timeout: options.RequestTimeout, }, options: options, @@ -89,7 +89,7 @@ func (c *Client[Req, Res]) Send( return 0, fmt.Errorf("build request: %w", err) } - resp, err := c.client.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return 0, fmt.Errorf("do request: %w", err) } @@ -100,9 +100,7 @@ func (c *Client[Req, Res]) Send( } }() - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusUnauthorized && - resp.StatusCode != http.StatusForbidden { + if !isExpectedStatus(resp.StatusCode) { return resp.StatusCode, ErrUnexpectedStatusCode } @@ -205,3 +203,9 @@ func shouldRetry(statusCode int, err error) bool { statusCode == http.StatusGatewayTimeout || statusCode == http.StatusTooManyRequests } + +func isExpectedStatus(statusCode int) bool { + return statusCode == http.StatusOK || + statusCode == http.StatusUnauthorized || + statusCode == http.StatusForbidden +} diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index d4b1d8f99..8761cd135 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -44,7 +44,7 @@ func verifySignature(signatureHeader, secret string, body []byte) error { // newHMACTestServer creates a new httptest.Server that verifies the HMAC signature. // It returns a valid JSON response if the signature is correct. -func newHMACTestServer(validSecret string, responseData testResponse) *httptest.Server { +func newHMACTestServer(t *testing.T, validSecret string, responseData testResponse) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { signatureHeader := r.Header.Get("X-Signature-256") if signatureHeader == "" { @@ -65,11 +65,11 @@ func newHMACTestServer(validSecret string, responseData testResponse) *httptest. w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(responseData) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) })) } -func newRetryServer(replyAfter int) *httptest.Server { +func newRetryServer(t *testing.T, replyAfter int, responseData testResponse) *httptest.Server { var requestCount int32 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count := int(atomic.AddInt32(&requestCount, 1)) @@ -80,16 +80,16 @@ func newRetryServer(replyAfter int) *httptest.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(testResponse{Greeting: "Recovered Response"}) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) })) } -func newDelayServer(delayTime time.Duration) *httptest.Server { +func newDelayServer(t *testing.T, delayTime time.Duration, responseData testResponse) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(delayTime) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(testResponse{Greeting: "Recovered Response"}) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) })) } @@ -98,7 +98,7 @@ func TestHMAC(t *testing.T) { const invalidSecret = "wrong-key" expectedResponse := testResponse{Greeting: "HMAC OK"} - testServer := newHMACTestServer(validSecret, expectedResponse) + testServer := newHMACTestServer(t, validSecret, expectedResponse) defer testServer.Close() client := webhook.NewClient[testRequest, testResponse](webhook.Options{ @@ -161,7 +161,8 @@ func TestRetryRequest(t *testing.T) { replyAfter := 4 reachableRetries := replyAfter - 1 unreachableRetries := replyAfter - 2 - server := newRetryServer(replyAfter) + expectedResponse := testResponse{Greeting: "retry succeed"} + server := newRetryServer(t, replyAfter, expectedResponse) defer server.Close() reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ @@ -178,32 +179,35 @@ func TestRetryRequest(t *testing.T) { RequestTimeout: 10 * time.Millisecond, }) - t.Run("request fails with timeout test", func(t *testing.T) { - reqPayload := testRequest{Name: "TimeoutTest"} + t.Run("retry fail test", func(t *testing.T) { + reqPayload := testRequest{Name: "retry fails"} body, err := json.Marshal(reqPayload) assert.NoError(t, err) resp, statusCode, err := unreachableClient.Send(context.Background(), server.URL, "", body) assert.Error(t, err) - assert.Equal(t, 0, statusCode) + assert.ErrorContains(t, err, webhook.ErrWebhookTimeout.Error()) + assert.Equal(t, http.StatusServiceUnavailable, statusCode) assert.Nil(t, resp) }) - t.Run("request succeed after timeout", func(t *testing.T) { - reqPayload := testRequest{Name: "TimeoutTest"} + t.Run("retry succeed timeout", func(t *testing.T) { + reqPayload := testRequest{Name: "retry succeed"} body, err := json.Marshal(reqPayload) assert.NoError(t, err) resp, statusCode, err := reachableClient.Send(context.Background(), server.URL, "", body) assert.NoError(t, err) - assert.Equal(t, 200, statusCode) + assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) }) } func TestRequestTimeout(t *testing.T) { delayTime := 10 * time.Millisecond - server := newDelayServer(delayTime) + expectedResponse := testResponse{Greeting: "hello"} + server := newDelayServer(t, delayTime, expectedResponse) defer server.Close() reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ @@ -229,6 +233,7 @@ func TestRequestTimeout(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusOK, statusCode) assert.NotNil(t, resp) + assert.Equal(t, expectedResponse.Greeting, resp.Greeting) }) t.Run("request fails with timeout test", func(t *testing.T) { @@ -244,7 +249,8 @@ func TestRequestTimeout(t *testing.T) { } func TestErrorHandling(t *testing.T) { - server := newDelayServer(1 * time.Second) + expectedResponse := testResponse{Greeting: "hello"} + server := newRetryServer(t, 2, expectedResponse) defer server.Close() unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ @@ -263,7 +269,7 @@ func TestErrorHandling(t *testing.T) { defer cancel() resp, statusCode, err := unreachableClient.Send(ctx, server.URL, "", body) assert.Error(t, err) - assert.Equal(t, 0, statusCode) + assert.Equal(t, http.StatusServiceUnavailable, statusCode) assert.Nil(t, resp) }) From 871dcc7d7a16129ba163ea01192fa9729792efb8 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 15:29:47 +0900 Subject: [PATCH 15/18] fix default config - DefaultAuthWebhookMinWaitInterval - DefaultAuthWebhookRequestTimeout --- server/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/config.go b/server/config.go index 32e215382..6af8a684e 100644 --- a/server/config.go +++ b/server/config.go @@ -62,8 +62,8 @@ const ( DefaultAuthWebhookMaxRetries = 10 DefaultAuthWebhookMaxWaitInterval = 3000 * time.Millisecond - DefaultAuthWebhookMinWaitInterval = 200 * time.Millisecond - DefaultAuthWebhookRequestTimeout = 30 * time.Second + DefaultAuthWebhookMinWaitInterval = 100 * time.Millisecond + DefaultAuthWebhookRequestTimeout = 10 * time.Second DefaultAuthWebhookCacheSize = 5000 DefaultAuthWebhookCacheTTL = 10 * time.Second DefaultProjectCacheSize = 256 From d2b18e6c6aa928d1467e0b312c45bee7b6d7c2d8 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 16:00:17 +0900 Subject: [PATCH 16/18] apply code rabbit's suggestion --- pkg/webhook/client_test.go | 14 +++++++++----- server/backend/config.go | 8 ++++---- server/rpc/auth/webhook.go | 4 ++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 8761cd135..73d1c8320 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -86,13 +86,17 @@ func newRetryServer(t *testing.T, replyAfter int, responseData testResponse) *ht func newDelayServer(t *testing.T, delayTime time.Duration, responseData testResponse) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(delayTime) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - assert.NoError(t, json.NewEncoder(w).Encode(responseData)) + ctx, cancel := context.WithTimeout(r.Context(), delayTime) + defer cancel() + + select { + case <-ctx.Done(): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + assert.NoError(t, json.NewEncoder(w).Encode(responseData)) + } })) } - func TestHMAC(t *testing.T) { const validSecret = "my-secret-key" const invalidSecret = "wrong-key" diff --git a/server/backend/config.go b/server/backend/config.go index 5827029ac..9f1767c6c 100644 --- a/server/backend/config.go +++ b/server/backend/config.go @@ -64,7 +64,7 @@ type Config struct { // AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. AuthWebhookMaxWaitInterval string `yaml:"AuthWebhookMaxWaitInterval"` - // AuthWebhookMinWaitInterval is the max interval that waits before retrying the authorization webhook. + // AuthWebhookMinWaitInterval is the min interval that waits before retrying the authorization webhook. AuthWebhookMinWaitInterval string `yaml:"AuthWebhookMinWaitInterval"` // AuthWebhookRequestTimeout is the max waiting time per auth webhook request @@ -117,7 +117,7 @@ func (c *Config) Validate() error { if _, err := time.ParseDuration(c.AuthWebhookRequestTimeout); err != nil { return fmt.Errorf( - `invalid argument "%s" for "--auth-webhook-reqeust-timeout" flag: %w`, + `invalid argument "%s" for "--auth-webhook-request-timeout" flag: %w`, c.AuthWebhookRequestTimeout, err, ) @@ -164,7 +164,7 @@ func (c *Config) ParseAuthWebhookMaxWaitInterval() time.Duration { return result } -// ParseAuthWebhookMinWaitInterval returns max wait interval. +// ParseAuthWebhookMinWaitInterval returns min wait interval. func (c *Config) ParseAuthWebhookMinWaitInterval() time.Duration { result, err := time.ParseDuration(c.AuthWebhookMinWaitInterval) if err != nil { @@ -175,7 +175,7 @@ func (c *Config) ParseAuthWebhookMinWaitInterval() time.Duration { return result } -// ParseAuthWebhookRequestTimeout returns max wait interval. +// ParseAuthWebhookRequestTimeout returns request timeout. func (c *Config) ParseAuthWebhookRequestTimeout() time.Duration { result, err := time.ParseDuration(c.AuthWebhookRequestTimeout) if err != nil { diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index 1d952e4a6..cb27e5df3 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -91,6 +91,10 @@ func generateCacheKey(publicKey string, body []byte) string { // handleWebhookResponse processes the webhook response and returns an error if necessary. func handleWebhookResponse(status int, res *types.AuthWebhookResponse) error { + if res == nil { + return fmt.Errorf("nil response for status %d: %w", status, webhook.ErrUnexpectedResponse) + } + switch { case status == http.StatusOK && res.Allowed: return nil From 27bad001ea41fae01454924f8bfb67075f183f60 Mon Sep 17 00:00:00 2001 From: Changyu Moon Date: Thu, 6 Feb 2025 16:05:00 +0900 Subject: [PATCH 17/18] fix config.sample.yml --- server/config.sample.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/config.sample.yml b/server/config.sample.yml index 07fbbb932..16cee8318 100644 --- a/server/config.sample.yml +++ b/server/config.sample.yml @@ -72,10 +72,10 @@ Backend: AuthWebhookMaxWaitInterval: "3s" # AuthWebhookMinWaitInterval is the min interval that waits before retrying the authorization webhook. - AuthWebhookMinWaitInterval: "200ms" + AuthWebhookMinWaitInterval: "100ms" # AuthWebhookRequestTimeout is the max waiting time per authorization webhook request. - AuthWebhookRequestTimeout: "30s" + AuthWebhookRequestTimeout: "10s" # AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. AuthWebhookCacheTTL: "10s" From 638932fb6356559c839b2f3c8f09136988ad5417 Mon Sep 17 00:00:00 2001 From: Youngteac Hong Date: Thu, 6 Feb 2025 18:08:50 +0900 Subject: [PATCH 18/18] Revise the codes --- cmd/yorkie/server.go | 24 ++++++++++++------------ pkg/webhook/client.go | 2 +- pkg/webhook/client_test.go | 12 ++++++------ server/backend/backend.go | 18 ++++++++++-------- server/config.go | 2 +- server/config.sample.yml | 14 +++++++------- 6 files changed, 37 insertions(+), 35 deletions(-) diff --git a/cmd/yorkie/server.go b/cmd/yorkie/server.go index 911f6faec..6d9714f9f 100644 --- a/cmd/yorkie/server.go +++ b/cmd/yorkie/server.go @@ -299,29 +299,29 @@ func init() { server.DefaultSnapshotDisableGC, "Whether to disable garbage collection of snapshots.", ) + cmd.Flags().DurationVar( + &authWebhookRequestTimeout, + "auth-webhook-request-timeout", + server.DefaultAuthWebhookRequestTimeout, + "Timeout for each authorization webhook request.", + ) cmd.Flags().Uint64Var( &conf.Backend.AuthWebhookMaxRetries, "auth-webhook-max-retries", server.DefaultAuthWebhookMaxRetries, - "Maximum number of retries for an authorization webhook.", - ) - cmd.Flags().DurationVar( - &authWebhookMaxWaitInterval, - "auth-webhook-max-wait-interval", - server.DefaultAuthWebhookMaxWaitInterval, - "Maximum wait interval for authorization webhook.", + "Maximum number of retries for authorization webhook.", ) cmd.Flags().DurationVar( &authWebhookMinWaitInterval, "auth-webhook-min-wait-interval", server.DefaultAuthWebhookMinWaitInterval, - "Minimum wait interval for authorization webhook.", + "Minimum wait interval between retries(exponential backoff).", ) cmd.Flags().DurationVar( - &authWebhookRequestTimeout, - "auth-webhook-request-timeout", - server.DefaultAuthWebhookRequestTimeout, - "Maximum wait time per authorization webhook request.", + &authWebhookMaxWaitInterval, + "auth-webhook-max-wait-interval", + server.DefaultAuthWebhookMaxWaitInterval, + "Maximum wait interval between retries(exponential backoff).", ) cmd.Flags().IntVar( &conf.Backend.AuthWebhookCacheSize, diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 757aa77ea..f91afa34b 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -47,10 +47,10 @@ var ( // Options are the options for the webhook httpClient. type Options struct { + RequestTimeout time.Duration MaxRetries uint64 MinWaitInterval time.Duration MaxWaitInterval time.Duration - RequestTimeout time.Duration } // Client is a httpClient for the webhook. diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 73d1c8320..b0afa1a34 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -161,7 +161,7 @@ func TestHMAC(t *testing.T) { }) } -func TestRetryRequest(t *testing.T) { +func TestBackoff(t *testing.T) { replyAfter := 4 reachableRetries := replyAfter - 1 unreachableRetries := replyAfter - 2 @@ -170,17 +170,17 @@ func TestRetryRequest(t *testing.T) { defer server.Close() reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 10 * time.Millisecond, MaxRetries: uint64(reachableRetries), MinWaitInterval: 1 * time.Millisecond, MaxWaitInterval: 5 * time.Millisecond, - RequestTimeout: 10 * time.Millisecond, }) unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 10 * time.Millisecond, MaxRetries: uint64(unreachableRetries), MinWaitInterval: 1 * time.Millisecond, MaxWaitInterval: 5 * time.Millisecond, - RequestTimeout: 10 * time.Millisecond, }) t.Run("retry fail test", func(t *testing.T) { @@ -215,17 +215,17 @@ func TestRequestTimeout(t *testing.T) { defer server.Close() reachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 15 * time.Millisecond, MaxRetries: 0, MinWaitInterval: 0, MaxWaitInterval: 0, - RequestTimeout: 15 * time.Millisecond, }) unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 5 * time.Millisecond, MaxRetries: 0, MinWaitInterval: 0, MaxWaitInterval: 0, - RequestTimeout: 5 * time.Millisecond, }) t.Run("request succeed after timeout", func(t *testing.T) { @@ -258,10 +258,10 @@ func TestErrorHandling(t *testing.T) { defer server.Close() unreachableClient := webhook.NewClient[testRequest, testResponse](webhook.Options{ + RequestTimeout: 50 * time.Millisecond, MaxRetries: 0, MinWaitInterval: 0, MaxWaitInterval: 0, - RequestTimeout: 50 * time.Millisecond, }) t.Run("request fails with context done test", func(t *testing.T) { diff --git a/server/backend/backend.go b/server/backend/backend.go index e55787c32..9d1a95839 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -86,11 +86,11 @@ func New( conf.Hostname = hostname } - // 02. Create auth webhook client and cache. - cache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( + // 02. Create the webhook webhookCache and client. + webhookCache := cache.NewLRUExpireCache[string, pkgtypes.Pair[int, *types.AuthWebhookResponse]]( conf.AuthWebhookCacheSize, ) - auth := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( + webhookClient := webhook.NewClient[types.AuthWebhookRequest, types.AuthWebhookResponse]( webhook.Options{ MaxRetries: conf.AuthWebhookMaxRetries, MinWaitInterval: conf.ParseAuthWebhookMinWaitInterval(), @@ -155,11 +155,13 @@ func New( ) return &Backend{ - Config: conf, - WebhookCache: cache, - WebhookClient: auth, - Locker: locker, - PubSub: pubsub, + Config: conf, + + WebhookCache: webhookCache, + WebhookClient: webhookClient, + + Locker: locker, + PubSub: pubsub, Metrics: metrics, DB: db, diff --git a/server/config.go b/server/config.go index 6af8a684e..ec63a8958 100644 --- a/server/config.go +++ b/server/config.go @@ -60,10 +60,10 @@ const ( DefaultSnapshotWithPurgingChanges = false DefaultSnapshotDisableGC = false + DefaultAuthWebhookRequestTimeout = 3 * time.Second DefaultAuthWebhookMaxRetries = 10 DefaultAuthWebhookMaxWaitInterval = 3000 * time.Millisecond DefaultAuthWebhookMinWaitInterval = 100 * time.Millisecond - DefaultAuthWebhookRequestTimeout = 10 * time.Second DefaultAuthWebhookCacheSize = 5000 DefaultAuthWebhookCacheTTL = 10 * time.Second DefaultProjectCacheSize = 256 diff --git a/server/config.sample.yml b/server/config.sample.yml index 16cee8318..71b09fb5d 100644 --- a/server/config.sample.yml +++ b/server/config.sample.yml @@ -65,17 +65,17 @@ Backend: # AuthWebhookMethods is the list of methods to use for authorization. AuthWebhookMethods: [] - # AuthWebhookMaxRetries is the max count that retries the authorization webhook. - AuthWebhookMaxRetries: 10 + # AuthWebhookRequestTimeout is the timeout for each authorization webhook request. + AuthWebhookRequestTimeout: "3s" - # AuthWebhookMaxWaitInterval is the max interval that waits before retrying the authorization webhook. - AuthWebhookMaxWaitInterval: "3s" + # AuthWebhookMaxRetries is the max number of retries for the authorization webhook. + AuthWebhookMaxRetries: 10 - # AuthWebhookMinWaitInterval is the min interval that waits before retrying the authorization webhook. + # AuthWebhookMinWaitInterval is the minimum wait interval between retries(exponential backoff). AuthWebhookMinWaitInterval: "100ms" - # AuthWebhookRequestTimeout is the max waiting time per authorization webhook request. - AuthWebhookRequestTimeout: "10s" + # AuthWebhookMaxWaitInterval is the maximum wait interval between retries(exponential backoff). + AuthWebhookMaxWaitInterval: "3s" # AuthWebhookCacheTTL is the TTL value to set when caching the authorized result. AuthWebhookCacheTTL: "10s"