Skip to content

Commit

Permalink
fix --statement bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo committed Feb 24, 2025
1 parent 6274c62 commit dd2cdf3
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
2 changes: 2 additions & 0 deletions pkg/migration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (m *Migration) normalizeOptions() (stmt *statement.AbstractStatement, err e
if stmt.Schema != "" && stmt.Schema != m.Database {
return nil, errors.New("schema name in statement (`schema`.`table`) does not match --database")
}
stmt.Schema = m.Database
} else {
if m.Table == "" {
return nil, errors.New("table name is required")
Expand All @@ -104,6 +105,7 @@ func (m *Migration) normalizeOptions() (stmt *statement.AbstractStatement, err e
return nil, errors.New("could not parse SQL statement: " + fullStatement)
}
stmt = &statement.AbstractStatement{
Schema: m.Database,
Table: m.Table,
Alter: m.Alter,
Statement: fullStatement,
Expand Down
26 changes: 13 additions & 13 deletions pkg/migration/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (r *Runner) Run(originalCtx context.Context) error {
ctx, cancel := context.WithCancel(originalCtx)
defer cancel()
r.startTime = time.Now()
r.logger.Infof("Starting spirit migration: concurrency=%d target-chunk-size=%s table='%s.%s' alter=\"%s\"",
r.migration.Threads, r.migration.TargetChunkTime, r.migration.Database, r.stmt.Table, r.stmt.Alter,
r.logger.Infof("Starting spirit migration: concurrency=%d target-chunk-size=%s table='%s.%s' alter=%s ",
r.migration.Threads, r.migration.TargetChunkTime, r.stmt.Schema, r.stmt.Table, r.stmt.Alter,
)

// Create a database connection
Expand Down Expand Up @@ -186,7 +186,7 @@ func (r *Runner) Run(originalCtx context.Context) error {
}

// Get Table Info
r.table = table.NewTableInfo(r.db, r.migration.Database, r.stmt.Table)
r.table = table.NewTableInfo(r.db, r.stmt.Schema, r.stmt.Table)
if err := r.table.SetInfo(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -438,7 +438,7 @@ func (r *Runner) attemptMySQLDDL(ctx context.Context) error {
}

func (r *Runner) dsn() string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s", r.migration.Username, r.migration.Password, r.migration.Host, r.migration.Database)
return fmt.Sprintf("%s:%s@tcp(%s)/%s", r.migration.Username, r.migration.Password, r.migration.Host, r.stmt.Schema)
}

func (r *Runner) setup(ctx context.Context) error {
Expand Down Expand Up @@ -572,7 +572,7 @@ func (r *Runner) createNewTable(ctx context.Context) error {
r.table.SchemaName, newName, r.table.SchemaName, r.table.TableName); err != nil {
return err
}
r.newTable = table.NewTableInfo(r.db, r.migration.Database, newName)
r.newTable = table.NewTableInfo(r.db, r.stmt.Schema, newName)
if err := r.newTable.SetInfo(ctx); err != nil {
return err
}
Expand All @@ -589,7 +589,7 @@ func (r *Runner) alterNewTable(ctx context.Context) error {
// Retry without the ALGORITHM=COPY. If there is a second error, then the DDL itself
// is not supported. It could be a syntax error, in which case we return the second error,
// which will probably be easier to read because it is unaltered.
if err := dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.migration.Alter, r.newTable.SchemaName, r.newTable.TableName); err != nil {
if err := dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.stmt.Alter, r.newTable.SchemaName, r.newTable.TableName); err != nil {
return err
}
}
Expand All @@ -612,11 +612,11 @@ func (r *Runner) oldTableName() string {
}

func (r *Runner) attemptInstantDDL(ctx context.Context) error {
return dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.migration.Alter+", ALGORITHM=INSTANT", r.table.SchemaName, r.table.TableName)
return dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.stmt.Alter+", ALGORITHM=INSTANT", r.table.SchemaName, r.table.TableName)
}

func (r *Runner) attemptInplaceDDL(ctx context.Context) error {
return dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.migration.Alter+", ALGORITHM=INPLACE, LOCK=NONE", r.table.SchemaName, r.table.TableName)
return dbconn.Exec(ctx, r.db, "ALTER TABLE %n.%n "+r.stmt.Alter+", ALGORITHM=INPLACE, LOCK=NONE", r.table.SchemaName, r.table.TableName)
}

func (r *Runner) createCheckpointTable(ctx context.Context) error {
Expand Down Expand Up @@ -743,7 +743,7 @@ func (r *Runner) resumeFromCheckpoint(ctx context.Context) error {

// Make sure we can read from the new table.
if err := dbconn.Exec(ctx, r.db, "SELECT * FROM %n.%n LIMIT 1",
r.migration.Database, newName); err != nil {
r.stmt.Schema, newName); err != nil {
return fmt.Errorf("could not find any checkpoints in table '%s'", newName)
}

Expand All @@ -752,19 +752,19 @@ func (r *Runner) resumeFromCheckpoint(ctx context.Context) error {
// was created by either an earlier or later version of spirit, in which case
// we do not support recovery.
query := fmt.Sprintf("SELECT * FROM `%s`.`%s` ORDER BY id DESC LIMIT 1",
r.migration.Database, cpName)
r.stmt.Schema, cpName)
var copierWatermark, binlogName, alterStatement string
var id, binlogPos int
var rowsCopied, rowsCopiedLogical uint64
err := r.db.QueryRow(query).Scan(&id, &copierWatermark, &r.checksumWatermark, &binlogName, &binlogPos, &rowsCopied, &rowsCopiedLogical, &alterStatement)
if err != nil {
return fmt.Errorf("could not read from table '%s', err:%v", cpName, err)
}
if r.migration.Alter != alterStatement {
if r.stmt.Alter != alterStatement {
return ErrMismatchedAlter
}
// Populate the objects that would have been set in the other funcs.
r.newTable = table.NewTableInfo(r.db, r.migration.Database, newName)
r.newTable = table.NewTableInfo(r.db, r.stmt.Schema, newName)
if err := r.newTable.SetInfo(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -931,7 +931,7 @@ func (r *Runner) dumpCheckpoint(ctx context.Context) error {
binlog.Pos,
copyRows,
logicalCopyRows,
r.migration.Alter,
r.stmt.Alter,
)
}

Expand Down
29 changes: 29 additions & 0 deletions pkg/migration/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3057,3 +3057,32 @@ func TestPreventConcurrentRuns(t *testing.T) {
assert.Error(t, err)
assert.ErrorContains(t, err, "could not acquire metadata lock")
}

func TestStatementWorkflowStillInstant(t *testing.T) {
cfg, err := mysql.ParseDSN(testutils.DSN())
assert.NoError(t, err)

testutils.RunSQL(t, `DROP TABLE IF EXISTS stmtworkflow`)
table := `CREATE TABLE stmtworkflow (
id int(11) NOT NULL AUTO_INCREMENT,
b INT NOT NULL,
c INT NOT NULL,
PRIMARY KEY (id),
INDEX (b)
)`
testutils.RunSQL(t, table)
m, err := NewRunner(&Migration{
Host: cfg.Addr,
Username: cfg.User,
Password: cfg.Passwd,
Database: cfg.DBName,
Threads: 1,
Statement: "ALTER TABLE stmtworkflow ADD newcol INT",
})
assert.NoError(t, err)
err = m.Run(context.Background())
assert.NoError(t, err)

assert.True(t, m.usedInstantDDL) // expected to count as instant.
assert.NoError(t, m.Close())
}
3 changes: 3 additions & 0 deletions pkg/statement/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ func New(statement string) (*AbstractStatement, error) {
}
normalizedStmt := sb.String()
trimLen := len(alterStmt.Table.Name.String()) + 15 // len ALTER TABLE + quotes
if len(alterStmt.Table.Schema.String()) > 0 {
trimLen += len(alterStmt.Table.Schema.String()) + 3 // len schema + quotes and dot.
}
return &AbstractStatement{
Schema: alterStmt.Table.Schema.String(),
Table: alterStmt.Table.Name.String(),
Expand Down
6 changes: 6 additions & 0 deletions pkg/statement/statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ func TestExtractFromStatement(t *testing.T) {
assert.Equal(t, "t1", abstractStmt.Table)
assert.Equal(t, "ADD INDEX(`something`)", abstractStmt.Alter)

abstractStmt, err = New("ALTER TABLE test.t1 ADD INDEX (something)")
assert.NoError(t, err)
assert.Equal(t, "test", abstractStmt.Schema)
assert.Equal(t, "t1", abstractStmt.Table)
assert.Equal(t, "ADD INDEX(`something`)", abstractStmt.Alter)

abstractStmt, err = New("ALTER TABLE t1aaaa ADD COLUMN newcol int")
assert.NoError(t, err)
assert.Equal(t, "t1aaaa", abstractStmt.Table)
Expand Down

0 comments on commit dd2cdf3

Please sign in to comment.