|
| 1 | +package rqlite |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "io" |
| 6 | + nurl "net/url" |
| 7 | + "strconv" |
| 8 | + "strings" |
| 9 | + |
| 10 | + "go.uber.org/atomic" |
| 11 | + |
| 12 | + "github.com/golang-migrate/migrate/v4" |
| 13 | + "github.com/golang-migrate/migrate/v4/database" |
| 14 | + "github.com/hashicorp/go-multierror" |
| 15 | + "github.com/pkg/errors" |
| 16 | + "github.com/rqlite/gorqlite" |
| 17 | +) |
| 18 | + |
| 19 | +func init() { |
| 20 | + database.Register("rqlite", &Rqlite{}) |
| 21 | +} |
| 22 | + |
| 23 | +const ( |
| 24 | + // DefaultMigrationsTable defines the default rqlite migrations table |
| 25 | + DefaultMigrationsTable = "schema_migrations" |
| 26 | + |
| 27 | + // DefaultConnectInsecure defines the default setting for connect insecure |
| 28 | + DefaultConnectInsecure = false |
| 29 | +) |
| 30 | + |
| 31 | +// ErrNilConfig is returned if no configuration was passed to WithInstance |
| 32 | +var ErrNilConfig = fmt.Errorf("no config") |
| 33 | + |
| 34 | +// ErrBadConfig is returned if configuration was invalid |
| 35 | +var ErrBadConfig = fmt.Errorf("bad parameter") |
| 36 | + |
| 37 | +// Config defines the driver configuration |
| 38 | +type Config struct { |
| 39 | + // ConnectInsecure sets whether the connection uses TLS. Ineffectual when using WithInstance |
| 40 | + ConnectInsecure bool |
| 41 | + // MigrationsTable configures the migrations table name |
| 42 | + MigrationsTable string |
| 43 | +} |
| 44 | + |
| 45 | +type Rqlite struct { |
| 46 | + db *gorqlite.Connection |
| 47 | + isLocked atomic.Bool |
| 48 | + |
| 49 | + config *Config |
| 50 | +} |
| 51 | + |
| 52 | +// WithInstance creates a rqlite database driver with an existing gorqlite database connection |
| 53 | +// and a Config struct |
| 54 | +func WithInstance(instance *gorqlite.Connection, config *Config) (database.Driver, error) { |
| 55 | + if config == nil { |
| 56 | + return nil, ErrNilConfig |
| 57 | + } |
| 58 | + |
| 59 | + // we use the consistency level check as a database ping |
| 60 | + if _, err := instance.ConsistencyLevel(); err != nil { |
| 61 | + return nil, err |
| 62 | + } |
| 63 | + |
| 64 | + if len(config.MigrationsTable) == 0 { |
| 65 | + config.MigrationsTable = DefaultMigrationsTable |
| 66 | + } |
| 67 | + |
| 68 | + driver := &Rqlite{ |
| 69 | + db: instance, |
| 70 | + config: config, |
| 71 | + } |
| 72 | + |
| 73 | + if err := driver.ensureVersionTable(); err != nil { |
| 74 | + return nil, err |
| 75 | + } |
| 76 | + |
| 77 | + return driver, nil |
| 78 | +} |
| 79 | + |
| 80 | +// OpenURL creates a rqlite database driver from a connect URL |
| 81 | +func OpenURL(url string) (database.Driver, error) { |
| 82 | + d := &Rqlite{} |
| 83 | + return d.Open(url) |
| 84 | +} |
| 85 | + |
| 86 | +func (r *Rqlite) ensureVersionTable() (err error) { |
| 87 | + if err = r.Lock(); err != nil { |
| 88 | + return err |
| 89 | + } |
| 90 | + |
| 91 | + defer func() { |
| 92 | + if e := r.Unlock(); e != nil { |
| 93 | + if err == nil { |
| 94 | + err = e |
| 95 | + } else { |
| 96 | + err = multierror.Append(err, e) |
| 97 | + } |
| 98 | + } |
| 99 | + }() |
| 100 | + |
| 101 | + stmts := []string{ |
| 102 | + fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool)`, r.config.MigrationsTable), |
| 103 | + fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version)`, r.config.MigrationsTable), |
| 104 | + } |
| 105 | + |
| 106 | + if _, err := r.db.Write(stmts); err != nil { |
| 107 | + return err |
| 108 | + } |
| 109 | + |
| 110 | + return nil |
| 111 | +} |
| 112 | + |
| 113 | +// Open returns a new driver instance configured with parameters |
| 114 | +// coming from the URL string. Migrate will call this function |
| 115 | +// only once per instance. |
| 116 | +func (r *Rqlite) Open(url string) (database.Driver, error) { |
| 117 | + dburl, config, err := parseUrl(url) |
| 118 | + if err != nil { |
| 119 | + return nil, err |
| 120 | + } |
| 121 | + r.config = config |
| 122 | + |
| 123 | + r.db, err = gorqlite.Open(dburl.String()) |
| 124 | + if err != nil { |
| 125 | + return nil, err |
| 126 | + } |
| 127 | + |
| 128 | + if err := r.ensureVersionTable(); err != nil { |
| 129 | + return nil, err |
| 130 | + } |
| 131 | + |
| 132 | + return r, nil |
| 133 | +} |
| 134 | + |
| 135 | +// Close closes the underlying database instance managed by the driver. |
| 136 | +// Migrate will call this function only once per instance. |
| 137 | +func (r *Rqlite) Close() error { |
| 138 | + r.db.Close() |
| 139 | + return nil |
| 140 | +} |
| 141 | + |
| 142 | +// Lock should acquire a database lock so that only one migration process |
| 143 | +// can run at a time. Migrate will call this function before Run is called. |
| 144 | +// If the implementation can't provide this functionality, return nil. |
| 145 | +// Return database.ErrLocked if database is already locked. |
| 146 | +func (r *Rqlite) Lock() error { |
| 147 | + if !r.isLocked.CAS(false, true) { |
| 148 | + return database.ErrLocked |
| 149 | + } |
| 150 | + return nil |
| 151 | +} |
| 152 | + |
| 153 | +// Unlock should release the lock. Migrate will call this function after |
| 154 | +// all migrations have been run. |
| 155 | +func (r *Rqlite) Unlock() error { |
| 156 | + if !r.isLocked.CAS(true, false) { |
| 157 | + return database.ErrNotLocked |
| 158 | + } |
| 159 | + return nil |
| 160 | +} |
| 161 | + |
| 162 | +// Run applies a migration to the database. migration is guaranteed to be not nil. |
| 163 | +func (r *Rqlite) Run(migration io.Reader) error { |
| 164 | + migr, err := io.ReadAll(migration) |
| 165 | + if err != nil { |
| 166 | + return err |
| 167 | + } |
| 168 | + |
| 169 | + query := string(migr[:]) |
| 170 | + if _, err := r.db.WriteOne(query); err != nil { |
| 171 | + return &database.Error{OrigErr: err, Query: []byte(query)} |
| 172 | + } |
| 173 | + |
| 174 | + return nil |
| 175 | +} |
| 176 | + |
| 177 | +// SetVersion saves version and dirty state. |
| 178 | +// Migrate will call this function before and after each call to Run. |
| 179 | +// version must be >= -1. -1 means NilVersion. |
| 180 | +func (r *Rqlite) SetVersion(version int, dirty bool) error { |
| 181 | + deleteQuery := fmt.Sprintf(`DELETE FROM %s`, r.config.MigrationsTable) |
| 182 | + statements := []gorqlite.ParameterizedStatement{ |
| 183 | + { |
| 184 | + Query: deleteQuery, |
| 185 | + }, |
| 186 | + } |
| 187 | + |
| 188 | + // Also re-write the schema version for nil dirty versions to prevent |
| 189 | + // empty schema version for failed down migration on the first migration |
| 190 | + // See: https://github.com/golang-migrate/migrate/issues/330 |
| 191 | + insertQuery := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, r.config.MigrationsTable) |
| 192 | + if version >= 0 || (version == database.NilVersion && dirty) { |
| 193 | + statements = append(statements, gorqlite.ParameterizedStatement{ |
| 194 | + Query: insertQuery, |
| 195 | + Arguments: []interface{}{ |
| 196 | + version, |
| 197 | + dirty, |
| 198 | + }, |
| 199 | + }) |
| 200 | + } |
| 201 | + |
| 202 | + wr, err := r.db.WriteParameterized(statements) |
| 203 | + if err != nil { |
| 204 | + for i, res := range wr { |
| 205 | + if res.Err != nil { |
| 206 | + return &database.Error{OrigErr: err, Query: []byte(statements[i].Query)} |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + // if somehow we're still here, return the original error with combined queries |
| 211 | + return &database.Error{OrigErr: err, Query: []byte(deleteQuery + "\n" + insertQuery)} |
| 212 | + } |
| 213 | + |
| 214 | + return nil |
| 215 | +} |
| 216 | + |
| 217 | +// Version returns the currently active version and if the database is dirty. |
| 218 | +// When no migration has been applied, it must return version -1. |
| 219 | +// Dirty means, a previous migration failed and user interaction is required. |
| 220 | +func (r *Rqlite) Version() (version int, dirty bool, err error) { |
| 221 | + query := "SELECT version, dirty FROM " + r.config.MigrationsTable + " LIMIT 1" |
| 222 | + |
| 223 | + qr, err := r.db.QueryOne(query) |
| 224 | + if err != nil { |
| 225 | + return database.NilVersion, false, nil |
| 226 | + } |
| 227 | + |
| 228 | + if !qr.Next() { |
| 229 | + return database.NilVersion, false, nil |
| 230 | + } |
| 231 | + |
| 232 | + if err := qr.Scan(&version, &dirty); err != nil { |
| 233 | + return database.NilVersion, false, &database.Error{OrigErr: err, Query: []byte(query)} |
| 234 | + } |
| 235 | + |
| 236 | + return version, dirty, nil |
| 237 | +} |
| 238 | + |
| 239 | +// Drop deletes everything in the database. |
| 240 | +// Note that this is a breaking action, a new call to Open() is necessary to |
| 241 | +// ensure subsequent calls work as expected. |
| 242 | +func (r *Rqlite) Drop() error { |
| 243 | + query := `SELECT name FROM sqlite_master WHERE type = 'table'` |
| 244 | + |
| 245 | + tables, err := r.db.QueryOne(query) |
| 246 | + if err != nil { |
| 247 | + return &database.Error{OrigErr: err, Query: []byte(query)} |
| 248 | + } |
| 249 | + |
| 250 | + statements := make([]string, 0) |
| 251 | + for tables.Next() { |
| 252 | + var tableName string |
| 253 | + if err := tables.Scan(&tableName); err != nil { |
| 254 | + return err |
| 255 | + } |
| 256 | + |
| 257 | + if len(tableName) > 0 { |
| 258 | + statement := fmt.Sprintf(`DROP TABLE %s`, tableName) |
| 259 | + statements = append(statements, statement) |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + // return if nothing to do |
| 264 | + if len(statements) <= 0 { |
| 265 | + return nil |
| 266 | + } |
| 267 | + |
| 268 | + wr, err := r.db.Write(statements) |
| 269 | + if err != nil { |
| 270 | + for i, res := range wr { |
| 271 | + if res.Err != nil { |
| 272 | + return &database.Error{OrigErr: err, Query: []byte(statements[i])} |
| 273 | + } |
| 274 | + } |
| 275 | + |
| 276 | + // if somehow we're still here, return the original error with combined queries |
| 277 | + return &database.Error{OrigErr: err, Query: []byte(strings.Join(statements, "\n"))} |
| 278 | + } |
| 279 | + |
| 280 | + return nil |
| 281 | +} |
| 282 | + |
| 283 | +func parseUrl(url string) (*nurl.URL, *Config, error) { |
| 284 | + parsedUrl, err := nurl.Parse(url) |
| 285 | + if err != nil { |
| 286 | + return nil, nil, err |
| 287 | + } |
| 288 | + |
| 289 | + config, err := parseConfigFromQuery(parsedUrl.Query()) |
| 290 | + if err != nil { |
| 291 | + return nil, nil, err |
| 292 | + } |
| 293 | + |
| 294 | + if parsedUrl.Scheme != "rqlite" { |
| 295 | + return nil, nil, errors.Wrap(ErrBadConfig, "bad scheme") |
| 296 | + } |
| 297 | + |
| 298 | + // adapt from rqlite to http/https schemes |
| 299 | + if config.ConnectInsecure { |
| 300 | + parsedUrl.Scheme = "http" |
| 301 | + } else { |
| 302 | + parsedUrl.Scheme = "https" |
| 303 | + } |
| 304 | + |
| 305 | + filteredUrl := migrate.FilterCustomQuery(parsedUrl) |
| 306 | + |
| 307 | + return filteredUrl, config, nil |
| 308 | +} |
| 309 | + |
| 310 | +func parseConfigFromQuery(queryVals nurl.Values) (*Config, error) { |
| 311 | + c := Config{ |
| 312 | + ConnectInsecure: DefaultConnectInsecure, |
| 313 | + MigrationsTable: DefaultMigrationsTable, |
| 314 | + } |
| 315 | + |
| 316 | + migrationsTable := queryVals.Get("x-migrations-table") |
| 317 | + if migrationsTable != "" { |
| 318 | + if strings.HasPrefix(migrationsTable, "sqlite_") { |
| 319 | + return nil, errors.Wrap(ErrBadConfig, "invalid value for x-migrations-table") |
| 320 | + } |
| 321 | + c.MigrationsTable = migrationsTable |
| 322 | + } |
| 323 | + |
| 324 | + connectInsecureStr := queryVals.Get("x-connect-insecure") |
| 325 | + if connectInsecureStr != "" { |
| 326 | + connectInsecure, err := strconv.ParseBool(connectInsecureStr) |
| 327 | + if err != nil { |
| 328 | + return nil, errors.Wrap(ErrBadConfig, "invalid value for x-connect-insecure") |
| 329 | + } |
| 330 | + c.ConnectInsecure = connectInsecure |
| 331 | + } |
| 332 | + |
| 333 | + return &c, nil |
| 334 | +} |
0 commit comments