From 44451828843542837c2a3197a84a7a685a941b3d Mon Sep 17 00:00:00 2001 From: Yen-Ming Lee Date: Thu, 29 Feb 2024 11:52:34 -0800 Subject: [PATCH] add TestConcurrentBufferDuplicateKeys --- concurrent_buffer.go | 17 ++++++----------- concurrent_buffer_test.go | 12 ++++++++++++ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/concurrent_buffer.go b/concurrent_buffer.go index b69d540..e0c2dd9 100644 --- a/concurrent_buffer.go +++ b/concurrent_buffer.go @@ -199,7 +199,7 @@ func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (in var oldNodes []SortedSetNode _ = gob.NewDecoder(b).Decode(&oldNodes) for _, node := range oldNodes { - if node.CreatedAt > now.UnixNano() { + if node.CreatedAt > now.UnixNano() && node.Value != element { newNodes = append(newNodes, node) } } @@ -239,8 +239,7 @@ func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) erro var err error now := c.clock.Now() var newNodes []SortedSetNode - var casId uint64 = 0 - deleted := false + var casID uint64 item, err := c.cli.Get(c.key) if err != nil { if errors.Is(err, memcache.ErrCacheMiss) { @@ -248,16 +247,12 @@ func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) erro } return errors.Wrap(err, "failed to Get") } - casId = item.CasID + casID = item.CasID var oldNodes []SortedSetNode _ = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes) for _, node := range oldNodes { - if node.CreatedAt > now.UnixNano() { - if node.Value == key && !deleted { - deleted = true - } else { - newNodes = append(newNodes, node) - } + if node.CreatedAt > now.UnixNano() && node.Value != key { + newNodes = append(newNodes, node) } } @@ -266,7 +261,7 @@ func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) erro item = &memcache.Item{ Key: c.key, Value: b.Bytes(), - CasID: casId, + CasID: casID, } err = c.cli.CompareAndSwap(item) if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { diff --git a/concurrent_buffer_test.go b/concurrent_buffer_test.go index 89ae954..2c8b803 100644 --- a/concurrent_buffer_test.go +++ b/concurrent_buffer_test.go @@ -89,3 +89,15 @@ func (s *LimitersTestSuite) TestConcurrentBufferExpiredKeys() { s.NoError(buffer.Limit(context.TODO(), "key3")) } } + +func (s *LimitersTestSuite) TestConcurrentBufferDuplicateKeys() { + clock := newFakeClock() + capacity := int64(2) + ttl := time.Second + for _, buffer := range s.concurrentBuffers(capacity, ttl, clock) { + s.Require().NoError(buffer.Limit(context.TODO(), "key1")) + s.Require().NoError(buffer.Limit(context.TODO(), "key2")) + // No error is expected as it should just update the timestamp of the existing key. + s.NoError(buffer.Limit(context.TODO(), "key1")) + } +}