From 8ab1f26dc6ed13f1606d5e1d9cdfff10927536b3 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Fri, 21 Feb 2025 10:58:50 -0800 Subject: [PATCH 1/2] Add Schema member to AbstractStatement --- .gitignore | 3 +++ pkg/migration/migration.go | 6 +++--- pkg/statement/statement.go | 33 +++++++++++++++++++-------------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 88f660f6..c4a7aa2d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ cmd/spirit/spirit # Dependency directories (remove the comment below to include it) # vendor/ + +.goosehints +.goose \ No newline at end of file diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index f478381c..2df31f7b 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -84,12 +84,12 @@ func (m *Migration) normalizeOptions() (stmt *statement.AbstractStatement, err e // This also returns the StmtNode. stmt, err = statement.New(m.Statement) if err != nil { - if err == statement.ErrSchemaNameIncluded { - return nil, err - } // Omit the parser error messages, just show the statement. return nil, errors.New("could not parse SQL statement: " + m.Statement) } + if stmt.Schema != "" && stmt.Schema != m.Database { + return nil, errors.New("schema name in statement (`schema`.`table`) does not match --database") + } } else { if m.Table == "" { return nil, errors.New("table name is required") diff --git a/pkg/statement/statement.go b/pkg/statement/statement.go index 122b07ed..73909d27 100644 --- a/pkg/statement/statement.go +++ b/pkg/statement/statement.go @@ -13,6 +13,7 @@ import ( ) type AbstractStatement struct { + Schema string Table string Alter string // may be empty. Statement string @@ -20,7 +21,6 @@ type AbstractStatement struct { } var ( - ErrSchemaNameIncluded = errors.New("schema name included in the table name is not supported") ErrNotSupportedStatement = errors.New("not a supported statement type") ErrNotAlterTable = errors.New("not an ALTER TABLE statement") ) @@ -42,18 +42,16 @@ func New(statement string) (*AbstractStatement, error) { // if the schema name is included it could be different from the --database // specified, which causes all sorts of problems. The easiest way to handle this // it just to not permit it. - if alterStmt.Table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } var sb strings.Builder sb.Reset() rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) if err = alterStmt.Restore(rCtx); err != nil { - return nil, fmt.Errorf("could not restore alter table statement: %s", err) + return nil, fmt.Errorf("could not restore alter clause statement: %s", err) } normalizedStmt := sb.String() trimLen := len(alterStmt.Table.Name.String()) + 15 // len ALTER TABLE + quotes return &AbstractStatement{ + Schema: alterStmt.Table.Schema.String(), Table: alterStmt.Table.Name.String(), Alter: normalizedStmt[trimLen:], Statement: statement, @@ -66,34 +64,41 @@ func New(statement string) (*AbstractStatement, error) { // but it's not a spirit migration. But the table should be specified. case *ast.CreateTableStmt: stmt := stmtNodes[0].(*ast.CreateTableStmt) - if stmt.Table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } return &AbstractStatement{ + Schema: stmt.Table.Schema.String(), Table: stmt.Table.Name.String(), Statement: statement, StmtNode: &stmtNodes[0], }, err case *ast.DropTableStmt: stmt := stmtNodes[0].(*ast.DropTableStmt) + distinctSchemas := make(map[string]struct{}) for _, table := range stmt.Tables { - if table.Schema.String() != "" { - return nil, ErrSchemaNameIncluded - } + distinctSchemas[table.Schema.String()] = struct{}{} + } + if len(distinctSchemas) > 1 { + return nil, errors.New("statement attempts to drop tables from multiple schemas") } return &AbstractStatement{ + Schema: stmt.Tables[0].Schema.String(), Table: stmt.Tables[0].Name.String(), // TODO: this is just used in log lines, but there could be more than one! Statement: statement, StmtNode: &stmtNodes[0], }, err case *ast.RenameTableStmt: stmt := stmtNodes[0].(*ast.RenameTableStmt) - for _, table := range stmt.TableToTables { - if table.OldTable.Schema.String() != "" || table.NewTable.Schema.String() != "" { - return nil, ErrSchemaNameIncluded + distinctSchemas := make(map[string]struct{}) + for _, clause := range stmt.TableToTables { + if clause.OldTable.Schema.String() != clause.NewTable.Schema.String() { + return nil, errors.New("statement attempts to move table between schemas") } + distinctSchemas[clause.OldTable.Schema.String()] = struct{}{} + } + if len(distinctSchemas) > 1 { + return nil, errors.New("statement attempts to rename tables in multiple schemas") } return &AbstractStatement{ + Schema: stmt.TableToTables[0].OldTable.Schema.String(), Table: stmt.TableToTables[0].OldTable.Name.String(), // TODO: this is just used in log lines, but there could be more than one! Statement: statement, StmtNode: &stmtNodes[0], From adcb3ef2f2aaad4a744358823666dc4a8bf48011 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Fri, 21 Feb 2025 15:07:27 -0800 Subject: [PATCH 2/2] testing --- pkg/migration/migration_test.go | 8 +++++--- pkg/statement/statement.go | 6 ++++-- pkg/statement/statement_test.go | 17 ++++++++++++++--- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pkg/migration/migration_test.go b/pkg/migration/migration_test.go index 73ee3fa5..fefc38a4 100644 --- a/pkg/migration/migration_test.go +++ b/pkg/migration/migration_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + _ "github.com/pingcap/tidb/pkg/parser/test_driver" "github.com/cashapp/spirit/pkg/testutils" @@ -443,6 +445,7 @@ func TestCreateIndexIsRewritten(t *testing.T) { testutils.RunSQL(t, tbl) cfg, err := mysql.ParseDSN(testutils.DSN()) assert.NoError(t, err) + require.NotEqual(t, "", cfg.DBName) migration := &Migration{ Host: cfg.Addr, Username: cfg.User, @@ -450,7 +453,7 @@ func TestCreateIndexIsRewritten(t *testing.T) { Database: cfg.DBName, Threads: 1, Checksum: true, - Statement: "CREATE INDEX idx ON t1createindex (b)", + Statement: "CREATE INDEX idx ON " + cfg.DBName + ".t1createindex (b)", } err = migration.Run() assert.NoError(t, err) @@ -475,6 +478,5 @@ func TestSchemaNameIncluded(t *testing.T) { Statement: "ALTER TABLE test.t1schemaname ADD COLUMN c int", } err = migration.Run() - assert.Error(t, err) - assert.ErrorContains(t, err, "schema name included in the table name is not supported") + assert.NoError(t, err) } diff --git a/pkg/statement/statement.go b/pkg/statement/statement.go index 73909d27..594ffd92 100644 --- a/pkg/statement/statement.go +++ b/pkg/statement/statement.go @@ -13,8 +13,8 @@ import ( ) type AbstractStatement struct { - Schema string - Table string + Schema string // this will be empty unless the table name is fully qualified (ALTER TABLE test.t1 ...) + Table string // for statements that affect multiple tables (DROP TABLE t1, t2), only the first is set here! Alter string // may be empty. Statement string StmtNode *ast.StmtNode @@ -243,6 +243,8 @@ func convertCreateIndexToAlterTable(stmt ast.StmtNode) (*AbstractStatement, erro alterStmt := fmt.Sprintf("ADD %s %s (%s)", keyType, ciStmt.IndexName, strings.Join(columns, ", ")) // We hint in the statement that it's been rewritten // and in the stmtNode we reparse from the alterStmt. + // TODO: include schema name if it's explicitly given in the CREATE INDEX statement? + // TODO: identifiers should be quoted/escaped in case a maniac includes a backtick in a table name. statement := fmt.Sprintf("ALTER TABLE `%s` %s", ciStmt.Table.Name, alterStmt) p := parser.New() stmtNodes, _, err := p.Parse(statement, "", "") diff --git a/pkg/statement/statement_test.go b/pkg/statement/statement_test.go index bee8670c..d2174019 100644 --- a/pkg/statement/statement_test.go +++ b/pkg/statement/statement_test.go @@ -33,9 +33,10 @@ func TestExtractFromStatement(t *testing.T) { _, err = New("ALTER TABLE t1 ADD INDEX (something); ALTER TABLE t2 ADD INDEX (something)") assert.Error(t, err) - // Try and include the schema name. - _, err = New("ALTER TABLE test.t1 ADD INDEX (something)") - assert.Error(t, err) + // Include the schema name. + abstractStmt, err = New("ALTER TABLE test.t1 ADD INDEX (something)") + assert.NoError(t, err) + assert.Equal(t, abstractStmt.Schema, "test") // Try and parse an invalid statement. _, err = New("ALTER TABLE t1 yes") @@ -48,6 +49,12 @@ func TestExtractFromStatement(t *testing.T) { assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter) assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement) + abstractStmt, err = New("CREATE INDEX idx ON test.`t1` (a)") + assert.NoError(t, err) + assert.Equal(t, "t1", abstractStmt.Table) + assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter) + assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement) + // test unsupported. _, err = New("INSERT INTO t1 (a) VALUES (1)") assert.Error(t, err) @@ -60,6 +67,10 @@ func TestExtractFromStatement(t *testing.T) { assert.Empty(t, abstractStmt.Alter) assert.False(t, abstractStmt.IsAlterTable()) + // drop table with multiple schemas + _, err = New("DROP TABLE test.t1, test2.t1") + assert.Error(t, err) + // rename table abstractStmt, err = New("RENAME TABLE t1 TO t2") assert.NoError(t, err)