From e5b56e48db7ba2557d7e00ea20b535301562574a Mon Sep 17 00:00:00 2001 From: Yen-Ming Lee Date: Thu, 22 Feb 2024 18:01:39 -0800 Subject: [PATCH] Refactor and reorder the code and document --- .github/workflows/tests.yml | 8 +- README.md | 10 +- concurrent_buffer.go | 108 +++++++------------- docker-compose.yml | 10 +- fixedwindow.go | 92 +++++++++-------- fixedwindow_test.go | 2 +- leakybucket.go | 188 +++++++++++++++++------------------ leakybucket_test.go | 4 +- slidingwindow.go | 130 ++++++++++++------------ slidingwindow_test.go | 2 +- tokenbucket.go | 191 ++++++++++++++++++------------------ tokenbucket_test.go | 4 +- 12 files changed, 349 insertions(+), 400 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a6da20..55e01cb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,10 @@ jobs: ALLOW_EMPTY_PASSWORD: yes ports: - 6379:6379 + memcached: + image: bitnami/memcached + ports: + - 11211:11211 consul: image: bitnami/consul ports: @@ -33,10 +37,6 @@ jobs: ALLOW_ANONYMOUS_LOGIN: yes ports: - 2181:2181 - memcached: - image: bitnami/memcached - ports: - - 11211:11211 steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index 4b522fc..6c86c47 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,9 @@ Most common implementations are already provided. - [`Token bucket`](https://en.wikipedia.org/wiki/Token_bucket) - in-memory (local) - redis + - memcached - etcd - dynamodb - - memcached Allows requests at a certain input rate with possible bursts configured by the capacity parameter. The output rate equals to the input rate. @@ -22,9 +22,9 @@ Most common implementations are already provided. - [`Leaky bucket`](https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue) - in-memory (local) - redis + - memcached - etcd - dynamodb - - memcached Puts requests in a FIFO queue to be processed at a constant rate. There are no restrictions on the input rate except for the capacity of the queue. @@ -33,8 +33,8 @@ Most common implementations are already provided. - [`Fixed window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) - in-memory (local) - redis - - dynamodb - memcached + - dynamodb Simple and resources efficient algorithm that does not need a lock. Precision may be adjusted by the size of the window. @@ -43,8 +43,8 @@ Most common implementations are already provided. - [`Sliding window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) - in-memory (local) - redis - - dynamodb - memcached + - dynamodb Smoothes out the bursts around the boundary between 2 adjacent windows. Needs as twice more memory as the `Fixed Window` algorithm (2 windows instead of 1 at a time). @@ -133,7 +133,7 @@ Supported backends: Run tests locally: ```bash -docker-compose up -d # start etcd, Redis, zookeeper, consul, memcached, and localstack +docker-compose up -d # start etcd, Redis, memcached, zookeeper, consul, and localstack ETCD_ENDPOINTS="127.0.0.1:2379" REDIS_ADDR="127.0.0.1:6379" ZOOKEEPER_ENDPOINTS="127.0.0.1" CONSUL_ADDR="127.0.0.1:8500" AWS_ADDR="127.0.0.1:8000" MEMCACHED_ADDR="127.0.0.1:11211" go test -race -v ``` diff --git a/concurrent_buffer.go b/concurrent_buffer.go index 0a42482..73c1569 100644 --- a/concurrent_buffer.go +++ b/concurrent_buffer.go @@ -176,7 +176,7 @@ type SortedSetNode struct { Value string } -// Add adds the request with the given key to the serialize slice in Memcached and returns the total number of requests in it. +// Add adds the request with the given key to the slice in Memcached and returns the total number of requests in it. // It also removes the keys with expired TTL. func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (int64, error) { var err error @@ -191,7 +191,6 @@ func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (in item, err = c.cli.Get(c.key) if err != nil { if !errors.Is(err, memcache.ErrCacheMiss) { - err = errors.Wrap(err, "failed to Get") return } } else { @@ -200,7 +199,6 @@ func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (in var oldNodes []SortedSetNode err = gob.NewDecoder(b).Decode(&oldNodes) if err != nil { - err = errors.Wrap(err, "failed to Decode") return } for len(oldNodes) > 0 { @@ -215,30 +213,17 @@ func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (in var b bytes.Buffer err = gob.NewEncoder(&b).Encode(newNodes) if err != nil { - err = errors.Wrap(err, "failed to Encode") return } + item = &memcache.Item{ + Key: c.key, + Value: b.Bytes(), + CasID: casId, + } if casId > 0 { - err = c.cli.CompareAndSwap(&memcache.Item{ - Key: c.key, - Value: b.Bytes(), - Expiration: int32(c.clock.Now().Add(c.ttl).Unix()), - CasID: casId, - }) - if err != nil { - err = errors.Wrap(err, "failed to CompareAndSwap") - return - } + err = c.cli.CompareAndSwap(item) } else { - err = c.cli.Add(&memcache.Item{ - Key: c.key, - Value: b.Bytes(), - Expiration: int32(c.clock.Now().Add(c.ttl).Unix()), - }) - if err != nil { - err = errors.Wrap(err, "failed to Add") - return - } + err = c.cli.Add(item) } }() @@ -250,15 +235,14 @@ func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (in if err != nil { if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { return c.Add(ctx, element) - } else { - return 0, errors.Wrap(err, "failed to add in memcached") } + return 0, errors.Wrap(err, "failed to add in memcached") } return int64(len(newNodes)), nil } } -// Remove removes the request identified by the key from the serialized slice in Memcached. +// Remove removes the request identified by the key from the slice in Memcached. func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) error { var err error now := c.clock.Now() @@ -269,60 +253,40 @@ func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) erro if err != nil { if errors.Is(err, memcache.ErrCacheMiss) { return nil - } else { - return errors.Wrap(err, "failed to Get") - } - } else { - casId = item.CasID - var oldNodes []SortedSetNode - err = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes) - if err != nil { - return errors.Wrap(err, "failed to Decode") } - for len(oldNodes) > 0 { - node := oldNodes[0] - oldNodes = oldNodes[1:] - if node.Score > now.UnixNano() { - if node.Value == key && !deleted { - deleted = true - } else { - newNodes = append(newNodes, node) - } + return errors.Wrap(err, "failed to Get") + } + casId = item.CasID + var oldNodes []SortedSetNode + err = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes) + if err != nil { + return errors.Wrap(err, "failed to Decode") + } + for len(oldNodes) > 0 { + node := oldNodes[0] + oldNodes = oldNodes[1:] + if node.Score > now.UnixNano() { + if node.Value == key && !deleted { + deleted = true + } else { + newNodes = append(newNodes, node) } } } + var b bytes.Buffer err = gob.NewEncoder(&b).Encode(newNodes) if err != nil { return errors.Wrap(err, "failed to Encode") } - if casId > 0 { - err = c.cli.CompareAndSwap(&memcache.Item{ - Key: c.key, - Value: b.Bytes(), - Expiration: int32(c.clock.Now().Add(c.ttl).Unix()), - CasID: casId, - }) - if err != nil { - if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { - return c.Remove(ctx, key) - } else { - return errors.Wrap(err, "failed to CompareAndSwap") - } - } - } else { - err = c.cli.Add(&memcache.Item{ - Key: c.key, - Value: b.Bytes(), - Expiration: int32(c.clock.Now().Add(c.ttl).Unix()), - }) - if err != nil { - if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { - return c.Remove(ctx, key) - } else { - return errors.Wrap(err, "failed to Add") - } - } + item = &memcache.Item{ + Key: c.key, + Value: b.Bytes(), + CasID: casId, + } + err = c.cli.CompareAndSwap(item) + if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { + return c.Remove(ctx, key) } - return err + return errors.Wrap(err, "failed to CompareAndSwap") } diff --git a/docker-compose.yml b/docker-compose.yml index 73354ad..9e6083d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,11 @@ services: ports: - "6379:6379" + memcached: + image: bitnami/memcached + ports: + - "11211:11211" + consul: image: bitnami/consul ports: @@ -32,8 +37,3 @@ services: command: "-jar DynamoDBLocal.jar -inMemory" ports: - "8000:8000" - - memcached: - image: bitnami/memcached - ports: - - "11211:11211" diff --git a/fixedwindow.go b/fixedwindow.go index 92d55e4..6d8b890 100644 --- a/fixedwindow.go +++ b/fixedwindow.go @@ -133,6 +133,51 @@ func (f *FixedWindowRedis) Increment(ctx context.Context, window time.Time, ttl } } +// FixedWindowMemcached implements FixedWindow in Memcached. +type FixedWindowMemcached struct { + cli *memcache.Client + prefix string +} + +// NewFixedWindowMemcached returns a new instance of FixedWindowMemcached. +// Prefix is the key prefix used to store all the keys used in this implementation in Memcached. +func NewFixedWindowMemcached(cli *memcache.Client, prefix string) *FixedWindowMemcached { + return &FixedWindowMemcached{cli: cli, prefix: prefix + ":FixedWindow"} +} + +// Increment increments the window's counter in Memcached. +func (f *FixedWindowMemcached) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { + var newValue uint64 + var err error + done := make(chan struct{}) + go func() { + defer close(done) + key := fmt.Sprintf("%s:%d", f.prefix, window.UnixNano()) + newValue, err = f.cli.Increment(key, 1) + if err != nil && errors.Is(err, memcache.ErrCacheMiss) { + newValue = 1 + item := &memcache.Item{ + Key: key, + Value: []byte(strconv.FormatUint(newValue, 10)), + } + err = f.cli.Add(item) + } + }() + + select { + case <-done: + if err != nil { + if errors.Is(err, memcache.ErrNotStored) { + return f.Increment(ctx, window, ttl) + } + return 0, errors.Wrap(err, "failed to Increment or Add") + } + return int64(newValue), err + case <-ctx.Done(): + return 0, ctx.Err() + } +} + // FixedWindowDynamoDB implements FixedWindow in DynamoDB. type FixedWindowDynamoDB struct { client *dynamodb.Client @@ -205,50 +250,3 @@ func (f *FixedWindowDynamoDB) Increment(ctx context.Context, window time.Time, t return int64(count), nil } - -// FixedWindowMemcached implements FixedWindow in Memcached. -type FixedWindowMemcached struct { - cli *memcache.Client - prefix string -} - -// NewFixedWindowMemcached returns a new instance of FixedWindowMemcached. -// Prefix is the key prefix used to store all the keys used in this implementation in Memcached. -func NewFixedWindowMemcached(cli *memcache.Client, prefix string) *FixedWindowMemcached { - return &FixedWindowMemcached{cli: cli, prefix: prefix + ":FixedWindow"} -} - -// Increment increments the window's counter in Memcached. -func (f *FixedWindowMemcached) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { - var newValue uint64 - var err error - done := make(chan struct{}) - go func() { - defer close(done) - key := fmt.Sprintf("%s:%d", f.prefix, window.UnixNano()) - newValue, err = f.cli.Increment(key, 1) - if err != nil && errors.Is(err, memcache.ErrCacheMiss) { - newValue = 1 - err = f.cli.Add(&memcache.Item{ - Key: key, - Value: []byte(strconv.FormatUint(newValue, 10)), - Expiration: int32(time.Now().Add(ttl).Unix()), - }) - } - }() - - select { - case <-done: - if err != nil { - if errors.Is(err, memcache.ErrNotStored) { - return f.Increment(ctx, window, ttl) - } else { - return 0, errors.Wrap(err, "memcached transaction failed") - } - } else { - return int64(newValue), err - } - case <-ctx.Done(): - return 0, ctx.Err() - } -} diff --git a/fixedwindow_test.go b/fixedwindow_test.go index d7b4d0c..0b33710 100644 --- a/fixedwindow_test.go +++ b/fixedwindow_test.go @@ -22,8 +22,8 @@ func (s *LimitersTestSuite) fixedWindowIncrementers() []l.FixedWindowIncrementer return []l.FixedWindowIncrementer{ l.NewFixedWindowInMemory(), l.NewFixedWindowRedis(s.redisClient, uuid.New().String()), - l.NewFixedWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), l.NewFixedWindowMemcached(s.memcacheClient, uuid.New().String()), + l.NewFixedWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), } } diff --git a/leakybucket.go b/leakybucket.go index 5f41693..adad103 100644 --- a/leakybucket.go +++ b/leakybucket.go @@ -5,7 +5,6 @@ import ( "context" "encoding/gob" "fmt" - "github.com/bradfitz/gomemcache/memcache" "reflect" "strconv" "sync" @@ -15,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/bradfitz/gomemcache/memcache" "github.com/pkg/errors" "github.com/redis/go-redis/v9" "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" @@ -399,6 +399,96 @@ func (t *LeakyBucketRedis) SetState(ctx context.Context, state LeakyBucketState) return errors.Wrap(err, "failed to save keys to redis") } +// LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend. +type LeakyBucketMemcached struct { + cli *memcache.Client + key string + ttl time.Duration + raceCheck bool + casId uint64 +} + +// NewLeakyBucketMemcached creates a new LeakyBucketMemcached instance. +// Key is the key used to store all the keys used in this implementation in Memcached. +// TTL is the TTL of the stored keys. +// +// If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then +// ErrRaceCondition is returned. +func NewLeakyBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *LeakyBucketMemcached { + return &LeakyBucketMemcached{cli: cli, key: key + ":LeakyBucket", ttl: ttl, raceCheck: raceCheck} +} + +// State gets the bucket's state from Memcached. +func (t *LeakyBucketMemcached) State(ctx context.Context) (LeakyBucketState, error) { + var item *memcache.Item + var err error + state := LeakyBucketState{} + done := make(chan struct{}, 1) + go func() { + defer close(done) + item, err = t.cli.Get(t.key) + }() + + select { + case <-done: + + case <-ctx.Done(): + return state, ctx.Err() + } + + if err != nil { + if errors.Is(err, memcache.ErrCacheMiss) { + // Keys don't exist, return an empty state. + return state, nil + } + return state, errors.Wrap(err, "failed to get keys from memcached") + } + b := bytes.NewBuffer(item.Value) + err = gob.NewDecoder(b).Decode(&state) + if err != nil { + return state, errors.Wrap(err, "failed to Decode") + } + t.casId = item.CasID + return state, nil +} + +// SetState updates the state in Memcached. +// The provided fencing token is checked on the Memcached side before saving the keys. +func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketState) error { + var err error + done := make(chan struct{}, 1) + var b bytes.Buffer + err = gob.NewEncoder(&b).Encode(state) + if err != nil { + return errors.Wrap(err, "failed to Encode") + } + go func() { + defer close(done) + item := &memcache.Item{ + Key: t.key, + Value: b.Bytes(), + CasID: t.casId, + } + if t.raceCheck && t.casId > 0 { + err = t.cli.CompareAndSwap(item) + } else { + err = t.cli.Set(item) + } + }() + + select { + case <-done: + + case <-ctx.Done(): + return ctx.Err() + } + + if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { + return ErrRaceCondition + } + return errors.Wrap(err, "failed to save keys to memcached") +} + // LeakyBucketDynamoDB is a DyanamoDB implementation of a LeakyBucketStateBackend. type LeakyBucketDynamoDB struct { client *dynamodb.Client @@ -524,99 +614,3 @@ func (t *LeakyBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput return state, nil } - -// LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend. -type LeakyBucketMemcached struct { - cli *memcache.Client - key string - ttl time.Duration - raceCheck bool - casId uint64 -} - -// NewLeakyBucketMemcached creates a new LeakyBucketMemcached instance. -// Key is the key used to store all the keys used in this implementation in Memcached. -// TTL is the TTL of the stored keys. -// -// If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then -// ErrRaceCondition is returned. -func NewLeakyBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *LeakyBucketMemcached { - return &LeakyBucketMemcached{cli: cli, key: key + ":LeakyBucket", ttl: ttl, raceCheck: raceCheck} -} - -// State gets the bucket's state from Memcached. -func (t *LeakyBucketMemcached) State(ctx context.Context) (LeakyBucketState, error) { - var item *memcache.Item - var err error - state := LeakyBucketState{} - done := make(chan struct{}, 1) - go func() { - defer close(done) - item, err = t.cli.Get(t.key) - }() - - select { - case <-done: - - case <-ctx.Done(): - return state, ctx.Err() - } - - if err != nil { - if errors.Is(err, memcache.ErrCacheMiss) { - // Keys don't exist, return an empty state. - return state, nil - } else { - return state, errors.Wrap(err, "failed to get keys from memcached") - } - } - b := bytes.NewBuffer(item.Value) - err = gob.NewDecoder(b).Decode(&state) - if err != nil { - return state, errors.Wrap(err, "failed to Decide") - } - t.casId = item.CasID - return state, nil -} - -// SetState updates the state in Memcached. -// The provided fencing token is checked on the Memcached side before saving the keys. -func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketState) error { - var err error - done := make(chan struct{}, 1) - var b bytes.Buffer - err = gob.NewEncoder(&b).Encode(state) - if err != nil { - return errors.Wrap(err, "failed to Encode") - } - go func() { - defer close(done) - if t.raceCheck && t.casId > 0 { - err = t.cli.CompareAndSwap(&memcache.Item{ - Key: t.key, - Value: b.Bytes(), - Expiration: int32(time.Now().Add(t.ttl).Unix()), - CasID: t.casId, - }) - } else { - err = t.cli.Set(&memcache.Item{ - Key: t.key, - Value: b.Bytes(), - Expiration: int32(time.Now().Add(t.ttl).Unix()), - }) - } - }() - - select { - case <-done: - - case <-ctx.Done(): - return ctx.Err() - } - - if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { - return ErrRaceCondition - } else { - return errors.Wrap(err, "failed to save keys to memcached") - } -} diff --git a/leakybucket_test.go b/leakybucket_test.go index 75feb62..b21f9b3 100644 --- a/leakybucket_test.go +++ b/leakybucket_test.go @@ -28,10 +28,10 @@ func (s *LimitersTestSuite) leakyBucketBackends() []l.LeakyBucketStateBackend { l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, true), l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), time.Second, false), l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), time.Second, true), - l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), - l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, false), l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, true), + l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), + l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), } } diff --git a/slidingwindow.go b/slidingwindow.go index 3de7fb6..f467473 100644 --- a/slidingwindow.go +++ b/slidingwindow.go @@ -153,6 +153,70 @@ func (s *SlidingWindowRedis) Increment(ctx context.Context, prev, curr time.Time } } +// SlidingWindowMemcached implements SlidingWindow in Memcached. +type SlidingWindowMemcached struct { + cli *memcache.Client + prefix string +} + +// NewSlidingWindowMemcached creates a new instance of SlidingWindowMemcached. +func NewSlidingWindowMemcached(cli *memcache.Client, prefix string) *SlidingWindowMemcached { + return &SlidingWindowMemcached{cli: cli, prefix: prefix + ":SlidingWindow"} +} + +// Increment increments the current window's counter in Memcached and returns the number of requests in the previous window +// and the current one. +func (s *SlidingWindowMemcached) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { + var prevCount uint64 + var currCount uint64 + var err error + done := make(chan struct{}) + go func() { + defer close(done) + + var item *memcache.Item + prevKey := fmt.Sprintf("%s:%d", s.prefix, prev.UnixNano()) + item, err = s.cli.Get(prevKey) + if err != nil { + if errors.Is(err, memcache.ErrCacheMiss) { + err = nil + prevCount = 0 + } else { + return + } + } else { + prevCount, err = strconv.ParseUint(string(item.Value), 10, 64) + if err != nil { + return + } + } + + currKey := fmt.Sprintf("%s:%d", s.prefix, curr.UnixNano()) + currCount, err = s.cli.Increment(currKey, 1) + if err != nil && errors.Is(err, memcache.ErrCacheMiss) { + currCount = 1 + item = &memcache.Item{ + Key: currKey, + Value: []byte(strconv.FormatUint(currCount, 10)), + } + err = s.cli.Add(item) + } + }() + + select { + case <-done: + if err != nil { + if errors.Is(err, memcache.ErrNotStored) { + return s.Increment(ctx, prev, curr, ttl) + } + return 0, 0, err + } + return int64(prevCount), int64(currCount), nil + case <-ctx.Done(): + return 0, 0, ctx.Err() + } +} + // SlidingWindowDynamoDB implements SlidingWindow in DynamoDB. type SlidingWindowDynamoDB struct { client *dynamodb.Client @@ -270,69 +334,3 @@ func (s *SlidingWindowDynamoDB) Increment(ctx context.Context, prev, curr time.T return priorCount, currentCount, nil } - -// SlidingWindowMemcached implements SlidingWindow in Memcached. -type SlidingWindowMemcached struct { - cli *memcache.Client - prefix string -} - -// NewSlidingWindowMemcached creates a new instance of SlidingWindowMemcached. -func NewSlidingWindowMemcached(cli *memcache.Client, prefix string) *SlidingWindowMemcached { - return &SlidingWindowMemcached{cli: cli, prefix: prefix + ":SlidingWindow"} -} - -// Increment increments the current window's counter in Memcached and returns the number of requests in the previous window -// and the current one. -func (s *SlidingWindowMemcached) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { - var prevCount uint64 - var currCount uint64 - var err error - done := make(chan struct{}) - go func() { - defer close(done) - - var item *memcache.Item - prevKey := fmt.Sprintf("%s:%d", s.prefix, prev.UnixNano()) - item, err = s.cli.Get(prevKey) - if err != nil { - if errors.Is(err, memcache.ErrCacheMiss) { - err = nil - prevCount = 0 - } else { - return - } - } else { - prevCount, err = strconv.ParseUint(string(item.Value), 10, 64) - if err != nil { - return - } - } - - currKey := fmt.Sprintf("%s:%d", s.prefix, curr.UnixNano()) - currCount, err = s.cli.Increment(currKey, 1) - if err != nil && errors.Is(err, memcache.ErrCacheMiss) { - currCount = 1 - err = s.cli.Add(&memcache.Item{ - Key: currKey, - Value: []byte(strconv.FormatUint(currCount, 10)), - Expiration: int32(time.Now().Add(ttl).Unix()), - }) - } - }() - - select { - case <-done: - if err != nil { - if errors.Is(err, memcache.ErrNotStored) { - return s.Increment(ctx, prev, curr, ttl) - } else { - return 0, 0, err - } - } else { - return int64(prevCount), int64(currCount), nil - } - case <-ctx.Done(): - return 0, 0, ctx.Err() - } -} diff --git a/slidingwindow_test.go b/slidingwindow_test.go index 7baf230..215a20e 100644 --- a/slidingwindow_test.go +++ b/slidingwindow_test.go @@ -22,8 +22,8 @@ func (s *LimitersTestSuite) slidingWindowIncrementers() []l.SlidingWindowIncreme return []l.SlidingWindowIncrementer{ l.NewSlidingWindowInMemory(), l.NewSlidingWindowRedis(s.redisClient, uuid.New().String()), - l.NewSlidingWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), l.NewSlidingWindowMemcached(s.memcacheClient, uuid.New().String()), + l.NewSlidingWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), } } diff --git a/tokenbucket.go b/tokenbucket.go index a67dbaa..5bd95d0 100644 --- a/tokenbucket.go +++ b/tokenbucket.go @@ -5,7 +5,6 @@ import ( "context" "encoding/gob" "fmt" - "github.com/bradfitz/gomemcache/memcache" "strconv" "sync" "time" @@ -14,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/bradfitz/gomemcache/memcache" "github.com/pkg/errors" "github.com/redis/go-redis/v9" "go.etcd.io/etcd/api/v3/mvccpb" @@ -481,6 +481,98 @@ func (t *TokenBucketRedis) SetState(ctx context.Context, state TokenBucketState) return errors.Wrap(err, "failed to save keys to redis") } +// TokenBucketMemcached is a Memcached implementation of a TokenBucketStateBackend. +// +// Memcached is a distributed memory object caching system. +type TokenBucketMemcached struct { + cli *memcache.Client + key string + ttl time.Duration + raceCheck bool + casId uint64 +} + +// NewTokenBucketMemcached creates a new TokenBucketMemcached instance. +// Key is the key used to store all the keys used in this implementation in Memcached. +// TTL is the TTL of the stored keys. +// +// If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then +// ErrRaceCondition is returned. +// This adds an extra overhead since a Lua script has to be executed on the Memcached side which locks the entire database. +func NewTokenBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *TokenBucketMemcached { + return &TokenBucketMemcached{cli: cli, key: key + ":TokenBucket", ttl: ttl, raceCheck: raceCheck} +} + +// State gets the bucket's state from Memcached. +func (t *TokenBucketMemcached) State(ctx context.Context) (TokenBucketState, error) { + var item *memcache.Item + var state TokenBucketState + var err error + + done := make(chan struct{}, 1) + t.casId = 0 + + go func() { + defer close(done) + item, err = t.cli.Get(t.key) + }() + + select { + case <-done: + + case <-ctx.Done(): + return state, ctx.Err() + } + + if err != nil { + if errors.Is(err, memcache.ErrCacheMiss) { + // Keys don't exist, return the initial state. + return state, nil + } + return state, errors.Wrap(err, "failed to get key from memcached") + } + b := bytes.NewBuffer(item.Value) + err = gob.NewDecoder(b).Decode(&state) + if err != nil { + return state, errors.Wrap(err, "failed to Decode") + } + t.casId = item.CasID + return state, nil +} + +// SetState updates the state in Memcached. +func (t *TokenBucketMemcached) SetState(ctx context.Context, state TokenBucketState) error { + var err error + done := make(chan struct{}, 1) + var b bytes.Buffer + err = gob.NewEncoder(&b).Encode(state) + if err != nil { + return errors.Wrap(err, "failed to Encode") + } + go func() { + defer close(done) + item := &memcache.Item{ + Key: t.key, + Value: b.Bytes(), + CasID: t.casId, + } + if t.raceCheck && t.casId > 0 { + err = t.cli.CompareAndSwap(item) + } else { + err = t.cli.Set(item) + } + }() + + select { + case <-done: + + case <-ctx.Done(): + return ctx.Err() + } + + return errors.Wrap(err, "failed to save keys to memcached") +} + // TokenBucketDynamoDB is a DynamoDB implementation of a TokenBucketStateBackend. type TokenBucketDynamoDB struct { client *dynamodb.Client @@ -609,100 +701,3 @@ func (t *TokenBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput return state, nil } - -// TokenBucketMemcached is a Memcached implementation of a TokenBucketStateBackend. -// -// Memcached is a distributed memory object caching system. -type TokenBucketMemcached struct { - cli *memcache.Client - key string - ttl time.Duration - raceCheck bool - casId uint64 -} - -// NewTokenBucketMemcached creates a new TokenBucketMemcached instance. -// Key is the key used to store all the keys used in this implementation in Memcached. -// TTL is the TTL of the stored keys. -// -// If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then -// ErrRaceCondition is returned. -// This adds an extra overhead since a Lua script has to be executed on the Memcached side which locks the entire database. -func NewTokenBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *TokenBucketMemcached { - return &TokenBucketMemcached{cli: cli, key: key + ":TokenBucket", ttl: ttl, raceCheck: raceCheck} -} - -// State gets the bucket's state from Memcached. -func (t *TokenBucketMemcached) State(ctx context.Context) (TokenBucketState, error) { - var item *memcache.Item - var state TokenBucketState - var err error - - done := make(chan struct{}, 1) - t.casId = 0 - - go func() { - defer close(done) - item, err = t.cli.Get(t.key) - }() - - select { - case <-done: - - case <-ctx.Done(): - return state, ctx.Err() - } - - if err != nil { - if errors.Is(err, memcache.ErrCacheMiss) { - // Keys don't exist, return the initial state. - return state, nil - } else { - return state, errors.Wrap(err, "failed to get key from memcached") - } - } - b := bytes.NewBuffer(item.Value) - err = gob.NewDecoder(b).Decode(&state) - if err != nil { - return state, errors.Wrap(err, "failed to Decide") - } - t.casId = item.CasID - return state, nil -} - -// SetState updates the state in Memcached. -func (t *TokenBucketMemcached) SetState(ctx context.Context, state TokenBucketState) error { - var err error - done := make(chan struct{}, 1) - var b bytes.Buffer - err = gob.NewEncoder(&b).Encode(state) - if err != nil { - return errors.Wrap(err, "failed to Encode") - } - go func() { - defer close(done) - if t.raceCheck && t.casId > 0 { - err = t.cli.CompareAndSwap(&memcache.Item{ - Key: t.key, - Value: b.Bytes(), - Expiration: int32(time.Now().Add(t.ttl).Unix()), - CasID: t.casId, - }) - } else { - err = t.cli.Set(&memcache.Item{ - Key: t.key, - Value: b.Bytes(), - Expiration: int32(time.Now().Add(t.ttl).Unix()), - }) - } - }() - - select { - case <-done: - - case <-ctx.Done(): - return ctx.Err() - } - - return errors.Wrap(err, "failed to save keys to memcached") -} diff --git a/tokenbucket_test.go b/tokenbucket_test.go index 664f762..ced16a6 100644 --- a/tokenbucket_test.go +++ b/tokenbucket_test.go @@ -85,10 +85,10 @@ func (s *LimitersTestSuite) tokenBucketBackends() []l.TokenBucketStateBackend { l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, true), l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), time.Second, false), l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), time.Second, true), - l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), - l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, false), l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, true), + l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), + l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), } }