diff --git a/.gitignore b/.gitignore index 88f660f..c4a7aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ cmd/spirit/spirit # Dependency directories (remove the comment below to include it) # vendor/ + +.goosehints +.goose \ No newline at end of file diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index f478381..2df31f7 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -84,12 +84,12 @@ func (m *Migration) normalizeOptions() (stmt *statement.AbstractStatement, err e // This also returns the StmtNode. stmt, err = statement.New(m.Statement) if err != nil { - if err == statement.ErrSchemaNameIncluded { - return nil, err - } // Omit the parser error messages, just show the statement. return nil, errors.New("could not parse SQL statement: " + m.Statement) } + if stmt.Schema != "" && stmt.Schema != m.Database { + return nil, errors.New("schema name in statement (`schema`.`table`) does not match --database") + } } else { if m.Table == "" { return nil, errors.New("table name is required") diff --git a/pkg/migration/migration_test.go b/pkg/migration/migration_test.go index 73ee3fa..fefc38a 100644 --- a/pkg/migration/migration_test.go +++ b/pkg/migration/migration_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + _ "github.com/pingcap/tidb/pkg/parser/test_driver" "github.com/cashapp/spirit/pkg/testutils" @@ -443,6 +445,7 @@ func TestCreateIndexIsRewritten(t *testing.T) { testutils.RunSQL(t, tbl) cfg, err := mysql.ParseDSN(testutils.DSN()) assert.NoError(t, err) + require.NotEqual(t, "", cfg.DBName) migration := &Migration{ Host: cfg.Addr, Username: cfg.User, @@ -450,7 +453,7 @@ func TestCreateIndexIsRewritten(t *testing.T) { Database: cfg.DBName, Threads: 1, Checksum: true, - Statement: "CREATE INDEX idx ON t1createindex (b)", + Statement: "CREATE INDEX idx ON " + cfg.DBName + ".t1createindex (b)", } err = migration.Run() assert.NoError(t, err) @@ -475,6 +478,5 @@ func TestSchemaNameIncluded(t *testing.T) { Statement: "ALTER TABLE test.t1schemaname ADD COLUMN c int", } err = migration.Run() - assert.Error(t, err) - assert.ErrorContains(t, err, "schema name included in the table name is not supported") + assert.NoError(t, err) } diff --git a/pkg/statement/statement.go b/pkg/statement/statement.go index 122b07e..594ffd9 100644 --- a/pkg/statement/statement.go +++ b/pkg/statement/statement.go @@ -13,14 +13,14 @@ import ( ) type AbstractStatement struct { - Table string + Schema string // this will be empty unless the table name is fully qualified (ALTER TABLE test.t1 ...) + Table string // for statements that affect multiple tables (DROP TABLE t1, t2), only the first is set here! Alter string // may be empty. Statement string StmtNode *ast.StmtNode } var ( - ErrSchemaNameIncluded = errors.New("schema name included in the table name is not supported") ErrNotSupportedStatement = errors.New("not a supported statement type") ErrNotAlterTable = errors.New("not an ALTER TABLE statement") ) @@ -42,18 +42,16 @@ func New(statement string) (*AbstractStatement, error) { // if the schema name is included it could be different from the --database // specified, which causes all sorts of problems. The easiest way to handle this // it just to not permit it. - if alterStmt.Table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } var sb strings.Builder sb.Reset() rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) if err = alterStmt.Restore(rCtx); err != nil { - return nil, fmt.Errorf("could not restore alter table statement: %s", err) + return nil, fmt.Errorf("could not restore alter clause statement: %s", err) } normalizedStmt := sb.String() trimLen := len(alterStmt.Table.Name.String()) + 15 // len ALTER TABLE + quotes return &AbstractStatement{ + Schema: alterStmt.Table.Schema.String(), Table: alterStmt.Table.Name.String(), Alter: normalizedStmt[trimLen:], Statement: statement, @@ -66,34 +64,41 @@ func New(statement string) (*AbstractStatement, error) { // but it's not a spirit migration. But the table should be specified. case *ast.CreateTableStmt: stmt := stmtNodes[0].(*ast.CreateTableStmt) - if stmt.Table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } return &AbstractStatement{ + Schema: stmt.Table.Schema.String(), Table: stmt.Table.Name.String(), Statement: statement, StmtNode: &stmtNodes[0], }, err case *ast.DropTableStmt: stmt := stmtNodes[0].(*ast.DropTableStmt) + distinctSchemas := make(map[string]struct{}) for _, table := range stmt.Tables { - if table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } + distinctSchemas[table.Schema.String()] = struct{}{} + } + if len(distinctSchemas) > 1 { + return nil, errors.New("statement attempts to drop tables from multiple schemas") } return &AbstractStatement{ + Schema: stmt.Tables[0].Schema.String(), Table: stmt.Tables[0].Name.String(), // TODO: this is just used in log lines, but there could be more than one! Statement: statement, StmtNode: &stmtNodes[0], }, err case *ast.RenameTableStmt: stmt := stmtNodes[0].(*ast.RenameTableStmt) - for _, table := range stmt.TableToTables { - if table.OldTable.Schema.String() != "" || table.NewTable.Schema.String() != "" { - return nil, ErrSchemaNameIncluded + distinctSchemas := make(map[string]struct{}) + for _, clause := range stmt.TableToTables { + if clause.OldTable.Schema.String() != clause.NewTable.Schema.String() { + return nil, errors.New("statement attempts to move table between schemas") } + distinctSchemas[clause.OldTable.Schema.String()] = struct{}{} + } + if len(distinctSchemas) > 1 { + return nil, errors.New("statement attempts to rename tables in multiple schemas") } return &AbstractStatement{ + Schema: stmt.TableToTables[0].OldTable.Schema.String(), Table: stmt.TableToTables[0].OldTable.Name.String(), // TODO: this is just used in log lines, but there could be more than one! Statement: statement, StmtNode: &stmtNodes[0], @@ -238,6 +243,8 @@ func convertCreateIndexToAlterTable(stmt ast.StmtNode) (*AbstractStatement, erro alterStmt := fmt.Sprintf("ADD %s %s (%s)", keyType, ciStmt.IndexName, strings.Join(columns, ", ")) // We hint in the statement that it's been rewritten // and in the stmtNode we reparse from the alterStmt. + // TODO: include schema name if it's explicitly given in the CREATE INDEX statement? + // TODO: identifiers should be quoted/escaped in case a maniac includes a backtick in a table name. statement := fmt.Sprintf("ALTER TABLE `%s` %s", ciStmt.Table.Name, alterStmt) p := parser.New() stmtNodes, _, err := p.Parse(statement, "", "") diff --git a/pkg/statement/statement_test.go b/pkg/statement/statement_test.go index bee8670..d217401 100644 --- a/pkg/statement/statement_test.go +++ b/pkg/statement/statement_test.go @@ -33,9 +33,10 @@ func TestExtractFromStatement(t *testing.T) { _, err = New("ALTER TABLE t1 ADD INDEX (something); ALTER TABLE t2 ADD INDEX (something)") assert.Error(t, err) - // Try and include the schema name. - _, err = New("ALTER TABLE test.t1 ADD INDEX (something)") - assert.Error(t, err) + // Include the schema name. + abstractStmt, err = New("ALTER TABLE test.t1 ADD INDEX (something)") + assert.NoError(t, err) + assert.Equal(t, abstractStmt.Schema, "test") // Try and parse an invalid statement. _, err = New("ALTER TABLE t1 yes") @@ -48,6 +49,12 @@ func TestExtractFromStatement(t *testing.T) { assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter) assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement) + abstractStmt, err = New("CREATE INDEX idx ON test.`t1` (a)") + assert.NoError(t, err) + assert.Equal(t, "t1", abstractStmt.Table) + assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter) + assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement) + // test unsupported. _, err = New("INSERT INTO t1 (a) VALUES (1)") assert.Error(t, err) @@ -60,6 +67,10 @@ func TestExtractFromStatement(t *testing.T) { assert.Empty(t, abstractStmt.Alter) assert.False(t, abstractStmt.IsAlterTable()) + // drop table with multiple schemas + _, err = New("DROP TABLE test.t1, test2.t1") + assert.Error(t, err) + // rename table abstractStmt, err = New("RENAME TABLE t1 TO t2") assert.NoError(t, err)