diff --git a/Makefile b/Makefile index 8e23a43c7..cad350418 100644 --- a/Makefile +++ b/Makefile @@ -105,6 +105,9 @@ echo-database: @echo "$(DATABASE)" +lint: + golangci-lint run -c .golangci.yml + define external_deps @echo '-- $(1)'; go list -f '{{join .Deps "\n"}}' $(1) | grep -v github.com/$(REPO_OWNER)/migrate | xargs go list -f '{{if not .Standard}}{{.ImportPath}}{{end}}' @@ -113,7 +116,8 @@ endef .PHONY: build build-docker build-cli clean test-short test test-with-flags html-coverage \ restore-import-paths rewrite-import-paths list-external-deps release \ - docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database + docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database \ + lint SHELL = /bin/sh RAND = $(shell echo $$RANDOM) diff --git a/database/multistmt/parse.go b/database/multistmt/parse.go index 9a045767d..b1c2f9093 100644 --- a/database/multistmt/parse.go +++ b/database/multistmt/parse.go @@ -15,7 +15,7 @@ var StartBufSize = 4096 // from the multi-statement migration should be parsed and handled. type Handler func(migration []byte) bool -func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) { +func splitWithDelimiter(delimiter []byte) bufio.SplitFunc { return func(d []byte, atEOF bool) (int, []byte, error) { // SplitFunc inspired by bufio.ScanLines() implementation if atEOF { @@ -31,11 +31,13 @@ func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byt } } -// Parse parses the given multi-statement migration -func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { +type scanFuncFactory func(delimiter []byte) bufio.SplitFunc + +// parse parses the given multi-statement migration +func parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler, factory scanFuncFactory) error { scanner := bufio.NewScanner(reader) scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize) - scanner.Split(splitWithDelimiter(delimiter)) + scanner.Split(factory(delimiter)) for scanner.Scan() { cont := h(scanner.Bytes()) if !cont { @@ -44,3 +46,8 @@ func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) } return scanner.Err() } + +// Parse parses the given multi-statement migration +func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { + return parse(reader, delimiter, maxMigrationSize, h, splitWithDelimiter) +} diff --git a/database/multistmt/parse_postgres.go b/database/multistmt/parse_postgres.go new file mode 100644 index 000000000..a2990bdca --- /dev/null +++ b/database/multistmt/parse_postgres.go @@ -0,0 +1,93 @@ +// Package multistmt provides methods for parsing multi-statement database migrations +package multistmt + +import ( + "bufio" + "bytes" + "io" + "unicode" + + "golang.org/x/exp/slices" +) + +const dollar = '$' + +func isValidTagSymbol(r rune) bool { + return unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' +} + +func pgSplitWithDelimiter(delimiter []byte) bufio.SplitFunc { + // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-DOLLAR-QUOTING + // inside the dollar-quoted string, single quotes can be used without needing to + // be escaped. Indeed, no characters inside a dollar-quoted string are ever + // escaped: the string content is always written literally. Backslashes are not + // special, and neither are dollar signs, unless they are part of a sequence + // matching the opening tag. + // + // It is possible to nest dollar-quoted string constants by choosing different + // tags at each nesting level. This is most commonly used in writing function + // definitions + return func(d []byte, atEOF bool) (int, []byte, error) { + if atEOF { + if len(d) == 0 { + return 0, nil, nil + } + + return len(d), d, nil + } + + stack := [][]byte{delimiter} + maybeDollarQuoted := false + firstDollarPosition := 0 + + reader := bufio.NewReader(bytes.NewReader(d)) + position := 0 + + for position < len(d) { + currentDelimiter := stack[len(stack)-1] + + if len(d[position:]) >= len(currentDelimiter) { + if slices.Equal(d[position:position+len(currentDelimiter)], currentDelimiter) { + // pop delimiter from stack and fast-forward cursor and reader + stack = stack[:len(stack)-1] + position += len(currentDelimiter) + _, _ = io.ReadFull(reader, currentDelimiter) + + if len(stack) != 0 { + continue + } + } + } + + if len(stack) == 0 { + return position, d[:position], nil + } + + r, size, err := reader.ReadRune() + if err != nil { + return position + size, d[:position+size], err + } + + switch { + case r == dollar && !maybeDollarQuoted: + maybeDollarQuoted = true + + firstDollarPosition = position + case r == dollar && maybeDollarQuoted: + stack = append(stack, d[firstDollarPosition:position+size]) + maybeDollarQuoted = false + case !isValidTagSymbol(r) && maybeDollarQuoted: + maybeDollarQuoted = false + } + + position += size + } + + return 0, nil, nil + } +} + +// PGParse parses the given multi-statement migration for PostgreSQL respecting the dollar-quoted strings +func PGParse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { + return parse(reader, delimiter, maxMigrationSize, h, pgSplitWithDelimiter) +} diff --git a/database/multistmt/parse_postgres_test.go b/database/multistmt/parse_postgres_test.go new file mode 100644 index 000000000..49f920503 --- /dev/null +++ b/database/multistmt/parse_postgres_test.go @@ -0,0 +1,102 @@ +package multistmt_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/golang-migrate/migrate/v4/database/multistmt" +) + +func TestPGParse(t *testing.T) { + createFunctionEmptyTagStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS +$$ +BEGIN + NEW.new_id := NEW.id; + RETURN NEW; +END +$$ LANGUAGE PLPGSQL;` + + createFunctionStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS +$BODY$ +BEGIN + NEW.new_id := NEW.id; + RETURN NEW; +END +$BODY$ LANGUAGE PLPGSQL;` + + createTriggerStmt := `CREATE TRIGGER set_new_id_trigger BEFORE INSERT OR UPDATE ON mytable +FOR EACH ROW EXECUTE PROCEDURE set_new_id();` + + nestedDollarQuotes := `$function$ +BEGIN + RETURN ($1 ~ $q$[\t\r\n\v\\]$q$); +END; +$function$;` + + advancedCreateFunction := `CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;` + + testCases := []struct { + name string + multiStmt string + delimiter string + expected []string + expectedErr error + }{ + {name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";", + expected: []string{"single statement, no delimiter"}, expectedErr: nil}, + {name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";", + expected: []string{"single statement, one delimiter;"}, expectedErr: nil}, + {name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";", + expected: []string{"statement one;", " statement two"}, expectedErr: nil}, + {name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";", + expected: []string{"statement one;", " statement two;"}, expectedErr: nil}, + {name: "singe statement with nested dollar-quoted string", multiStmt: nestedDollarQuotes, delimiter: ";", + expected: []string{nestedDollarQuotes}}, + {name: "multiple statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, + createFunctionEmptyTagStmt, advancedCreateFunction, createTriggerStmt, nestedDollarQuotes}, ""), + delimiter: ";", + expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, + createTriggerStmt, nestedDollarQuotes}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + stmts := make([]string, 0, len(tc.expected)) + err := multistmt.PGParse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool { + stmts = append(stmts, string(b)) + return true + }) + assert.Equal(t, tc.expectedErr, err) + assert.Equal(t, tc.expected, stmts) + }) + } +} + +func TestPGParseDiscontinue(t *testing.T) { + multiStmt := "statement one; statement two" + delimiter := ";" + expected := []string{"statement one;"} + + stmts := make([]string, 0, len(expected)) + err := multistmt.PGParse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool { + stmts = append(stmts, string(b)) + return false + }) + assert.Nil(t, err) + assert.Equal(t, expected, stmts) +} diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7e42d29c9..9d12a626c 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -363,7 +363,7 @@ func (p *Postgres) releaseTableLock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 53e8e1d86..eb4ff4d1c 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -211,7 +211,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql;`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -223,6 +229,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..bd2bae40e 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -254,7 +254,7 @@ func (p *Postgres) Unlock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index c7339c4fc..d247016d8 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -186,7 +186,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql;`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -198,6 +204,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 9e6d6277f..0d1f602d5 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -267,7 +267,7 @@ func (p *Postgres) Unlock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 65395cc7e..955ad0455 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -183,7 +183,14 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql; +`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -195,6 +202,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) }