Skip to content

Commit 4c46468

Browse files
authored
Merge pull request #2292 from AlexanderYastrebov/ring/reduce-set-addrs-shards-locking-2
fix: reduce `SetAddrs` shards lock contention
2 parents 5502cf6 + 7c6f677 commit 4c46468

File tree

3 files changed

+252
-33
lines changed

3 files changed

+252
-33
lines changed

internal_test.go

+118
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package redis
33
import (
44
"context"
55
"fmt"
6+
"reflect"
7+
"sync"
8+
"sync/atomic"
69
"testing"
710
"time"
811

@@ -107,6 +110,7 @@ func TestRingSetAddrsAndRebalanceRace(t *testing.T) {
107110
}
108111
},
109112
})
113+
defer ring.Close()
110114

111115
// Continuously update addresses by adding and removing one address
112116
updatesDone := make(chan struct{})
@@ -156,13 +160,127 @@ func BenchmarkRingShardingRebalanceLocked(b *testing.B) {
156160
}
157161

158162
ring := NewRing(opts)
163+
defer ring.Close()
159164

160165
b.ResetTimer()
161166
for i := 0; i < b.N; i++ {
162167
ring.sharding.rebalanceLocked()
163168
}
164169
}
165170

171+
type testCounter struct {
172+
mu sync.Mutex
173+
t *testing.T
174+
m map[string]int
175+
}
176+
177+
func newTestCounter(t *testing.T) *testCounter {
178+
return &testCounter{t: t, m: make(map[string]int)}
179+
}
180+
181+
func (ct *testCounter) increment(key string) {
182+
ct.mu.Lock()
183+
defer ct.mu.Unlock()
184+
ct.m[key]++
185+
}
186+
187+
func (ct *testCounter) expect(values map[string]int) {
188+
ct.mu.Lock()
189+
defer ct.mu.Unlock()
190+
ct.t.Helper()
191+
if !reflect.DeepEqual(values, ct.m) {
192+
ct.t.Errorf("expected %v != actual %v", values, ct.m)
193+
}
194+
}
195+
196+
func TestRingShardsCleanup(t *testing.T) {
197+
const (
198+
ringShard1Name = "ringShardOne"
199+
ringShard2Name = "ringShardTwo"
200+
201+
ringShard1Addr = "shard1.test"
202+
ringShard2Addr = "shard2.test"
203+
)
204+
205+
t.Run("closes unused shards", func(t *testing.T) {
206+
closeCounter := newTestCounter(t)
207+
208+
ring := NewRing(&RingOptions{
209+
Addrs: map[string]string{
210+
ringShard1Name: ringShard1Addr,
211+
ringShard2Name: ringShard2Addr,
212+
},
213+
NewClient: func(opt *Options) *Client {
214+
c := NewClient(opt)
215+
c.baseClient.onClose = func() error {
216+
closeCounter.increment(opt.Addr)
217+
return nil
218+
}
219+
return c
220+
},
221+
})
222+
closeCounter.expect(map[string]int{})
223+
224+
// no change due to the same addresses
225+
ring.SetAddrs(map[string]string{
226+
ringShard1Name: ringShard1Addr,
227+
ringShard2Name: ringShard2Addr,
228+
})
229+
closeCounter.expect(map[string]int{})
230+
231+
ring.SetAddrs(map[string]string{
232+
ringShard1Name: ringShard1Addr,
233+
})
234+
closeCounter.expect(map[string]int{ringShard2Addr: 1})
235+
236+
ring.SetAddrs(map[string]string{
237+
ringShard2Name: ringShard2Addr,
238+
})
239+
closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
240+
241+
ring.Close()
242+
closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2})
243+
})
244+
245+
t.Run("closes created shards if ring was closed", func(t *testing.T) {
246+
createCounter := newTestCounter(t)
247+
closeCounter := newTestCounter(t)
248+
249+
var (
250+
ring *Ring
251+
shouldClose int32
252+
)
253+
254+
ring = NewRing(&RingOptions{
255+
Addrs: map[string]string{
256+
ringShard1Name: ringShard1Addr,
257+
},
258+
NewClient: func(opt *Options) *Client {
259+
if atomic.LoadInt32(&shouldClose) != 0 {
260+
ring.Close()
261+
}
262+
createCounter.increment(opt.Addr)
263+
c := NewClient(opt)
264+
c.baseClient.onClose = func() error {
265+
closeCounter.increment(opt.Addr)
266+
return nil
267+
}
268+
return c
269+
},
270+
})
271+
createCounter.expect(map[string]int{ringShard1Addr: 1})
272+
closeCounter.expect(map[string]int{})
273+
274+
atomic.StoreInt32(&shouldClose, 1)
275+
276+
ring.SetAddrs(map[string]string{
277+
ringShard2Name: ringShard2Addr,
278+
})
279+
createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
280+
closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
281+
})
282+
}
283+
166284
//------------------------------------------------------------------------------
167285

168286
type timeoutErr struct {

ring.go

+41-27
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ type ringSharding struct {
219219
hash ConsistentHash
220220
numShard int
221221
onNewNode []func(rdb *Client)
222+
223+
// ensures exclusive access to SetAddrs so there is no need
224+
// to hold mu for the duration of potentially long shard creation
225+
setAddrsMu sync.Mutex
222226
}
223227

224228
type ringShards struct {
@@ -245,46 +249,62 @@ func (c *ringSharding) OnNewNode(fn func(rdb *Client)) {
245249
// decrease number of shards, that you use. It will reuse shards that
246250
// existed before and close the ones that will not be used anymore.
247251
func (c *ringSharding) SetAddrs(addrs map[string]string) {
248-
c.mu.Lock()
252+
c.setAddrsMu.Lock()
253+
defer c.setAddrsMu.Unlock()
249254

255+
cleanup := func(shards map[string]*ringShard) {
256+
for addr, shard := range shards {
257+
if err := shard.Client.Close(); err != nil {
258+
internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err)
259+
}
260+
}
261+
}
262+
263+
c.mu.RLock()
250264
if c.closed {
251-
c.mu.Unlock()
265+
c.mu.RUnlock()
252266
return
253267
}
268+
existing := c.shards
269+
c.mu.RUnlock()
270+
271+
shards, created, unused := c.newRingShards(addrs, existing)
254272

255-
shards, cleanup := c.newRingShards(addrs, c.shards)
273+
c.mu.Lock()
274+
if c.closed {
275+
cleanup(created)
276+
c.mu.Unlock()
277+
return
278+
}
256279
c.shards = shards
257280
c.rebalanceLocked()
258281
c.mu.Unlock()
259282

260-
cleanup()
283+
cleanup(unused)
261284
}
262285

263286
func (c *ringSharding) newRingShards(
264-
addrs map[string]string, existingShards *ringShards,
265-
) (*ringShards, func()) {
266-
shardMap := make(map[string]*ringShard) // indexed by addr
267-
unusedShards := make(map[string]*ringShard) // indexed by addr
268-
269-
if existingShards != nil {
270-
for _, shard := range existingShards.list {
271-
addr := shard.Client.opt.Addr
272-
shardMap[addr] = shard
273-
unusedShards[addr] = shard
274-
}
275-
}
287+
addrs map[string]string, existing *ringShards,
288+
) (shards *ringShards, created, unused map[string]*ringShard) {
289+
290+
shards = &ringShards{m: make(map[string]*ringShard, len(addrs))}
291+
created = make(map[string]*ringShard) // indexed by addr
292+
unused = make(map[string]*ringShard) // indexed by addr
276293

277-
shards := &ringShards{
278-
m: make(map[string]*ringShard),
294+
if existing != nil {
295+
for _, shard := range existing.list {
296+
unused[shard.addr] = shard
297+
}
279298
}
280299

281300
for name, addr := range addrs {
282-
if shard, ok := shardMap[addr]; ok {
301+
if shard, ok := unused[addr]; ok {
283302
shards.m[name] = shard
284-
delete(unusedShards, addr)
303+
delete(unused, addr)
285304
} else {
286305
shard := newRingShard(c.opt, addr)
287306
shards.m[name] = shard
307+
created[addr] = shard
288308

289309
for _, fn := range c.onNewNode {
290310
fn(shard.Client)
@@ -296,13 +316,7 @@ func (c *ringSharding) newRingShards(
296316
shards.list = append(shards.list, shard)
297317
}
298318

299-
return shards, func() {
300-
for addr, shard := range unusedShards {
301-
if err := shard.Client.Close(); err != nil {
302-
internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err)
303-
}
304-
}
305-
}
319+
return
306320
}
307321

308322
func (c *ringSharding) List() []*ringShard {

ring_test.go

+93-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"strconv"
99
"sync"
10+
"testing"
1011
"time"
1112

1213
. "github.com/onsi/ginkgo"
@@ -123,15 +124,15 @@ var _ = Describe("Redis Ring", func() {
123124
})
124125
Expect(ring.Len(), 1)
125126
gotShard := ring.ShardByName("ringShardOne")
126-
Expect(gotShard).To(Equal(wantShard))
127+
Expect(gotShard).To(BeIdenticalTo(wantShard))
127128

128129
ring.SetAddrs(map[string]string{
129130
"ringShardOne": ":" + ringShard1Port,
130131
"ringShardTwo": ":" + ringShard2Port,
131132
})
132133
Expect(ring.Len(), 2)
133134
gotShard = ring.ShardByName("ringShardOne")
134-
Expect(gotShard).To(Equal(wantShard))
135+
Expect(gotShard).To(BeIdenticalTo(wantShard))
135136
})
136137

137138
It("uses 3 shards after setting it to 3 shards", func() {
@@ -155,8 +156,8 @@ var _ = Describe("Redis Ring", func() {
155156
gotShard1 := ring.ShardByName(shardName1)
156157
gotShard2 := ring.ShardByName(shardName2)
157158
gotShard3 := ring.ShardByName(shardName3)
158-
Expect(gotShard1).To(Equal(wantShard1))
159-
Expect(gotShard2).To(Equal(wantShard2))
159+
Expect(gotShard1).To(BeIdenticalTo(wantShard1))
160+
Expect(gotShard2).To(BeIdenticalTo(wantShard2))
160161
Expect(gotShard3).ToNot(BeNil())
161162

162163
ring.SetAddrs(map[string]string{
@@ -167,8 +168,8 @@ var _ = Describe("Redis Ring", func() {
167168
gotShard1 = ring.ShardByName(shardName1)
168169
gotShard2 = ring.ShardByName(shardName2)
169170
gotShard3 = ring.ShardByName(shardName3)
170-
Expect(gotShard1).To(Equal(wantShard1))
171-
Expect(gotShard2).To(Equal(wantShard2))
171+
Expect(gotShard1).To(BeIdenticalTo(wantShard1))
172+
Expect(gotShard2).To(BeIdenticalTo(wantShard2))
172173
Expect(gotShard3).To(BeNil())
173174
})
174175
})
@@ -739,3 +740,89 @@ var _ = Describe("Ring Tx timeout", func() {
739740
testTimeout()
740741
})
741742
})
743+
744+
func TestRingSetAddrsContention(t *testing.T) {
745+
const (
746+
ringShard1Name = "ringShardOne"
747+
ringShard2Name = "ringShardTwo"
748+
)
749+
750+
for _, port := range []string{ringShard1Port, ringShard2Port} {
751+
if _, err := startRedis(port); err != nil {
752+
t.Fatal(err)
753+
}
754+
}
755+
756+
t.Cleanup(func() {
757+
for _, p := range processes {
758+
if err := p.Close(); err != nil {
759+
t.Errorf("Failed to stop redis process: %v", err)
760+
}
761+
}
762+
processes = nil
763+
})
764+
765+
ring := redis.NewRing(&redis.RingOptions{
766+
Addrs: map[string]string{
767+
"ringShardOne": ":" + ringShard1Port,
768+
},
769+
NewClient: func(opt *redis.Options) *redis.Client {
770+
// Simulate slow shard creation
771+
time.Sleep(100 * time.Millisecond)
772+
return redis.NewClient(opt)
773+
},
774+
})
775+
defer ring.Close()
776+
777+
if _, err := ring.Ping(context.Background()).Result(); err != nil {
778+
t.Fatal(err)
779+
}
780+
781+
// Continuously update addresses by adding and removing one address
782+
updatesDone := make(chan struct{})
783+
defer func() { close(updatesDone) }()
784+
go func() {
785+
ticker := time.NewTicker(10 * time.Millisecond)
786+
defer ticker.Stop()
787+
for i := 0; ; i++ {
788+
select {
789+
case <-ticker.C:
790+
if i%2 == 0 {
791+
ring.SetAddrs(map[string]string{
792+
ringShard1Name: ":" + ringShard1Port,
793+
})
794+
} else {
795+
ring.SetAddrs(map[string]string{
796+
ringShard1Name: ":" + ringShard1Port,
797+
ringShard2Name: ":" + ringShard2Port,
798+
})
799+
}
800+
case <-updatesDone:
801+
return
802+
}
803+
}
804+
}()
805+
806+
var pings, errClosed int
807+
timer := time.NewTimer(1 * time.Second)
808+
for running := true; running; pings++ {
809+
select {
810+
case <-timer.C:
811+
running = false
812+
default:
813+
if _, err := ring.Ping(context.Background()).Result(); err != nil {
814+
if err == redis.ErrClosed {
815+
// The shard client could be closed while ping command is in progress
816+
errClosed++
817+
} else {
818+
t.Fatal(err)
819+
}
820+
}
821+
}
822+
}
823+
824+
t.Logf("Number of pings: %d, errClosed: %d", pings, errClosed)
825+
if pings < 10_000 {
826+
t.Errorf("Expected at least 10k pings, got: %d", pings)
827+
}
828+
}

0 commit comments

Comments
 (0)