From 016f649f7955fa75aa436e645c9cc720d130eeba Mon Sep 17 00:00:00 2001 From: Svetlin Ralchev Date: Tue, 27 Jul 2021 16:32:21 +0100 Subject: [PATCH] Allow developers to change the migrations table Let's allow engineers to change the table depending on their preferences or naming convention. --- sqlmigr/executor.go | 13 +++++-------- sqlmigr/generator.go | 6 ++++-- sqlmigr/model.go | 16 +++++++++++++--- sqlmigr/provider.go | 30 +++++++++++++++++++++++++----- sqlmigr/runner.go | 5 ++++- 5 files changed, 51 insertions(+), 19 deletions(-) diff --git a/sqlmigr/executor.go b/sqlmigr/executor.go index 0cf0f71..2810cbb 100644 --- a/sqlmigr/executor.go +++ b/sqlmigr/executor.go @@ -3,6 +3,7 @@ package sqlmigr import ( "bytes" "fmt" + "regexp" "strings" "time" @@ -10,6 +11,8 @@ import ( "github.com/phogolabs/log" ) +var migrationRgxp = regexp.MustCompile(`CREATE TABLE IF NOT EXISTS\s*([a-z]+)\s*`) + // Executor provides a group of operations that works with migrations. type Executor struct { // Logger logs each execution step @@ -25,14 +28,8 @@ type Executor struct { // Setup setups the current project for database migrations by creating // migration directory and related database. func (m *Executor) Setup() error { - migration := &Migration{ - ID: min.Format(format), - Description: "setup", - Drivers: []string{every}, - CreatedAt: time.Now(), - } - up := &bytes.Buffer{} + fmt.Fprintln(up, "CREATE TABLE IF NOT EXISTS migrations (") fmt.Fprintln(up, " id VARCHAR(15) NOT NULL PRIMARY KEY,") fmt.Fprintln(up, " description TEXT NOT NULL,") @@ -48,7 +45,7 @@ func (m *Executor) Setup() error { DownCommand: down, } - return m.Generator.Write(migration, content) + return m.Generator.Write(setup, content) } // Create creates a migration script successfully if the project has already diff --git a/sqlmigr/generator.go b/sqlmigr/generator.go index f4b221f..105b10e 100644 --- a/sqlmigr/generator.go +++ b/sqlmigr/generator.go @@ -69,8 +69,10 @@ func (g *Generator) write(filename string, data []byte, perm os.FileMode) error err = io.ErrShortWrite } } - if err1 := f.Close(); err == nil { - err = err1 + + if xerr := f.Close(); err == nil { + err = xerr } + return err } diff --git a/sqlmigr/model.go b/sqlmigr/model.go index da1cd53..9a26c0d 100644 --- a/sqlmigr/model.go +++ b/sqlmigr/model.go @@ -23,6 +23,16 @@ var ( every = "sql" ) +var ( + // setup migration + setup = &Migration{ + ID: min.Format(format), + Description: "setup", + Drivers: []string{every}, + CreatedAt: time.Now(), + } +) + // FileSystem provides with primitives to work with the underlying file system type FileSystem = fs.FS @@ -163,13 +173,13 @@ func IsNotExist(err error) bool { switch { // SQLite - case msg == "no such table: migrations": + case strings.HasPrefix(msg, "no such table"): return true // PostgreSQL - case msg == `pq: relation "migrations" does not exist`: + case strings.HasSuffix(msg, "does not exist"): return true // MySQL - case strings.HasSuffix(msg, "migrations' doesn't exist"): + case strings.HasSuffix(msg, "doesn't exist"): return true default: return false diff --git a/sqlmigr/provider.go b/sqlmigr/provider.go index eb633b9..6edb22b 100644 --- a/sqlmigr/provider.go +++ b/sqlmigr/provider.go @@ -3,6 +3,7 @@ package sqlmigr import ( "bytes" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -39,7 +40,7 @@ func (m *Provider) Migrations() ([]*Migration, error) { func (m *Provider) files() ([]*Migration, error) { local := []*Migration{} - err := fs.WalkDir(m.FileSystem, ".", func(path string, info os.DirEntry, err error) error { + err := fs.WalkDir(m.FileSystem, ".", func(path string, info os.DirEntry, xerr error) error { if ferr := m.filter(info); ferr != nil { if ferr.Error() == "skip" { ferr = nil @@ -109,7 +110,7 @@ func (m *Provider) supported(drivers []string) bool { func (m *Provider) query() ([]*Migration, error) { query := &bytes.Buffer{} query.WriteString("SELECT id, description, created_at ") - query.WriteString("FROM migrations ") + query.WriteString("FROM " + m.table() + " ") query.WriteString("ORDER BY id ASC") remote := []*Migration{} @@ -126,7 +127,7 @@ func (m *Provider) Insert(item *Migration) error { item.CreatedAt = time.Now() builder := &bytes.Buffer{} - builder.WriteString("INSERT INTO migrations(id, description, created_at) ") + builder.WriteString("INSERT INTO " + m.table() + "(id, description, created_at) ") builder.WriteString("VALUES (?, ?, ?)") query := m.DB.Rebind(builder.String()) @@ -140,7 +141,7 @@ func (m *Provider) Insert(item *Migration) error { // Delete deletes applied sqlmigr item from sqlmigrs table. func (m *Provider) Delete(item *Migration) error { builder := &bytes.Buffer{} - builder.WriteString("DELETE FROM migrations ") + builder.WriteString("DELETE FROM " + m.table() + " ") builder.WriteString("WHERE id = ?") query := m.DB.Rebind(builder.String()) @@ -155,7 +156,7 @@ func (m *Provider) Delete(item *Migration) error { func (m *Provider) Exists(item *Migration) bool { count := 0 - if err := m.DB.Get(&count, "SELECT count(id) FROM migrations WHERE id = ?", item.ID); err != nil { + if err := m.DB.Get(&count, "SELECT count(id) FROM "+m.table()+" WHERE id = ?", item.ID); err != nil { return false } @@ -183,3 +184,22 @@ func (m *Provider) merge(remote, local []*Migration) ([]*Migration, error) { return result, nil } + +func (m *Provider) table() string { + for _, path := range setup.Filenames() { + file, err := m.FileSystem.Open(path) + if err != nil { + continue + } + // close the file + defer file.Close() + + if data, err := io.ReadAll(file); err == nil { + if match := migrationRgxp.FindSubmatch(data); len(match) == 2 { + return string(match[1]) + } + } + } + + return "migrations" +} diff --git a/sqlmigr/runner.go b/sqlmigr/runner.go index 206f1fc..ea7c3aa 100644 --- a/sqlmigr/runner.go +++ b/sqlmigr/runner.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jmoiron/sqlx" + "github.com/phogolabs/log" "github.com/phogolabs/prana/sqlexec" ) @@ -41,7 +42,9 @@ func (r *Runner) exec(step string, m *Migration) error { for _, query := range statements { if _, err := tx.Exec(query); err != nil { - tx.Rollback() + if xerr := tx.Rollback(); xerr != nil { + log.WithError(xerr).Error("rollback failure") + } return &RunnerError{ Err: err,