Skip to content

Commit e84d6a9

Browse files
Prometheus2677FPiety0521
authored andcommitted
Merge pull request #448 from andyNewman42/locking
Add advisory locking to mongodb
2 parents 085e1ba + 515d6e7 commit e84d6a9

File tree

5 files changed

+275
-13
lines changed

5 files changed

+275
-13
lines changed

database/mongodb/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
|------------|---------------------|-------------|
1414
| `x-migrations-collection` | `MigrationsCollection` | Name of the migrations collection |
1515
| `x-transaction-mode` | `TransactionMode` | If set to `true` wrap commands in [transaction](https://docs.mongodb.com/manual/core/transactions). Available only for replica set. Driver is using [strconv.ParseBool](https://golang.org/pkg/strconv/#ParseBool) for parsing|
16+
| `x-advisory-locking` | `true` | Feature flag for advisory locking, if set to false, disable advisory locking |
17+
| `x-advisory-lock-collection` | `migrate_advisory_lock` | The name of the collection to use for advisory locking.|
18+
| `x-advisory-lock-timout` | `15` | The max time in seconds that the advisory lock will wait if the db is already locked. |
19+
| `x-advisory-lock-timout-interval` | `10` | The max timeout in seconds interval that the advisory lock will wait if the db is already locked. |
1620
| `dbname` | `DatabaseName` | The name of the database to connect to |
1721
| `user` | | The user to sign in as. Can be omitted |
1822
| `password` | | The user's password. Can be omitted |

database/mongodb/mongodb.go

Lines changed: 200 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@ package mongodb
33
import (
44
"context"
55
"fmt"
6-
"io"
7-
"io/ioutil"
8-
"net/url"
9-
"strconv"
10-
6+
"github.com/cenkalti/backoff/v4"
117
"github.com/golang-migrate/migrate/v4/database"
8+
"github.com/hashicorp/go-multierror"
129
"go.mongodb.org/mongo-driver/bson"
1310
"go.mongodb.org/mongo-driver/mongo"
1411
"go.mongodb.org/mongo-driver/mongo/options"
1512
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13+
"io"
14+
"io/ioutil"
15+
"net/url"
16+
os "os"
17+
"strconv"
18+
"time"
1619
)
1720

1821
func init() {
@@ -23,6 +26,14 @@ func init() {
2326

2427
var DefaultMigrationsCollection = "schema_migrations"
2528

29+
const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
30+
const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
31+
const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released.
32+
const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout.
33+
const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true.
34+
const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field.
35+
const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for.
36+
2637
var (
2738
ErrNoDatabaseName = fmt.Errorf("no database name")
2839
ErrNilConfig = fmt.Errorf("no config")
@@ -31,21 +42,36 @@ var (
3142
type Mongo struct {
3243
client *mongo.Client
3344
db *mongo.Database
34-
3545
config *Config
3646
}
3747

48+
type Locking struct {
49+
CollectionName string
50+
Timeout int
51+
Enabled bool
52+
Interval int
53+
}
3854
type Config struct {
3955
DatabaseName string
4056
MigrationsCollection string
4157
TransactionMode bool
58+
Locking Locking
4259
}
43-
4460
type versionInfo struct {
4561
Version int `bson:"version"`
4662
Dirty bool `bson:"dirty"`
4763
}
4864

65+
type lockObj struct {
66+
Key int `bson:"locking_key"`
67+
Pid int `bson:"pid"`
68+
Hostname string `bson:"hostname"`
69+
CreatedAt time.Time `bson:"created_at"`
70+
}
71+
type findFilter struct {
72+
Key int `bson:"locking_key"`
73+
}
74+
4975
func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
5076
if config == nil {
5177
return nil, ErrNilConfig
@@ -56,48 +82,121 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro
5682
if len(config.MigrationsCollection) == 0 {
5783
config.MigrationsCollection = DefaultMigrationsCollection
5884
}
85+
if len(config.Locking.CollectionName) == 0 {
86+
config.Locking.CollectionName = DefaultLockingCollection
87+
}
88+
if config.Locking.Timeout <= 0 {
89+
config.Locking.Timeout = DefaultLockTimeout
90+
}
91+
if config.Locking.Interval <= 0 {
92+
config.Locking.Interval = DefaultLockTimeoutInterval
93+
}
94+
5995
mc := &Mongo{
6096
client: instance,
6197
db: instance.Database(config.DatabaseName),
6298
config: config,
6399
}
64100

101+
if mc.config.Locking.Enabled {
102+
if err := mc.ensureLockTable(); err != nil {
103+
return nil, err
104+
}
105+
}
106+
if err := mc.ensureVersionTable(); err != nil {
107+
return nil, err
108+
}
109+
65110
return mc, nil
66111
}
67112

68113
func (m *Mongo) Open(dsn string) (database.Driver, error) {
69-
//connsting is experimental package, but it used for parse connection string in mongo.Connect function
114+
//connstring is experimental package, but it used for parse connection string in mongo.Connect function
70115
uri, err := connstring.Parse(dsn)
71116
if err != nil {
72117
return nil, err
73118
}
74119
if len(uri.Database) == 0 {
75120
return nil, ErrNoDatabaseName
76121
}
77-
78122
unknown := url.Values(uri.UnknownOptions)
79123

80124
migrationsCollection := unknown.Get("x-migrations-collection")
81-
transactionMode, _ := strconv.ParseBool(unknown.Get("x-transaction-mode"))
82-
125+
lockCollection := unknown.Get("x-advisory-lock-collection")
126+
transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
127+
if err != nil {
128+
return nil, err
129+
}
130+
advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
131+
if err != nil {
132+
return nil, err
133+
}
134+
lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
135+
if err != nil {
136+
return nil, err
137+
}
138+
maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval)
139+
if err != nil {
140+
return nil, err
141+
}
83142
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
84143
if err != nil {
85144
return nil, err
86145
}
146+
87147
if err = client.Ping(context.TODO(), nil); err != nil {
88148
return nil, err
89149
}
90150
mc, err := WithInstance(client, &Config{
91151
DatabaseName: uri.Database,
92152
MigrationsCollection: migrationsCollection,
93153
TransactionMode: transactionMode,
154+
Locking: Locking{
155+
CollectionName: lockCollection,
156+
Timeout: lockingTimout,
157+
Enabled: advisoryLockingFlag,
158+
Interval: maxLockingIntervals,
159+
},
94160
})
95161
if err != nil {
96162
return nil, err
97163
}
98164
return mc, nil
99165
}
100166

167+
//Parse the url param, convert it to boolean
168+
// returns error if param invalid. returns defaultValue if param not present
169+
func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
170+
171+
// if parameter passed, parse it (otherwise return default value)
172+
if urlParam != "" {
173+
result, err := strconv.ParseBool(urlParam)
174+
if err != nil {
175+
return false, err
176+
}
177+
return result, nil
178+
}
179+
180+
// if no url Param passed, return default value
181+
return defaultValue, nil
182+
}
183+
184+
//Parse the url param, convert it to int
185+
// returns error if param invalid. returns defaultValue if param not present
186+
func parseInt(urlParam string, defaultValue int) (int, error) {
187+
188+
// if parameter passed, parse it (otherwise return default value)
189+
if urlParam != "" {
190+
result, err := strconv.Atoi(urlParam)
191+
if err != nil {
192+
return -1, err
193+
}
194+
return result, nil
195+
}
196+
197+
// if no url Param passed, return default value
198+
return defaultValue, nil
199+
}
101200
func (m *Mongo) SetVersion(version int, dirty bool) error {
102201
migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
103202
if err := migrationsCollection.Drop(context.TODO()); err != nil {
@@ -184,10 +283,99 @@ func (m *Mongo) Drop() error {
184283
return m.db.Drop(context.TODO())
185284
}
186285

187-
func (m *Mongo) Lock() error {
286+
func (m *Mongo) ensureLockTable() error {
287+
indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
288+
289+
indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
290+
_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
291+
Options: indexOptions,
292+
Keys: findFilter{Key: -1},
293+
})
294+
if err != nil {
295+
return err
296+
}
297+
return nil
298+
}
299+
300+
// ensureVersionTable checks if versions table exists and, if not, creates it.
301+
// Note that this function locks the database, which deviates from the usual
302+
// convention of "caller locks" in the MongoDb type.
303+
func (m *Mongo) ensureVersionTable() (err error) {
304+
if err = m.Lock(); err != nil {
305+
return err
306+
}
307+
308+
defer func() {
309+
if e := m.Unlock(); e != nil {
310+
if err == nil {
311+
err = e
312+
} else {
313+
err = multierror.Append(err, e)
314+
}
315+
}
316+
}()
317+
318+
if err != nil {
319+
return err
320+
}
321+
if _, _, err = m.Version(); err != nil {
322+
return err
323+
}
188324
return nil
189325
}
190326

327+
// Utilizes advisory locking on the config.LockingCollection collection
328+
// This uses a unique index on the `locking_key` field.
329+
func (m *Mongo) Lock() error {
330+
if !m.config.Locking.Enabled {
331+
return nil
332+
}
333+
pid := os.Getpid()
334+
hostname, err := os.Hostname()
335+
if err != nil {
336+
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
337+
}
338+
339+
newLockObj := lockObj{
340+
Key: lockKeyUniqueValue,
341+
Pid: pid,
342+
Hostname: hostname,
343+
CreatedAt: time.Now(),
344+
}
345+
operation := func() error {
346+
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
347+
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
348+
defer cancelFunc()
349+
return err
350+
}
351+
exponentialBackOff := backoff.NewExponentialBackOff()
352+
duration := time.Duration(m.config.Locking.Timeout) * time.Second
353+
exponentialBackOff.MaxElapsedTime = duration
354+
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
355+
356+
err = backoff.Retry(operation, exponentialBackOff)
357+
if err != nil {
358+
return database.ErrLocked
359+
}
360+
361+
return nil
362+
363+
}
191364
func (m *Mongo) Unlock() error {
365+
if !m.config.Locking.Enabled {
366+
return nil
367+
}
368+
369+
filter := findFilter{
370+
Key: lockKeyUniqueValue,
371+
}
372+
373+
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
374+
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
375+
defer cancel()
376+
377+
if err != nil {
378+
return err
379+
}
192380
return nil
193381
}

database/mongodb/mongodb_test.go

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func Test(t *testing.T) {
9292
}
9393
}()
9494
dt.TestNilVersion(t, d)
95-
//TestLockAndUnlock(t, d) driver doesn't support lock on database level
95+
dt.TestLockAndUnlock(t, d)
9696
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
9797
dt.TestSetVersion(t, d)
9898
dt.TestDrop(t, d)
@@ -180,6 +180,73 @@ func TestWithAuth(t *testing.T) {
180180
})
181181
}
182182

183+
func TestLockWorks(t *testing.T) {
184+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
185+
ip, port, err := c.FirstPort()
186+
if err != nil {
187+
t.Fatal(err)
188+
}
189+
190+
addr := mongoConnectionString(ip, port)
191+
p := &Mongo{}
192+
d, err := p.Open(addr)
193+
if err != nil {
194+
t.Fatal(err)
195+
}
196+
defer func() {
197+
if err := d.Close(); err != nil {
198+
t.Error(err)
199+
}
200+
}()
201+
202+
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
203+
204+
mc := d.(*Mongo)
205+
206+
err = mc.Lock()
207+
if err != nil {
208+
t.Fatal(err)
209+
}
210+
err = mc.Unlock()
211+
if err != nil {
212+
t.Fatal(err)
213+
}
214+
215+
err = mc.Lock()
216+
if err != nil {
217+
t.Fatal(err)
218+
}
219+
err = mc.Unlock()
220+
if err != nil {
221+
t.Fatal(err)
222+
}
223+
224+
// disable locking, validate wer can lock twice
225+
mc.config.Locking.Enabled = false
226+
err = mc.Lock()
227+
if err != nil {
228+
t.Fatal(err)
229+
}
230+
err = mc.Lock()
231+
if err != nil {
232+
t.Fatal(err)
233+
}
234+
235+
// re-enable locking,
236+
//try to hit a lock conflict
237+
mc.config.Locking.Enabled = true
238+
mc.config.Locking.Timeout = 1
239+
err = mc.Lock()
240+
if err != nil {
241+
t.Fatal(err)
242+
}
243+
err = mc.Lock()
244+
if err == nil {
245+
t.Fatal("should have failed, mongo should be locked already")
246+
}
247+
})
248+
}
249+
183250
func TestTransaction(t *testing.T) {
184251
transactionSpecs := []dktesting.ContainerSpec{
185252
{ImageName: "mongo:4", Options: dktest.Options{PortRequired: true, ReadyFunc: isReady,

0 commit comments

Comments
 (0)