diff --git a/go.mod b/go.mod index 177e673..2b125d1 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/siddontang/loggers v1.0.3 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 - golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/sync v0.4.0 ) @@ -34,6 +33,7 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect + golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index e0962b4..e5d4c35 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-mysql-org/go-mysql v1.8.1-0.20240805131754-ccf204bf2b2a h1:VO6kiE9ex1uNaCCgDz/q0EhTueLrr3vmxkjJpU2x6pk= -github.com/go-mysql-org/go-mysql v1.8.1-0.20240805131754-ccf204bf2b2a/go.mod h1:+SgFgTlqjqOQoMc98n9oyUWEgn2KkOL1VmXDoq2ONOs= github.com/go-mysql-org/go-mysql v1.9.1 h1:W2ZKkHkoM4mmkasJCoSYfaE4RQNxXTb6VqiaMpKFrJc= github.com/go-mysql-org/go-mysql v1.9.1/go.mod h1:+SgFgTlqjqOQoMc98n9oyUWEgn2KkOL1VmXDoq2ONOs= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= diff --git a/pkg/check/check.go b/pkg/check/check.go index 91bc087..029c542 100644 --- a/pkg/check/check.go +++ b/pkg/check/check.go @@ -29,7 +29,8 @@ type Resources struct { DB *sql.DB Replica *sql.DB Table *table.TableInfo - Alter string + Alter string // as ALTER + Statement string // as full SQL statement TargetChunkTime time.Duration Threads int ReplicaMaxLag time.Duration diff --git a/pkg/check/dropadd.go b/pkg/check/dropadd.go index fc64761..49efce7 100644 --- a/pkg/check/dropadd.go +++ b/pkg/check/dropadd.go @@ -21,11 +21,10 @@ func init() { // - We only allow a column name to be mentioned once across all // DROP and ADD parts of the alter statement. func dropAddCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - sql := fmt.Sprintf("ALTER TABLE %s %s", r.Table.TableName, r.Alter) p := parser.New() - stmtNodes, _, err := p.Parse(sql, "", "") + stmtNodes, _, err := p.Parse(r.Statement, "", "") if err != nil { - return fmt.Errorf("could not parse alter table statement: %s", sql) + return fmt.Errorf("could not parse alter table statement: %s", r.Statement) } stmt := &stmtNodes[0] alterStmt, ok := (*stmt).(*ast.AlterTableStmt) diff --git a/pkg/check/dropadd_test.go b/pkg/check/dropadd_test.go index 05e0301..65741ff 100644 --- a/pkg/check/dropadd_test.go +++ b/pkg/check/dropadd_test.go @@ -4,25 +4,23 @@ import ( "context" "testing" - "github.com/cashapp/spirit/pkg/table" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestDropAdd(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "DROP b, ADD b INT", + Statement: "ALTER TABLE t.t1 DROP b, ADD b INT", } err := dropAddCheck(context.Background(), r, logrus.New()) assert.Error(t, err) assert.ErrorContains(t, err, "column b is mentioned 2 times in the same statement") - r.Alter = "DROP b1, ADD b2 INT" + r.Statement = "ALTER TABLE t.t1 DROP b1, ADD b2 INT" err = dropAddCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) - r.Alter = "bogus" + r.Statement = "bogus" err = dropAddCheck(context.Background(), r, logrus.New()) assert.Error(t, err) } diff --git a/pkg/check/foreignkey.go b/pkg/check/foreignkey.go index be84258..e1ab048 100644 --- a/pkg/check/foreignkey.go +++ b/pkg/check/foreignkey.go @@ -38,11 +38,10 @@ func hasForeignKeysCheck(ctx context.Context, r Resources, logger loggers.Advanc } func addForeignKeyCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - sql := fmt.Sprintf("ALTER TABLE %s %s", r.Table.TableName, r.Alter) p := parser.New() - stmtNodes, _, err := p.Parse(sql, "", "") + stmtNodes, _, err := p.Parse(r.Statement, "", "") if err != nil { - return fmt.Errorf("could not parse alter table statement: %s", sql) + return fmt.Errorf("could not parse alter table statement: %s", r.Statement) } stmt := &stmtNodes[0] alterStmt, ok := (*stmt).(*ast.AlterTableStmt) diff --git a/pkg/check/foreignkey_test.go b/pkg/check/foreignkey_test.go index b69cba9..0726f56 100644 --- a/pkg/check/foreignkey_test.go +++ b/pkg/check/foreignkey_test.go @@ -13,18 +13,17 @@ import ( func TestAddForeignKey(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "ADD FOREIGN KEY (customer_id) REFERENCES customers (id)", + Statement: "ALTER TABLE t1 ADD FOREIGN KEY (customer_id) REFERENCES customers (id)", } err := addForeignKeyCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // add foreign key assert.ErrorContains(t, err, "adding foreign key constraints is not supported") - r.Alter = "DROP COLUMN foo" + r.Statement = "ALTER TABLE t1 DROP COLUMN foo" err = addForeignKeyCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) // regular DDL - r.Alter = "bogus" + r.Statement = "bogus" err = addForeignKeyCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // not a valid ddl } diff --git a/pkg/check/illegalclause.go b/pkg/check/illegalclause.go index cde1909..35bfe6f 100644 --- a/pkg/check/illegalclause.go +++ b/pkg/check/illegalclause.go @@ -8,12 +8,11 @@ import ( ) func init() { - registerCheck("illegalClause", illegalClauseCheck, ScopePreRun) + registerCheck("illegalClause", illegalClauseCheck, ScopePreflight) } // illegalClauseCheck checks for the presence of specific, unsupported // clauses in the ALTER statement, such as ALGORITHM= and LOCK=. func illegalClauseCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - sql := "ALTER TABLE x.x " + r.Alter - return utils.AlterContainsUnsupportedClause(sql) + return utils.AlterContainsUnsupportedClause(r.Statement) } diff --git a/pkg/check/illegalclause_test.go b/pkg/check/illegalclause_test.go index 893c4c2..0e10b16 100644 --- a/pkg/check/illegalclause_test.go +++ b/pkg/check/illegalclause_test.go @@ -4,41 +4,30 @@ import ( "context" "testing" - "github.com/cashapp/spirit/pkg/table" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestIllegalClauseCheck(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "ALGORITHM=INPLACE", + Statement: "ALTER TABLE t1 ADD INDEX (b), ALGORITHM=INPLACE", } err := illegalClauseCheck(context.Background(), r, logrus.New()) assert.Error(t, err) - assert.ErrorContains(t, err, "ALTER contains unsupported clause") + assert.ErrorContains(t, err, "contains unsupported clause") - r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "ALGORITHM=INPLACE, LOCK=shared", - } + r.Statement = "ALTER TABLE t1 ADD c INT, ALGORITHM=INPLACE, LOCK=shared" err = illegalClauseCheck(context.Background(), r, logrus.New()) assert.Error(t, err) - assert.ErrorContains(t, err, "ALTER contains unsupported clause") + assert.ErrorContains(t, err, "contains unsupported clause") - r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "lock=none", - } + r.Statement = "ALTER TABLE t1 ADD c INT, lock=none" err = illegalClauseCheck(context.Background(), r, logrus.New()) assert.Error(t, err) - assert.ErrorContains(t, err, "ALTER contains unsupported clause") + assert.ErrorContains(t, err, "contains unsupported clause") - r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "engine=innodb, algorithm=copy", - } + r.Statement = "ALTER TABLE t1 engine=innodb, algorithm=copy" err = illegalClauseCheck(context.Background(), r, logrus.New()) assert.Error(t, err) - assert.ErrorContains(t, err, "ALTER contains unsupported clause") + assert.ErrorContains(t, err, "contains unsupported clause") } diff --git a/pkg/check/primarykey.go b/pkg/check/primarykey.go index 38aef3e..05263f8 100644 --- a/pkg/check/primarykey.go +++ b/pkg/check/primarykey.go @@ -16,11 +16,10 @@ func init() { } func primaryKeyCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - sql := fmt.Sprintf("ALTER TABLE %s %s", r.Table.TableName, r.Alter) p := parser.New() - stmtNodes, _, err := p.Parse(sql, "", "") + stmtNodes, _, err := p.Parse(r.Statement, "", "") if err != nil { - return fmt.Errorf("could not parse alter table statement: %s", sql) + return fmt.Errorf("could not parse alter table statement: %s", r.Statement) } stmt := &stmtNodes[0] alterStmt, ok := (*stmt).(*ast.AlterTableStmt) diff --git a/pkg/check/primarykey_test.go b/pkg/check/primarykey_test.go index 30b0a13..290ddc7 100644 --- a/pkg/check/primarykey_test.go +++ b/pkg/check/primarykey_test.go @@ -4,29 +4,25 @@ import ( "context" "testing" - "github.com/cashapp/spirit/pkg/table" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestPrimaryKey(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "DROP PRIMARY KEY, ADD PRIMARY KEY (anothercol)", + Statement: "ALTER TABLE t.t1 DROP PRIMARY KEY, ADD PRIMARY KEY (anothercol)", } err := primaryKeyCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // drop primary key r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "ADD INDEX (anothercol)", + Statement: "ALTER TABLE t.t1 ADD INDEX (anothercol)", } err = primaryKeyCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) // safe modification r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "gibberish", + Statement: "gibberish", } err = primaryKeyCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // gibberish diff --git a/pkg/check/rename.go b/pkg/check/rename.go index df190b3..d0b6e80 100644 --- a/pkg/check/rename.go +++ b/pkg/check/rename.go @@ -17,11 +17,10 @@ func init() { // renameCheck checks for any renames, which are not supported. func renameCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - sql := fmt.Sprintf("ALTER TABLE %s %s", r.Table.TableName, r.Alter) p := parser.New() - stmtNodes, _, err := p.Parse(sql, "", "") + stmtNodes, _, err := p.Parse(r.Statement, "", "") if err != nil { - return fmt.Errorf("could not parse alter table statement: %s", sql) + return fmt.Errorf("could not parse alter table statement: %s", r.Statement) } stmt := &stmtNodes[0] alterStmt, ok := (*stmt).(*ast.AlterTableStmt) diff --git a/pkg/check/rename_test.go b/pkg/check/rename_test.go index 007ae5a..f03abc7 100644 --- a/pkg/check/rename_test.go +++ b/pkg/check/rename_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/cashapp/spirit/pkg/table" _ "github.com/pingcap/tidb/pkg/parser/test_driver" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -12,46 +11,40 @@ import ( func TestRename(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "RENAME TO newtablename", + Statement: "ALTER TABLE t.t1 RENAME TO newtablename", } err := renameCheck(context.Background(), r, logrus.New()) assert.Error(t, err) assert.ErrorContains(t, err, "renames are not supported") r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "RENAME COLUMN c1 TO c2", + Statement: "ALTER TABLE t.t1 RENAME COLUMN c1 TO c2", } err = renameCheck(context.Background(), r, logrus.New()) assert.Error(t, err) assert.ErrorContains(t, err, "renames are not supported") r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "CHANGE c1 c2 VARCHAR(100)", + Statement: "ALTER TABLE t.t1 CHANGE c1 c2 VARCHAR(100)", } err = renameCheck(context.Background(), r, logrus.New()) assert.Error(t, err) assert.ErrorContains(t, err, "renames are not supported") r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "CHANGE c1 c1 VARCHAR(100)", //nolint: dupword + Statement: "ALTER TABLE t.t1 CHANGE c1 c1 VARCHAR(100)", //nolint: dupword } err = renameCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "ADD INDEX (anothercol)", + Statement: "ALTER TABLE t.t1 ADD INDEX (anothercol)", } err = renameCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) // safe modification r = Resources{ - Table: &table.TableInfo{TableName: "test"}, - Alter: "gibberish", + Statement: "gibberish", } err = renameCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // gibberish diff --git a/pkg/check/triggers.go b/pkg/check/triggers.go index 5901511..9e07e97 100644 --- a/pkg/check/triggers.go +++ b/pkg/check/triggers.go @@ -3,11 +3,8 @@ package check import ( "context" "errors" - "fmt" "strings" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" _ "github.com/pingcap/tidb/pkg/parser/test_driver" "github.com/siddontang/loggers" ) @@ -21,6 +18,9 @@ func init() { func hasTriggersCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { sql := `SELECT * FROM information_schema.triggers WHERE (event_object_schema=? AND event_object_table=?)` + if r.DB == nil { + return errors.New("no database connection") + } rows, err := r.DB.QueryContext(ctx, sql, r.Table.SchemaName, r.Table.TableName) if err != nil { return err @@ -36,21 +36,11 @@ func hasTriggersCheck(ctx context.Context, r Resources, logger loggers.Advanced) } // addTriggersCheck checks for trigger creations, which is not supported +// Since the TiDB parser doesn't support this, we are using strings.Contains :( func addTriggersCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { - isAddingTrigger := strings.Contains(strings.ToUpper(r.Alter), "CREATE TRIGGER") + isAddingTrigger := strings.Contains(strings.ToUpper(strings.TrimSpace(r.Statement)), "CREATE TRIGGER") if isAddingTrigger { return errors.New("adding triggers is not supported") } - sql := fmt.Sprintf("ALTER TABLE %s %s", r.Table.TableName, r.Alter) - p := parser.New() - stmtNodes, _, err := p.Parse(sql, "", "") - if err != nil { - return fmt.Errorf("could not parse alter table statement: %s", sql) - } - stmt := &stmtNodes[0] - _, ok := (*stmt).(*ast.AlterTableStmt) - if !ok { - return errors.New("not a valid alter table statement") - } return nil // no problems } diff --git a/pkg/check/triggers_test.go b/pkg/check/triggers_test.go index b273059..e651055 100644 --- a/pkg/check/triggers_test.go +++ b/pkg/check/triggers_test.go @@ -16,20 +16,20 @@ import ( func TestAddTriggers(t *testing.T) { r := Resources{ - Table: &table.TableInfo{TableName: "account"}, - Alter: "CREATE TRIGGER ins_sum BEFORE INSERT ON account FOR EACH ROW SET @sum = @sum + NEW.amount;", + Table: &table.TableInfo{TableName: "account"}, + Statement: "CREATE TRIGGER ins_sum BEFORE INSERT ON account FOR EACH ROW SET @sum = @sum + NEW.amount;", } err := addTriggersCheck(context.Background(), r, logrus.New()) assert.Error(t, err) // add triggers assert.ErrorContains(t, err, "adding triggers is not supported") - r.Alter = "DROP COLUMN foo" - err = addForeignKeyCheck(context.Background(), r, logrus.New()) + r.Statement = "ALTER TABLE t.t1 DROP COLUMN foo" + err = addTriggersCheck(context.Background(), r, logrus.New()) assert.NoError(t, err) // regular DDL - r.Alter = "bogus" - err = addForeignKeyCheck(context.Background(), r, logrus.New()) - assert.Error(t, err) // not a valid ddl + r.Statement = "bogus" + err = addTriggersCheck(context.Background(), r, logrus.New()) + assert.NoError(t, err) // not a valid ddl, but thats ok } func TestHasTriggers(t *testing.T) { diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index eb9e143..d322deb 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -4,9 +4,13 @@ package migration import ( "context" "errors" + "fmt" + "strings" "time" "github.com/cashapp/spirit/pkg/check" + "github.com/cashapp/spirit/pkg/table" + "github.com/cashapp/spirit/pkg/utils" ) var ( @@ -31,6 +35,7 @@ type Migration struct { DeferCutOver bool `name:"defer-cutover" help:"Defer cutover (and checksum) until sentinel table is dropped" optional:"" default:"false"` Strict bool `name:"strict" help:"Exit on --alter mismatch when incomplete migration is detected" optional:"" default:"false"` InterpolateParams bool `name:"interpolate-params" help:"Enable interpolate params for DSN" optional:"" default:"false" hidden:""` + Statement string `name:"statement" help:"The SQL statement to run (replaces --table and --alter)" optional:"" default:""` } func (m *Migration) Run() error { @@ -47,3 +52,52 @@ func (m *Migration) Run() error { } return nil } + +// normalizeOptions does some validation and sets defaults. +// for example, it validates that only --statement or --table and --alter are specified. +func (m *Migration) normalizeOptions() error { + if m.TargetChunkTime == 0 { + m.TargetChunkTime = table.ChunkerDefaultTarget + } + if m.Threads == 0 { + m.Threads = 4 + } + if m.ReplicaMaxLag == 0 { + m.ReplicaMaxLag = 120 * time.Second + } + if m.Host == "" { + return errors.New("host is required") + } + if !strings.Contains(m.Host, ":") { + m.Host = fmt.Sprintf("%s:%d", m.Host, 3306) + } + if m.Database == "" { + return errors.New("schema name is required") + } + if m.Statement != "" { // statement is specified + if m.Table != "" || m.Alter != "" { + return errors.New("only --statement or --table and --alter can be specified") + } + // extract the table and alter from the statement. + // if it is a CREATE INDEX statement, we rewrite it to an alter statement. + // This also returns the StmtNode, which we do nothing with so far. + // TODO: if it's a CREATE INDEX statement, do we store the rewritten + // form in m.Statement? + var err error + m.Table, m.Alter, err = utils.ExtractFromStatement(m.Statement) + if err != nil { + return errors.New("could not parse statement") + } + } else { + if m.Table == "" { + return errors.New("table name is required") + } + if m.Alter == "" { + return errors.New("alter statement is required") + } + // Add the statement so that it can be used in some contexts + // Where the statement is preferred over the table and alter. + m.Statement = fmt.Sprintf("ALTER TABLE `%s`.`%s` %s", m.Database, m.Table, m.Alter) + } + return nil +} diff --git a/pkg/migration/migration_test.go b/pkg/migration/migration_test.go index dfa548d..0907289 100644 --- a/pkg/migration/migration_test.go +++ b/pkg/migration/migration_test.go @@ -332,3 +332,36 @@ func TestConvertCharset(t *testing.T) { err = migration.Run() assert.NoError(t, err) } + +func TestStmtWorkflow(t *testing.T) { + testutils.RunSQL(t, `DROP TABLE IF EXISTS t1s, _t1s_new`) + table := `CREATE TABLE t1s ( + id int not null primary key auto_increment, + b varchar(100) not null + )` + cfg, err := mysql.ParseDSN(testutils.DSN()) + assert.NoError(t, err) + migration := &Migration{ + Host: cfg.Addr, + Username: cfg.User, + Password: cfg.Passwd, + Database: cfg.DBName, + Threads: 1, + Checksum: true, + Statement: table, // CREATE TABLE. + } + err = migration.Run() + assert.NoError(t, err) + // We can also specify ALTER options in the statement. + migration = &Migration{ + Host: cfg.Addr, + Username: cfg.User, + Password: cfg.Passwd, + Database: cfg.DBName, + Threads: 1, + Checksum: true, + Statement: "ALTER TABLE t1s ADD COLUMN c int", // ALTER TABLE. + } + err = migration.Run() + assert.NoError(t, err) +} diff --git a/pkg/migration/runner.go b/pkg/migration/runner.go index df1080c..76f1360 100644 --- a/pkg/migration/runner.go +++ b/pkg/migration/runner.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "strings" "sync" "sync/atomic" "time" @@ -121,37 +120,14 @@ type Progress struct { } func NewRunner(m *Migration) (*Runner, error) { - r := &Runner{ + if err := m.normalizeOptions(); err != nil { + return nil, err + } + return &Runner{ migration: m, logger: logrus.New(), metricsSink: &metrics.NoopSink{}, - } - - if r.migration.TargetChunkTime == 0 { - r.migration.TargetChunkTime = table.ChunkerDefaultTarget - } - if r.migration.Threads == 0 { - r.migration.Threads = 4 - } - if r.migration.ReplicaMaxLag == 0 { - r.migration.ReplicaMaxLag = 120 * time.Second - } - if r.migration.Host == "" { - return nil, errors.New("host is required") - } - if !strings.Contains(r.migration.Host, ":") { - r.migration.Host = fmt.Sprintf("%s:%d", r.migration.Host, 3306) - } - if r.migration.Database == "" { - return nil, errors.New("schema name is required") - } - if r.migration.Table == "" { - return nil, errors.New("table name is required") - } - if r.migration.Alter == "" { - return nil, errors.New("alter statement is required") - } - return r, nil + }, nil } func (r *Runner) SetMetricsSink(sink metrics.Sink) { @@ -166,7 +142,7 @@ func (r *Runner) Run(originalCtx context.Context) error { ctx, cancel := context.WithCancel(originalCtx) defer cancel() r.startTime = time.Now() - r.logger.Infof("Starting spirit migration: concurrency=%d target-chunk-size=%s table=%s.%s alter=\"%s\"", + r.logger.Infof("Starting spirit migration: concurrency=%d target-chunk-size=%s table='%s.%s' alter=\"%s\"", r.migration.Threads, r.migration.TargetChunkTime, r.migration.Database, r.migration.Table, r.migration.Alter, ) @@ -195,6 +171,17 @@ func (r *Runner) Run(originalCtx context.Context) error { return err } + if r.migration.Alter == "" { + // It's a CREATE TABLE, DROP TABLE, or RENAME table. + // These are always instant. + err := dbconn.Exec(ctx, r.db, r.migration.Statement) + if err != nil { + return err + } + r.logger.Infof("apply complete") + return nil + } + // Get Table Info r.table = table.NewTableInfo(r.db, r.migration.Database, r.migration.Table) if err := r.table.SetInfo(ctx); err != nil { @@ -232,7 +219,7 @@ func (r *Runner) Run(originalCtx context.Context) error { // Force enable the checksum if it's an ADD UNIQUE INDEX operation // https://github.com/cashapp/spirit/issues/266 if !r.migration.Checksum { - if err := utils.AlterContainsAddUnique("ALTER TABLE unused " + r.migration.Alter); err != nil { + if err := utils.AlterContainsAddUnique(r.migration.Statement); err != nil { r.logger.Warnf("force enabling checksum: %v", err) r.migration.Checksum = true } @@ -243,7 +230,7 @@ func (r *Runner) Run(originalCtx context.Context) error { // It likely means the user is combining this operation with other unsafe operations, // which is not a good idea. We need to protect them by not allowing it. // https://github.com/cashapp/spirit/issues/283 - if err := utils.AlterContainsIndexVisibility("ALTER TABLE unused " + r.migration.Alter); err != nil { + if err := utils.AlterContainsIndexVisibility(r.migration.Statement); err != nil { return err } @@ -395,6 +382,7 @@ func (r *Runner) runChecks(ctx context.Context, scope check.ScopeFlag) error { Replica: r.replica, Table: r.table, Alter: r.migration.Alter, + Statement: r.migration.Statement, TargetChunkTime: r.migration.TargetChunkTime, Threads: r.migration.Threads, ReplicaMaxLag: r.migration.ReplicaMaxLag, @@ -429,8 +417,7 @@ func (r *Runner) attemptMySQLDDL(ctx context.Context) error { // If the operator has specified that they want to attempt // an inplace add index, we will attempt inplace regardless // of the statement. - alterStmt := fmt.Sprintf("ALTER TABLE %s %s", r.migration.Table, r.migration.Alter) - err = utils.AlgorithmInplaceConsideredSafe(alterStmt) + err = utils.AlgorithmInplaceConsideredSafe(r.migration.Statement) if err != nil { r.logger.Infof("unable to use INPLACE: %v", err) } @@ -876,7 +863,7 @@ func (r *Runner) checksum(ctx context.Context) error { // then the checksum will fail. This is entirely expected, and not considered a bug. We should // do our best-case to differentiate that we believe this ALTER statement is lossy, and // customize the returned error based on it. - if err := utils.AlterContainsAddUnique("ALTER TABLE unused " + r.migration.Alter); err != nil { + if err := utils.AlterContainsAddUnique(r.migration.Statement); err != nil { return errors.New("checksum failed after 3 attempts. Check that the ALTER statement is not adding a UNIQUE INDEX to non-unique data") } return errors.New("checksum failed after 3 attempts. This likely indicates either a bug in Spirit, or a manual modification to the _new table outside of Spirit. Please report @ github.com/cashapp/spirit") diff --git a/pkg/migration/runner_test.go b/pkg/migration/runner_test.go index 1a5a094..4964dcb 100644 --- a/pkg/migration/runner_test.go +++ b/pkg/migration/runner_test.go @@ -1976,7 +1976,8 @@ func TestResumeFromCheckpointStrict(t *testing.T) { // by disabling --strict migrationB.Strict = false - migrationB.Threads = 4 // to make the test run faster + migrationB.Threads = 4 // to make the test run faster + migrationB.Statement = "" // reset runner, err = NewRunner(migrationB) assert.NoError(t, err) @@ -2830,20 +2831,6 @@ func TestPreRunChecksE2E(t *testing.T) { defer db.Close() err = m.runChecks(context.TODO(), check.ScopePreRun) assert.NoError(t, err) - - m, err = NewRunner(&Migration{ - Host: cfg.Addr, - Username: cfg.User, - Password: cfg.Passwd, - Database: cfg.DBName, - Threads: 1, - Table: "test_checks_e2e", - Alter: "ALGORITHM=inplace", - }) - assert.NoError(t, err) - err = m.runChecks(context.TODO(), check.ScopePreRun) - assert.Error(t, err) - assert.ErrorContains(t, err, "unsupported clause") } // From https://github.com/cashapp/spirit/issues/241 diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 37fd228..625a81c 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -10,6 +10,7 @@ import ( "github.com/cashapp/spirit/pkg/table" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" ) const ( @@ -186,3 +187,74 @@ func AlterContainsIndexVisibility(sql string) error { func TrimAlter(alter string) string { return strings.TrimSuffix(strings.TrimSpace(alter), ";") } + +func ExtractFromStatement(statement string) (table string, alter string, err error) { + p := parser.New() + stmtNodes, _, err := p.Parse(statement, "", "") + if err != nil { + return "", "", err + } + if len(stmtNodes) != 1 { + return "", "", errors.New("only one statement may be specified at once") + } + switch stmtNodes[0].(type) { + case *ast.AlterTableStmt: + // type assert stmtNodes[0] as an AlterTableStmt and then + // extract the table and alter from it. + // TODO: handle the database name correctly, as it might differ from + // what was specified as --database. + alterStmt := stmtNodes[0].(*ast.AlterTableStmt) + var sb strings.Builder + sb.Reset() + rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + if err = alterStmt.Restore(rCtx); err != nil { + return "", "", fmt.Errorf("could not restore alter table statement: %s", err) + } + normalizedStmt := sb.String() + trimLen := len(alterStmt.Table.Name.String()) + 15 // len ALTER TABLE + quotes + if len(alterStmt.Table.Schema.String()) > 0 { + trimLen += len(alterStmt.Table.Schema.String()) + 3 // len schema + quotes and dot. + } + return alterStmt.Table.Name.String(), normalizedStmt[trimLen:], nil + case *ast.CreateIndexStmt: + // Need to rewrite to a corresponding ALTER TABLE statement + table, alter, err = convertCreateIndexToAlterTable(stmtNodes[0]) + return table, alter, err + // returning an empty alter means we are allowed to run it + // but it's not a spirit migration. But the table should be specified. + case *ast.CreateTableStmt: + createStmt := stmtNodes[0].(*ast.CreateTableStmt) + return createStmt.Table.Name.String(), "", nil + case *ast.DropTableStmt: + dropStmt := stmtNodes[0].(*ast.DropTableStmt) + return dropStmt.Tables[0].Name.String(), "", nil + case *ast.RenameTableStmt: + renameStmt := stmtNodes[0].(*ast.RenameTableStmt) + return renameStmt.TableToTables[0].OldTable.Name.String(), "", nil + } + // default: + return "", "", errors.New("not a supported statement type") +} + +func convertCreateIndexToAlterTable(stmt ast.StmtNode) (table string, alter string, err error) { + ciStmt, isCreateIndexStmt := stmt.(*ast.CreateIndexStmt) + if !isCreateIndexStmt { + return "", "", errors.New("not a CREATE INDEX statement") + } + var columns []string + var keyType string + for _, part := range ciStmt.IndexPartSpecifications { + columns = append(columns, part.Column.Name.String()) + } + switch ciStmt.KeyType { + case ast.IndexKeyTypeUnique: + keyType = "UNIQUE INDEX" + case ast.IndexKeyTypeFullText: + keyType = "FULLTEXT INDEX" + case ast.IndexKeyTypeSpatial: + keyType = "SPATIAL INDEX" + default: + keyType = "INDEX" + } + return ciStmt.Table.Name.String(), fmt.Sprintf("ADD %s %s (%s)", keyType, ciStmt.IndexName, strings.Join(columns, ", ")), nil +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index e3f051b..8475dc2 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -111,3 +111,25 @@ func TestTrimAlter(t *testing.T) { assert.Equal(t, "add column a, add column b", TrimAlter("add column a, add column b;")) assert.Equal(t, "add column a, add column b", TrimAlter("add column a, add column b")) } + +func TestExtractFromStatement(t *testing.T) { + table, alter, err := ExtractFromStatement("ALTER TABLE t1 ADD INDEX (something)") + assert.NoError(t, err) + assert.Equal(t, "t1", table) + assert.Equal(t, "ADD INDEX(`something`)", alter) + + table, alter, err = ExtractFromStatement("ALTER TABLE t.t1aaaa ADD COLUMN newcol int") + assert.NoError(t, err) + assert.Equal(t, "t1aaaa", table) + assert.Equal(t, "ADD COLUMN `newcol` INT", alter) + + table, alter, err = ExtractFromStatement("ALTER TABLE t.t1 DROP COLUMN foo") + assert.NoError(t, err) + assert.Equal(t, "t1", table) + assert.Equal(t, "DROP COLUMN `foo`", alter) + + table, alter, err = ExtractFromStatement("CREATE TABLE t1 (a int)") + assert.NoError(t, err) + assert.Equal(t, "t1", table) + assert.Empty(t, alter) +}