Skip to content

Commit 4f7feb8

Browse files
Prometheus2677FPiety0521
authored and
FPiety0521
committed
Correctly mark migrations as dirty when the first down migration fails
Fixes: golang-migrate/migrate#330 Other changes: - Cleanup code by using the database.NilVersion constant instead of the magic number -1
1 parent fb29412 commit 4f7feb8

File tree

14 files changed

+84
-59
lines changed

14 files changed

+84
-59
lines changed

database/cassandra/cassandra.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
201201
if err := c.session.Query(query).Exec(); err != nil {
202202
return &database.Error{OrigErr: err, Query: []byte(query)}
203203
}
204-
if version >= 0 {
204+
205+
// Also re-write the schema version for nil dirty versions to prevent
206+
// empty schema version for failed down migration on the first migration
207+
// See: https://github.com/golang-migrate/migrate/issues/330
208+
if version >= 0 || (version == database.NilVersion && dirty) {
205209
query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
206210
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
207211
return &database.Error{OrigErr: err, Query: []byte(query)}

database/cockroachdb/cockroachdb.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ func (c *CockroachDb) SetVersion(version int, dirty bool) error {
240240
return err
241241
}
242242

243-
if version >= 0 {
243+
// Also re-write the schema version for nil dirty versions to prevent
244+
// empty schema version for failed down migration on the first migration
245+
// See: https://github.com/golang-migrate/migrate/issues/330
246+
if version >= 0 || (version == database.NilVersion && dirty) {
244247
if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
245248
return err
246249
}

database/firebird/firebird.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ func (f *Firebird) Run(migration io.Reader) error {
134134
}
135135

136136
func (f *Firebird) SetVersion(version int, dirty bool) error {
137-
if version < 0 {
138-
return nil
139-
}
137+
// Always re-write the schema version to prevent empty schema version
138+
// for failed down migration on the first migration
139+
// See: https://github.com/golang-migrate/migrate/issues/330
140140

141141
// TODO: parameterize this SQL statement
142142
// https://firebirdsql.org/refdocs/langrefupd20-execblock.html

database/mysql/mysql.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,10 @@ func (m *Mysql) SetVersion(version int, dirty bool) error {
301301
return &database.Error{OrigErr: err, Query: []byte(query)}
302302
}
303303

304-
if version >= 0 {
304+
// Also re-write the schema version for nil dirty versions to prevent
305+
// empty schema version for failed down migration on the first migration
306+
// See: https://github.com/golang-migrate/migrate/issues/330
307+
if version >= 0 || (version == database.NilVersion && dirty) {
305308
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
306309
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
307310
if errRollback := tx.Rollback(); errRollback != nil {

database/neo4j/neo4j.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ type MigrationRecord struct {
183183
func (n *Neo4j) Version() (version int, dirty bool, err error) {
184184
session, err := n.driver.Session(neo4j.AccessModeRead)
185185
if err != nil {
186-
return -1, false, err
186+
return database.NilVersion, false, err
187187
}
188188
defer func() {
189189
if cerr := session.Close(); cerr != nil {
@@ -203,7 +203,7 @@ func (n *Neo4j) Version() (version int, dirty bool, err error) {
203203
mr := MigrationRecord{}
204204
versionResult, ok := record.Get("version")
205205
if !ok {
206-
mr.Version = -1
206+
mr.Version = database.NilVersion
207207
} else {
208208
mr.Version = int(versionResult.(int64))
209209
}
@@ -218,10 +218,10 @@ func (n *Neo4j) Version() (version int, dirty bool, err error) {
218218
return nil, result.Err()
219219
})
220220
if err != nil {
221-
return -1, false, err
221+
return database.NilVersion, false, err
222222
}
223223
if result == nil {
224-
return -1, false, err
224+
return database.NilVersion, false, err
225225
}
226226
mr := result.(MigrationRecord)
227227
return mr.Version, mr.Dirty, err

database/postgres/postgres.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,12 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
280280
return &database.Error{OrigErr: err, Query: []byte(query)}
281281
}
282282

283-
if version >= 0 {
284-
query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)`
283+
// Also re-write the schema version for nil dirty versions to prevent
284+
// empty schema version for failed down migration on the first migration
285+
// See: https://github.com/golang-migrate/migrate/issues/330
286+
if version >= 0 || (version == database.NilVersion && dirty) {
287+
query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) +
288+
` (version, dirty) VALUES ($1, $2)`
285289
if _, err := tx.Exec(query, version, dirty); err != nil {
286290
if errRollback := tx.Rollback(); errRollback != nil {
287291
err = multierror.Append(err, errRollback)

database/postgres/postgres_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"github.com/dhui/dktest"
2020

21+
"github.com/golang-migrate/migrate/v4/database"
2122
dt "github.com/golang-migrate/migrate/v4/database/testing"
2223
"github.com/golang-migrate/migrate/v4/dktesting"
2324
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -248,7 +249,7 @@ func TestWithSchema(t *testing.T) {
248249
if err != nil {
249250
t.Fatal(err)
250251
}
251-
if version != -1 {
252+
if version != database.NilVersion {
252253
t.Fatal("expected NilVersion")
253254
}
254255

database/ql/ql.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ func (m *Ql) SetVersion(version int, dirty bool) error {
210210
return &database.Error{OrigErr: err, Query: []byte(query)}
211211
}
212212

213-
if version >= 0 {
213+
// Also re-write the schema version for nil dirty versions to prevent
214+
// empty schema version for failed down migration on the first migration
215+
// See: https://github.com/golang-migrate/migrate/issues/330
216+
if version >= 0 || (version == database.NilVersion && dirty) {
214217
query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
215218
m.config.MigrationsTable)
216219
if _, err := tx.Exec(query, version, dirty); err != nil {

database/redshift/redshift.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,10 @@ func (p *Redshift) SetVersion(version int, dirty bool) error {
220220
return &database.Error{OrigErr: err, Query: []byte(query)}
221221
}
222222

223-
if version >= 0 {
223+
// Also re-write the schema version for nil dirty versions to prevent
224+
// empty schema version for failed down migration on the first migration
225+
// See: https://github.com/golang-migrate/migrate/issues/330
226+
if version >= 0 || (version == database.NilVersion && dirty) {
224227
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
225228
if _, err := tx.Exec(query, version, dirty); err != nil {
226229
if errRollback := tx.Rollback(); errRollback != nil {

database/redshift/redshift_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
)
2323

2424
import (
25+
"github.com/golang-migrate/migrate/v4/database"
2526
dt "github.com/golang-migrate/migrate/v4/database/testing"
2627
"github.com/golang-migrate/migrate/v4/dktesting"
2728
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -247,7 +248,7 @@ func TestWithSchema(t *testing.T) {
247248
if err != nil {
248249
t.Fatal(err)
249250
}
250-
if version != -1 {
251+
if version != database.NilVersion {
251252
t.Fatal("expected NilVersion")
252253
}
253254

database/sqlite3/sqlite3.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error {
235235
return &database.Error{OrigErr: err, Query: []byte(query)}
236236
}
237237

238-
if version >= 0 {
238+
// Also re-write the schema version for nil dirty versions to prevent
239+
// empty schema version for failed down migration on the first migration
240+
// See: https://github.com/golang-migrate/migrate/issues/330
241+
if version >= 0 || (version == database.NilVersion && dirty) {
239242
query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
240243
if _, err := tx.Exec(query, version, dirty); err != nil {
241244
if errRollback := tx.Rollback(); errRollback != nil {

database/sqlserver/sqlserver.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,10 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
239239
return &database.Error{OrigErr: err, Query: []byte(query)}
240240
}
241241

242-
if version >= 0 {
242+
// Also re-write the schema version for nil dirty versions to prevent
243+
// empty schema version for failed down migration on the first migration
244+
// See: https://github.com/golang-migrate/migrate/issues/330
245+
if version >= 0 || (version == database.NilVersion && dirty) {
243246
var dirtyBit int
244247
if dirty {
245248
dirtyBit = 1

database/stub/stub.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ type Stub struct {
2727
func (s *Stub) Open(url string) (database.Driver, error) {
2828
return &Stub{
2929
Url: url,
30-
CurrentVersion: -1,
30+
CurrentVersion: database.NilVersion,
3131
MigrationSequence: make([]string, 0),
3232
Config: &Config{},
3333
}, nil
@@ -38,7 +38,7 @@ type Config struct{}
3838
func WithInstance(instance interface{}, config *Config) (database.Driver, error) {
3939
return &Stub{
4040
Instance: instance,
41-
CurrentVersion: -1,
41+
CurrentVersion: database.NilVersion,
4242
MigrationSequence: make([]string, 0),
4343
Config: config,
4444
}, nil
@@ -84,7 +84,7 @@ func (s *Stub) Version() (version int, dirty bool, err error) {
8484
const DROP = "DROP"
8585

8686
func (s *Stub) Drop() error {
87-
s.CurrentVersion = -1
87+
s.CurrentVersion = database.NilVersion
8888
s.LastRunMigration = nil
8989
s.MigrationSequence = append(s.MigrationSequence, DROP)
9090
return nil

database/testing/testing.go

+35-38
Original file line numberDiff line numberDiff line change
@@ -116,43 +116,40 @@ func TestDrop(t *testing.T, d database.Driver) {
116116
}
117117

118118
func TestSetVersion(t *testing.T, d database.Driver) {
119-
if err := d.SetVersion(1, true); err != nil {
120-
t.Fatal(err)
121-
}
122-
123-
// call again
124-
if err := d.SetVersion(1, true); err != nil {
125-
t.Fatal(err)
126-
}
127-
128-
v, dirty, err := d.Version()
129-
if err != nil {
130-
t.Fatal(err)
131-
}
132-
if !dirty {
133-
t.Fatal("expected dirty")
134-
}
135-
if v != 1 {
136-
t.Fatal("expected version to be 1")
137-
}
138-
139-
if err := d.SetVersion(2, false); err != nil {
140-
t.Fatal(err)
141-
}
142-
143-
// call again
144-
if err := d.SetVersion(2, false); err != nil {
145-
t.Fatal(err)
146-
}
147-
148-
v, dirty, err = d.Version()
149-
if err != nil {
150-
t.Fatal(err)
151-
}
152-
if dirty {
153-
t.Fatal("expected not dirty")
154-
}
155-
if v != 2 {
156-
t.Fatal("expected version to be 2")
119+
// nolint:maligned
120+
testCases := []struct {
121+
name string
122+
version int
123+
dirty bool
124+
expectedErr error
125+
expectedReadErr error
126+
expectedVersion int
127+
expectedDirty bool
128+
}{
129+
{name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
130+
{name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
131+
{name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
132+
{name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
133+
{name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true},
134+
{name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false},
135+
}
136+
137+
for _, tc := range testCases {
138+
t.Run(tc.name, func(t *testing.T) {
139+
err := d.SetVersion(tc.version, tc.dirty)
140+
if err != tc.expectedErr {
141+
t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr)
142+
}
143+
v, dirty, readErr := d.Version()
144+
if readErr != tc.expectedReadErr {
145+
t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr)
146+
}
147+
if v != tc.expectedVersion {
148+
t.Error("Got unexpected version:", v, "!=", tc.expectedVersion)
149+
}
150+
if dirty != tc.expectedDirty {
151+
t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty)
152+
}
153+
})
157154
}
158155
}

0 commit comments

Comments
 (0)