Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Schema member to AbstractStatement struct #1

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ cmd/spirit/spirit

# Dependency directories (remove the comment below to include it)
# vendor/

.goosehints
.goose
6 changes: 3 additions & 3 deletions pkg/migration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ func (m *Migration) normalizeOptions() (stmt *statement.AbstractStatement, err e
// This also returns the StmtNode.
stmt, err = statement.New(m.Statement)
if err != nil {
if err == statement.ErrSchemaNameIncluded {
return nil, err
}
// Omit the parser error messages, just show the statement.
return nil, errors.New("could not parse SQL statement: " + m.Statement)
}
if stmt.Schema != "" && stmt.Schema != m.Database {
return nil, errors.New("schema name in statement (`schema`.`table`) does not match --database")
}
} else {
if m.Table == "" {
return nil, errors.New("table name is required")
Expand Down
8 changes: 5 additions & 3 deletions pkg/migration/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

_ "github.com/pingcap/tidb/pkg/parser/test_driver"

"github.com/cashapp/spirit/pkg/testutils"
Expand Down Expand Up @@ -443,14 +445,15 @@ func TestCreateIndexIsRewritten(t *testing.T) {
testutils.RunSQL(t, tbl)
cfg, err := mysql.ParseDSN(testutils.DSN())
assert.NoError(t, err)
require.NotEqual(t, "", cfg.DBName)
migration := &Migration{
Host: cfg.Addr,
Username: cfg.User,
Password: cfg.Passwd,
Database: cfg.DBName,
Threads: 1,
Checksum: true,
Statement: "CREATE INDEX idx ON t1createindex (b)",
Statement: "CREATE INDEX idx ON " + cfg.DBName + ".t1createindex (b)",
}
err = migration.Run()
assert.NoError(t, err)
Expand All @@ -475,6 +478,5 @@ func TestSchemaNameIncluded(t *testing.T) {
Statement: "ALTER TABLE test.t1schemaname ADD COLUMN c int",
}
err = migration.Run()
assert.Error(t, err)
assert.ErrorContains(t, err, "schema name included in the table name is not supported")
assert.NoError(t, err)
}
37 changes: 22 additions & 15 deletions pkg/statement/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
)

type AbstractStatement struct {
Table string
Schema string // this will be empty unless the table name is fully qualified (ALTER TABLE test.t1 ...)
Table string // for statements that affect multiple tables (DROP TABLE t1, t2), only the first is set here!
Alter string // may be empty.
Statement string
StmtNode *ast.StmtNode
}

var (
ErrSchemaNameIncluded = errors.New("schema name included in the table name is not supported")
ErrNotSupportedStatement = errors.New("not a supported statement type")
ErrNotAlterTable = errors.New("not an ALTER TABLE statement")
)
Expand All @@ -42,18 +42,16 @@ func New(statement string) (*AbstractStatement, error) {
// if the schema name is included it could be different from the --database
// specified, which causes all sorts of problems. The easiest way to handle this
// it just to not permit it.
if alterStmt.Table.Schema.String() != "" {
return nil, ErrSchemaNameIncluded
}
var sb strings.Builder
sb.Reset()
rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
if err = alterStmt.Restore(rCtx); err != nil {
return nil, fmt.Errorf("could not restore alter table statement: %s", err)
return nil, fmt.Errorf("could not restore alter clause statement: %s", err)
}
normalizedStmt := sb.String()
trimLen := len(alterStmt.Table.Name.String()) + 15 // len ALTER TABLE + quotes
return &AbstractStatement{
Schema: alterStmt.Table.Schema.String(),
Table: alterStmt.Table.Name.String(),
Alter: normalizedStmt[trimLen:],
Statement: statement,
Expand All @@ -66,34 +64,41 @@ func New(statement string) (*AbstractStatement, error) {
// but it's not a spirit migration. But the table should be specified.
case *ast.CreateTableStmt:
stmt := stmtNodes[0].(*ast.CreateTableStmt)
if stmt.Table.Schema.String() != "" {
return nil, ErrSchemaNameIncluded
}
return &AbstractStatement{
Schema: stmt.Table.Schema.String(),
Table: stmt.Table.Name.String(),
Statement: statement,
StmtNode: &stmtNodes[0],
}, err
case *ast.DropTableStmt:
stmt := stmtNodes[0].(*ast.DropTableStmt)
distinctSchemas := make(map[string]struct{})
for _, table := range stmt.Tables {
if table.Schema.String() != "" {
return nil, ErrSchemaNameIncluded
}
distinctSchemas[table.Schema.String()] = struct{}{}
}
if len(distinctSchemas) > 1 {
return nil, errors.New("statement attempts to drop tables from multiple schemas")
}
return &AbstractStatement{
Schema: stmt.Tables[0].Schema.String(),
Table: stmt.Tables[0].Name.String(), // TODO: this is just used in log lines, but there could be more than one!
Statement: statement,
StmtNode: &stmtNodes[0],
}, err
case *ast.RenameTableStmt:
stmt := stmtNodes[0].(*ast.RenameTableStmt)
for _, table := range stmt.TableToTables {
if table.OldTable.Schema.String() != "" || table.NewTable.Schema.String() != "" {
return nil, ErrSchemaNameIncluded
distinctSchemas := make(map[string]struct{})
for _, clause := range stmt.TableToTables {
if clause.OldTable.Schema.String() != clause.NewTable.Schema.String() {
return nil, errors.New("statement attempts to move table between schemas")
}
distinctSchemas[clause.OldTable.Schema.String()] = struct{}{}
}
if len(distinctSchemas) > 1 {
return nil, errors.New("statement attempts to rename tables in multiple schemas")
}
return &AbstractStatement{
Schema: stmt.TableToTables[0].OldTable.Schema.String(),
Table: stmt.TableToTables[0].OldTable.Name.String(), // TODO: this is just used in log lines, but there could be more than one!
Statement: statement,
StmtNode: &stmtNodes[0],
Expand Down Expand Up @@ -238,6 +243,8 @@ func convertCreateIndexToAlterTable(stmt ast.StmtNode) (*AbstractStatement, erro
alterStmt := fmt.Sprintf("ADD %s %s (%s)", keyType, ciStmt.IndexName, strings.Join(columns, ", "))
// We hint in the statement that it's been rewritten
// and in the stmtNode we reparse from the alterStmt.
// TODO: include schema name if it's explicitly given in the CREATE INDEX statement?
// TODO: identifiers should be quoted/escaped in case a maniac includes a backtick in a table name.
statement := fmt.Sprintf("ALTER TABLE `%s` %s", ciStmt.Table.Name, alterStmt)
p := parser.New()
stmtNodes, _, err := p.Parse(statement, "", "")
Expand Down
17 changes: 14 additions & 3 deletions pkg/statement/statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ func TestExtractFromStatement(t *testing.T) {
_, err = New("ALTER TABLE t1 ADD INDEX (something); ALTER TABLE t2 ADD INDEX (something)")
assert.Error(t, err)

// Try and include the schema name.
_, err = New("ALTER TABLE test.t1 ADD INDEX (something)")
assert.Error(t, err)
// Include the schema name.
abstractStmt, err = New("ALTER TABLE test.t1 ADD INDEX (something)")
assert.NoError(t, err)
assert.Equal(t, abstractStmt.Schema, "test")

// Try and parse an invalid statement.
_, err = New("ALTER TABLE t1 yes")
Expand All @@ -48,6 +49,12 @@ func TestExtractFromStatement(t *testing.T) {
assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter)
assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement)

abstractStmt, err = New("CREATE INDEX idx ON test.`t1` (a)")
assert.NoError(t, err)
assert.Equal(t, "t1", abstractStmt.Table)
assert.Equal(t, "ADD INDEX idx (a)", abstractStmt.Alter)
assert.Equal(t, "ALTER TABLE `t1` ADD INDEX idx (a)", abstractStmt.Statement)

// test unsupported.
_, err = New("INSERT INTO t1 (a) VALUES (1)")
assert.Error(t, err)
Expand All @@ -60,6 +67,10 @@ func TestExtractFromStatement(t *testing.T) {
assert.Empty(t, abstractStmt.Alter)
assert.False(t, abstractStmt.IsAlterTable())

// drop table with multiple schemas
_, err = New("DROP TABLE test.t1, test2.t1")
assert.Error(t, err)

// rename table
abstractStmt, err = New("RENAME TABLE t1 TO t2")
assert.NoError(t, err)
Expand Down