@@ -3,16 +3,19 @@ package mongodb
33import (
44 "context"
55 "fmt"
6- "io"
7- "io/ioutil"
8- "net/url"
9- "strconv"
10-
6+ "github.com/cenkalti/backoff/v4"
117 "github.com/golang-migrate/migrate/v4/database"
8+ "github.com/hashicorp/go-multierror"
129 "go.mongodb.org/mongo-driver/bson"
1310 "go.mongodb.org/mongo-driver/mongo"
1411 "go.mongodb.org/mongo-driver/mongo/options"
1512 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13+ "io"
14+ "io/ioutil"
15+ "net/url"
16+ os "os"
17+ "strconv"
18+ "time"
1619)
1720
1821func init () {
@@ -23,6 +26,14 @@ func init() {
2326
2427var DefaultMigrationsCollection = "schema_migrations"
2528
29+ const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
30+ const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
31+ const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released.
32+ const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout.
33+ const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true.
34+ const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field.
35+ const contextWaitTimeout = 5 * time .Second // how long to wait for the request to mongo to block/wait for.
36+
2637var (
2738 ErrNoDatabaseName = fmt .Errorf ("no database name" )
2839 ErrNilConfig = fmt .Errorf ("no config" )
@@ -31,21 +42,36 @@ var (
3142type Mongo struct {
3243 client * mongo.Client
3344 db * mongo.Database
34-
3545 config * Config
3646}
3747
48+ type Locking struct {
49+ CollectionName string
50+ Timeout int
51+ Enabled bool
52+ Interval int
53+ }
3854type Config struct {
3955 DatabaseName string
4056 MigrationsCollection string
4157 TransactionMode bool
58+ Locking Locking
4259}
43-
4460type versionInfo struct {
4561 Version int `bson:"version"`
4662 Dirty bool `bson:"dirty"`
4763}
4864
65+ type lockObj struct {
66+ Key int `bson:"locking_key"`
67+ Pid int `bson:"pid"`
68+ Hostname string `bson:"hostname"`
69+ CreatedAt time.Time `bson:"created_at"`
70+ }
71+ type findFilter struct {
72+ Key int `bson:"locking_key"`
73+ }
74+
4975func WithInstance (instance * mongo.Client , config * Config ) (database.Driver , error ) {
5076 if config == nil {
5177 return nil , ErrNilConfig
@@ -56,48 +82,121 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro
5682 if len (config .MigrationsCollection ) == 0 {
5783 config .MigrationsCollection = DefaultMigrationsCollection
5884 }
85+ if len (config .Locking .CollectionName ) == 0 {
86+ config .Locking .CollectionName = DefaultLockingCollection
87+ }
88+ if config .Locking .Timeout <= 0 {
89+ config .Locking .Timeout = DefaultLockTimeout
90+ }
91+ if config .Locking .Interval <= 0 {
92+ config .Locking .Interval = DefaultLockTimeoutInterval
93+ }
94+
5995 mc := & Mongo {
6096 client : instance ,
6197 db : instance .Database (config .DatabaseName ),
6298 config : config ,
6399 }
64100
101+ if mc .config .Locking .Enabled {
102+ if err := mc .ensureLockTable (); err != nil {
103+ return nil , err
104+ }
105+ }
106+ if err := mc .ensureVersionTable (); err != nil {
107+ return nil , err
108+ }
109+
65110 return mc , nil
66111}
67112
68113func (m * Mongo ) Open (dsn string ) (database.Driver , error ) {
69- //connsting is experimental package, but it used for parse connection string in mongo.Connect function
114+ //connstring is experimental package, but it used for parse connection string in mongo.Connect function
70115 uri , err := connstring .Parse (dsn )
71116 if err != nil {
72117 return nil , err
73118 }
74119 if len (uri .Database ) == 0 {
75120 return nil , ErrNoDatabaseName
76121 }
77-
78122 unknown := url .Values (uri .UnknownOptions )
79123
80124 migrationsCollection := unknown .Get ("x-migrations-collection" )
81- transactionMode , _ := strconv .ParseBool (unknown .Get ("x-transaction-mode" ))
82-
125+ lockCollection := unknown .Get ("x-advisory-lock-collection" )
126+ transactionMode , err := parseBoolean (unknown .Get ("x-transaction-mode" ), false )
127+ if err != nil {
128+ return nil , err
129+ }
130+ advisoryLockingFlag , err := parseBoolean (unknown .Get ("x-advisory-locking" ), DefaultAdvisoryLockingFlag )
131+ if err != nil {
132+ return nil , err
133+ }
134+ lockingTimout , err := parseInt (unknown .Get ("x-advisory-lock-timeout" ), DefaultLockTimeout )
135+ if err != nil {
136+ return nil , err
137+ }
138+ maxLockingIntervals , err := parseInt (unknown .Get ("x-advisory-lock-timout-interval" ), DefaultLockTimeoutInterval )
139+ if err != nil {
140+ return nil , err
141+ }
83142 client , err := mongo .Connect (context .TODO (), options .Client ().ApplyURI (dsn ))
84143 if err != nil {
85144 return nil , err
86145 }
146+
87147 if err = client .Ping (context .TODO (), nil ); err != nil {
88148 return nil , err
89149 }
90150 mc , err := WithInstance (client , & Config {
91151 DatabaseName : uri .Database ,
92152 MigrationsCollection : migrationsCollection ,
93153 TransactionMode : transactionMode ,
154+ Locking : Locking {
155+ CollectionName : lockCollection ,
156+ Timeout : lockingTimout ,
157+ Enabled : advisoryLockingFlag ,
158+ Interval : maxLockingIntervals ,
159+ },
94160 })
95161 if err != nil {
96162 return nil , err
97163 }
98164 return mc , nil
99165}
100166
167+ //Parse the url param, convert it to boolean
168+ // returns error if param invalid. returns defaultValue if param not present
169+ func parseBoolean (urlParam string , defaultValue bool ) (bool , error ) {
170+
171+ // if parameter passed, parse it (otherwise return default value)
172+ if urlParam != "" {
173+ result , err := strconv .ParseBool (urlParam )
174+ if err != nil {
175+ return false , err
176+ }
177+ return result , nil
178+ }
179+
180+ // if no url Param passed, return default value
181+ return defaultValue , nil
182+ }
183+
184+ //Parse the url param, convert it to int
185+ // returns error if param invalid. returns defaultValue if param not present
186+ func parseInt (urlParam string , defaultValue int ) (int , error ) {
187+
188+ // if parameter passed, parse it (otherwise return default value)
189+ if urlParam != "" {
190+ result , err := strconv .Atoi (urlParam )
191+ if err != nil {
192+ return - 1 , err
193+ }
194+ return result , nil
195+ }
196+
197+ // if no url Param passed, return default value
198+ return defaultValue , nil
199+ }
101200func (m * Mongo ) SetVersion (version int , dirty bool ) error {
102201 migrationsCollection := m .db .Collection (m .config .MigrationsCollection )
103202 if err := migrationsCollection .Drop (context .TODO ()); err != nil {
@@ -184,10 +283,99 @@ func (m *Mongo) Drop() error {
184283 return m .db .Drop (context .TODO ())
185284}
186285
187- func (m * Mongo ) Lock () error {
286+ func (m * Mongo ) ensureLockTable () error {
287+ indexes := m .db .Collection (m .config .Locking .CollectionName ).Indexes ()
288+
289+ indexOptions := options .Index ().SetUnique (true ).SetName (LockIndexName )
290+ _ , err := indexes .CreateOne (context .TODO (), mongo.IndexModel {
291+ Options : indexOptions ,
292+ Keys : findFilter {Key : - 1 },
293+ })
294+ if err != nil {
295+ return err
296+ }
297+ return nil
298+ }
299+
300+ // ensureVersionTable checks if versions table exists and, if not, creates it.
301+ // Note that this function locks the database, which deviates from the usual
302+ // convention of "caller locks" in the MongoDb type.
303+ func (m * Mongo ) ensureVersionTable () (err error ) {
304+ if err = m .Lock (); err != nil {
305+ return err
306+ }
307+
308+ defer func () {
309+ if e := m .Unlock (); e != nil {
310+ if err == nil {
311+ err = e
312+ } else {
313+ err = multierror .Append (err , e )
314+ }
315+ }
316+ }()
317+
318+ if err != nil {
319+ return err
320+ }
321+ if _ , _ , err = m .Version (); err != nil {
322+ return err
323+ }
188324 return nil
189325}
190326
327+ // Utilizes advisory locking on the config.LockingCollection collection
328+ // This uses a unique index on the `locking_key` field.
329+ func (m * Mongo ) Lock () error {
330+ if ! m .config .Locking .Enabled {
331+ return nil
332+ }
333+ pid := os .Getpid ()
334+ hostname , err := os .Hostname ()
335+ if err != nil {
336+ hostname = fmt .Sprintf ("Could not determine hostname. Error: %s" , err .Error ())
337+ }
338+
339+ newLockObj := lockObj {
340+ Key : lockKeyUniqueValue ,
341+ Pid : pid ,
342+ Hostname : hostname ,
343+ CreatedAt : time .Now (),
344+ }
345+ operation := func () error {
346+ timeout , cancelFunc := context .WithTimeout (context .Background (), contextWaitTimeout )
347+ _ , err := m .db .Collection (m .config .Locking .CollectionName ).InsertOne (timeout , newLockObj )
348+ defer cancelFunc ()
349+ return err
350+ }
351+ exponentialBackOff := backoff .NewExponentialBackOff ()
352+ duration := time .Duration (m .config .Locking .Timeout ) * time .Second
353+ exponentialBackOff .MaxElapsedTime = duration
354+ exponentialBackOff .MaxInterval = time .Duration (m .config .Locking .Interval ) * time .Second
355+
356+ err = backoff .Retry (operation , exponentialBackOff )
357+ if err != nil {
358+ return database .ErrLocked
359+ }
360+
361+ return nil
362+
363+ }
191364func (m * Mongo ) Unlock () error {
365+ if ! m .config .Locking .Enabled {
366+ return nil
367+ }
368+
369+ filter := findFilter {
370+ Key : lockKeyUniqueValue ,
371+ }
372+
373+ ctx , cancel := context .WithTimeout (context .Background (), contextWaitTimeout )
374+ _ , err := m .db .Collection (m .config .Locking .CollectionName ).DeleteMany (ctx , filter )
375+ defer cancel ()
376+
377+ if err != nil {
378+ return err
379+ }
192380 return nil
193381}
0 commit comments