Skip to content

Commit b8acfce

Browse files
author
Cody Jones
committed
Test the connection pool's concurrency safety
1 parent 06bc199 commit b8acfce

File tree

6 files changed

+194
-18
lines changed

6 files changed

+194
-18
lines changed

cluster.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ type Cluster struct {
4646
// NewCluster creates a new cluster by connecting to the given hosts.
4747
func NewCluster(hosts []Host, opts *ConnectOpts) (*Cluster, error) {
4848
c := &Cluster{
49-
hp: newHostPool(opts),
50-
seeds: hosts,
51-
opts: opts,
52-
closed: clusterWorking,
53-
connFactory: NewConnection,
49+
hp: newHostPool(opts),
50+
seeds: hosts,
51+
opts: opts,
52+
closed: clusterWorking,
53+
connFactory: func(host string, opts *ConnectOpts) (connection, error) {
54+
return NewConnection(host, opts)
55+
},
5456
}
5557

5658
err := c.run()

cluster_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ import (
55
"encoding/binary"
66
"encoding/json"
77
"fmt"
8+
"io"
9+
"net"
10+
"time"
11+
812
"github.com/stretchr/testify/mock"
913
test "gopkg.in/check.v1"
1014
"gopkg.in/rethinkdb/rethinkdb-go.v6/encoding"
1115
p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2"
12-
"io"
13-
"net"
14-
"time"
1516
)
1617

1718
type ClusterSuite struct{}
@@ -360,20 +361,20 @@ type mockDial struct {
360361
}
361362

362363
func mockedConnectionFactory(dial *mockDial) connFactory {
363-
return func(host string, opts *ConnectOpts) (connection *Connection, err error) {
364+
return func(host string, opts *ConnectOpts) (connection, error) {
364365
args := dial.MethodCalled("Dial", host)
365-
err = args.Error(1)
366+
err := args.Error(1)
366367
if err != nil {
367368
return nil, err
368369
}
369370

370-
connection = newConnection(args.Get(0).(net.Conn), host, opts)
371-
done := runConnection(connection)
371+
conn := newConnection(args.Get(0).(net.Conn), host, opts)
372+
done := runConnection(conn)
372373

373374
m := args.Get(0).(*connMock)
374375
m.setDone(done)
375376

376-
return connection, nil
377+
return conn, nil
377378
}
378379
}
379380

connection.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ type Connection struct {
6161
mu sync.Mutex
6262
}
6363

64-
type connFactory func(host string, opts *ConnectOpts) (*Connection, error)
64+
type connection interface {
65+
Server() (ServerResponse, error)
66+
Query(context.Context, Query) (*Response, *Cursor, error)
67+
Close() error
68+
isBad() bool
69+
isClosed() bool
70+
}
71+
72+
type connFactory func(host string, opts *ConnectOpts) (connection, error)
6573

6674
type responseAndError struct {
6775
response *Response

cursor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func newCursor(ctx context.Context, conn *Connection, cursorType string, token i
6161
type Cursor struct {
6262
releaseConn func() error
6363

64-
conn *Connection
64+
conn connection
6565
connOpts *ConnectOpts
6666
token int64
6767
cursorType string

pool.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ type Pool struct {
4242
}
4343

4444
type countedConn struct {
45-
conn *Connection
45+
conn connection
4646
mu int64
4747
refCount int64
4848
goroutineID int64
4949
}
5050

5151
// NewPool creates a new connection pool for the given host
5252
func NewPool(host Host, opts *ConnectOpts) (*Pool, error) {
53-
return newPool(host, opts, NewConnection)
53+
return newPool(host, opts, func(host string, opts *ConnectOpts) (connection, error) {
54+
return NewConnection(host, opts)
55+
})
5456
}
5557

5658
func newPool(host Host, opts *ConnectOpts, connFactory connFactory) (*Pool, error) {
@@ -232,7 +234,7 @@ func (p *Pool) Server() (ServerResponse, error) {
232234
}
233235

234236
// getConnection returns a valid (usable) connection from the pool having the given index.
235-
func (p *Pool) getConnection(pos int) (*Connection, error) {
237+
func (p *Pool) getConnection(pos int) (connection, error) {
236238
conn := &p.countedConns[pos].conn
237239
var err error
238240

pool_test.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package rethinkdb
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"math/rand"
8+
"sync"
9+
"sync/atomic"
10+
"testing"
11+
"time"
12+
13+
"github.com/silentred/gid"
14+
"golang.org/x/sync/errgroup"
15+
test "gopkg.in/check.v1"
16+
)
17+
18+
type PoolSuite struct{}
19+
20+
var _ = test.Suite(&PoolSuite{})
21+
22+
func (s *PoolSuite) TestConcurrency(c *test.C) {
23+
if testing.Short() {
24+
c.Skip("-short set")
25+
}
26+
27+
const (
28+
poolCapacity = 100
29+
numGoroutines = 2000
30+
testDuration = 15 * time.Second
31+
)
32+
33+
connFactory := func(host string, opts *ConnectOpts) (connection, error) {
34+
return &fakePoolConn{badProbability: 0.001}, nil
35+
}
36+
37+
p, err := newPool(Host{}, &ConnectOpts{MaxOpen: poolCapacity}, connFactory)
38+
c.Assert(p, test.NotNil)
39+
c.Assert(err, test.IsNil)
40+
41+
localState := [poolCapacity]struct {
42+
mu sync.Mutex
43+
goroutineID int64
44+
refCount int64
45+
}{}
46+
47+
p.afterAcquire = func(pos int) {
48+
s := &localState[pos]
49+
50+
// Don't use atomics here. It's more important this test be correct than fast.
51+
s.mu.Lock()
52+
defer s.mu.Unlock()
53+
54+
if s.goroutineID != gid.Get() {
55+
if s.refCount == 0 {
56+
// First use of a connection, or reuse of an existing but unused connection
57+
s.goroutineID = gid.Get()
58+
} else {
59+
panic(fmt.Sprintf("A connection would be concurrently used by goroutines %v and %v", s.goroutineID, gid.Get()))
60+
}
61+
}
62+
63+
s.refCount++
64+
}
65+
66+
p.beforeRelease = func(pos int) {
67+
s := &localState[pos]
68+
69+
s.mu.Lock()
70+
defer s.mu.Unlock()
71+
72+
s.refCount--
73+
}
74+
75+
q := testQuery(DB("db").Table("table").Get("id"))
76+
77+
calls := []func(*Pool) error{
78+
func(p *Pool) error {
79+
return p.Ping()
80+
},
81+
82+
func(p *Pool) error {
83+
return p.Exec(context.Background(), q)
84+
},
85+
86+
func(p *Pool) error {
87+
cursor, err := p.Query(context.Background(), q)
88+
if err != nil {
89+
return err
90+
}
91+
92+
// Pretend to read from the cursor
93+
time.Sleep(1 * time.Millisecond)
94+
cursor.finished = true
95+
96+
return cursor.Close()
97+
},
98+
99+
func(p *Pool) error {
100+
_, err := p.Server()
101+
return err
102+
},
103+
}
104+
105+
testLoop := func() error {
106+
startTime := time.Now()
107+
for time.Since(startTime) < testDuration {
108+
i := rand.Intn(len(calls))
109+
err := calls[i](p)
110+
if err != nil {
111+
return err
112+
}
113+
}
114+
115+
return nil
116+
}
117+
118+
g := new(errgroup.Group)
119+
for i := 0; i < numGoroutines; i++ {
120+
g.Go(testLoop)
121+
}
122+
123+
err = g.Wait()
124+
c.Assert(err, test.IsNil)
125+
}
126+
127+
type fakePoolConn struct {
128+
bad int32
129+
badProbability float32
130+
}
131+
132+
func (c *fakePoolConn) Server() (ServerResponse, error) {
133+
if c.isBad() {
134+
return ServerResponse{}, errors.New("Server() was called even though the connection is known to be unusable")
135+
}
136+
137+
if rand.Float32() < c.badProbability {
138+
atomic.StoreInt32(&c.bad, connBad)
139+
}
140+
141+
return ServerResponse{}, nil
142+
}
143+
144+
func (c *fakePoolConn) Query(ctx context.Context, q Query) (*Response, *Cursor, error) {
145+
if c.isBad() {
146+
return nil, nil, errors.New("Query() was called even though the connection is known to be unusable")
147+
}
148+
149+
if rand.Float32() < c.badProbability {
150+
atomic.StoreInt32(&c.bad, connBad)
151+
}
152+
153+
cursor := &Cursor{
154+
ctx: context.Background(),
155+
conn: c,
156+
}
157+
158+
return nil, cursor, nil
159+
}
160+
161+
func (c *fakePoolConn) Close() error { panic("not implemented") }
162+
func (c *fakePoolConn) isBad() bool { return atomic.LoadInt32(&c.bad) == connBad }
163+
func (c *fakePoolConn) isClosed() bool { return false }

0 commit comments

Comments
 (0)