From a20bd53957401bb047a80e7c81a9da9d4e7deaac Mon Sep 17 00:00:00 2001 From: Roman Bolkhovitin Date: Fri, 22 Dec 2023 00:43:39 +0300 Subject: [PATCH 1/5] add custom parser for postgres multistatements --- database/multistmt/parse.go | 15 +++- database/multistmt/parse_postgres.go | 93 +++++++++++++++++++++++ database/multistmt/parse_postgres_test.go | 85 +++++++++++++++++++++ database/pgx/pgx.go | 2 +- database/pgx/v5/pgx.go | 2 +- database/postgres/postgres.go | 2 +- 6 files changed, 192 insertions(+), 7 deletions(-) create mode 100644 database/multistmt/parse_postgres.go create mode 100644 database/multistmt/parse_postgres_test.go diff --git a/database/multistmt/parse.go b/database/multistmt/parse.go index 9a045767d..b1c2f9093 100644 --- a/database/multistmt/parse.go +++ b/database/multistmt/parse.go @@ -15,7 +15,7 @@ var StartBufSize = 4096 // from the multi-statement migration should be parsed and handled. type Handler func(migration []byte) bool -func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) { +func splitWithDelimiter(delimiter []byte) bufio.SplitFunc { return func(d []byte, atEOF bool) (int, []byte, error) { // SplitFunc inspired by bufio.ScanLines() implementation if atEOF { @@ -31,11 +31,13 @@ func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byt } } -// Parse parses the given multi-statement migration -func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { +type scanFuncFactory func(delimiter []byte) bufio.SplitFunc + +// parse parses the given multi-statement migration +func parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler, factory scanFuncFactory) error { scanner := bufio.NewScanner(reader) scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize) - scanner.Split(splitWithDelimiter(delimiter)) + scanner.Split(factory(delimiter)) for scanner.Scan() { cont := h(scanner.Bytes()) if !cont { @@ -44,3 +46,8 @@ func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) } return scanner.Err() } + +// Parse parses the given multi-statement migration +func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { + return parse(reader, delimiter, maxMigrationSize, h, splitWithDelimiter) +} diff --git a/database/multistmt/parse_postgres.go b/database/multistmt/parse_postgres.go new file mode 100644 index 000000000..a2990bdca --- /dev/null +++ b/database/multistmt/parse_postgres.go @@ -0,0 +1,93 @@ +// Package multistmt provides methods for parsing multi-statement database migrations +package multistmt + +import ( + "bufio" + "bytes" + "io" + "unicode" + + "golang.org/x/exp/slices" +) + +const dollar = '$' + +func isValidTagSymbol(r rune) bool { + return unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' +} + +func pgSplitWithDelimiter(delimiter []byte) bufio.SplitFunc { + // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-DOLLAR-QUOTING + // inside the dollar-quoted string, single quotes can be used without needing to + // be escaped. Indeed, no characters inside a dollar-quoted string are ever + // escaped: the string content is always written literally. Backslashes are not + // special, and neither are dollar signs, unless they are part of a sequence + // matching the opening tag. + // + // It is possible to nest dollar-quoted string constants by choosing different + // tags at each nesting level. This is most commonly used in writing function + // definitions + return func(d []byte, atEOF bool) (int, []byte, error) { + if atEOF { + if len(d) == 0 { + return 0, nil, nil + } + + return len(d), d, nil + } + + stack := [][]byte{delimiter} + maybeDollarQuoted := false + firstDollarPosition := 0 + + reader := bufio.NewReader(bytes.NewReader(d)) + position := 0 + + for position < len(d) { + currentDelimiter := stack[len(stack)-1] + + if len(d[position:]) >= len(currentDelimiter) { + if slices.Equal(d[position:position+len(currentDelimiter)], currentDelimiter) { + // pop delimiter from stack and fast-forward cursor and reader + stack = stack[:len(stack)-1] + position += len(currentDelimiter) + _, _ = io.ReadFull(reader, currentDelimiter) + + if len(stack) != 0 { + continue + } + } + } + + if len(stack) == 0 { + return position, d[:position], nil + } + + r, size, err := reader.ReadRune() + if err != nil { + return position + size, d[:position+size], err + } + + switch { + case r == dollar && !maybeDollarQuoted: + maybeDollarQuoted = true + + firstDollarPosition = position + case r == dollar && maybeDollarQuoted: + stack = append(stack, d[firstDollarPosition:position+size]) + maybeDollarQuoted = false + case !isValidTagSymbol(r) && maybeDollarQuoted: + maybeDollarQuoted = false + } + + position += size + } + + return 0, nil, nil + } +} + +// PGParse parses the given multi-statement migration for PostgreSQL respecting the dollar-quoted strings +func PGParse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error { + return parse(reader, delimiter, maxMigrationSize, h, pgSplitWithDelimiter) +} diff --git a/database/multistmt/parse_postgres_test.go b/database/multistmt/parse_postgres_test.go new file mode 100644 index 000000000..a202494f1 --- /dev/null +++ b/database/multistmt/parse_postgres_test.go @@ -0,0 +1,85 @@ +package multistmt_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/golang-migrate/migrate/v4/database/multistmt" +) + +func TestPGParse(t *testing.T) { + createFunctionEmptyTagStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS +$$ +BEGIN + NEW.new_id := NEW.id; + RETURN NEW; +END +$$ LANGUAGE PLPGSQL;` + + createFunctionStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS +$BODY$ +BEGIN + NEW.new_id := NEW.id; + RETURN NEW; +END +$BODY$ LANGUAGE PLPGSQL;` + + createTriggerStmt := `CREATE TRIGGER set_new_id_trigger BEFORE INSERT OR UPDATE ON mytable +FOR EACH ROW EXECUTE PROCEDURE set_new_id();` + + nestedDollarQuotes := `$function$ +BEGIN + RETURN ($1 ~ $q$[\t\r\n\v\\]$q$); +END; +$function$;` + + testCases := []struct { + name string + multiStmt string + delimiter string + expected []string + expectedErr error + }{ + {name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";", + expected: []string{"single statement, no delimiter"}, expectedErr: nil}, + {name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";", + expected: []string{"single statement, one delimiter;"}, expectedErr: nil}, + {name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";", + expected: []string{"statement one;", " statement two"}, expectedErr: nil}, + {name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";", + expected: []string{"statement one;", " statement two;"}, expectedErr: nil}, + {name: "singe statement with nested dollar-quoted string", multiStmt: nestedDollarQuotes, delimiter: ";", + expected: []string{nestedDollarQuotes}}, + {name: "three statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, + createFunctionEmptyTagStmt, createTriggerStmt, nestedDollarQuotes}, ""), delimiter: ";", + expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, createTriggerStmt, nestedDollarQuotes}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + stmts := make([]string, 0, len(tc.expected)) + err := multistmt.PGParse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool { + stmts = append(stmts, string(b)) + return true + }) + assert.Equal(t, tc.expectedErr, err) + assert.Equal(t, tc.expected, stmts) + }) + } +} + +func TestPGParseDiscontinue(t *testing.T) { + multiStmt := "statement one; statement two" + delimiter := ";" + expected := []string{"statement one;"} + + stmts := make([]string, 0, len(expected)) + err := multistmt.PGParse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool { + stmts = append(stmts, string(b)) + return false + }) + assert.Nil(t, err) + assert.Equal(t, expected, stmts) +} diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7e42d29c9..9d12a626c 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -363,7 +363,7 @@ func (p *Postgres) releaseTableLock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..bd2bae40e 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -254,7 +254,7 @@ func (p *Postgres) Unlock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 9e6d6277f..0d1f602d5 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -267,7 +267,7 @@ func (p *Postgres) Unlock() error { func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error - if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { if err = p.runStatement(m); err != nil { return false } From 2e20d5bb8161cf1f014fb6b22ac2e86540d94fc8 Mon Sep 17 00:00:00 2001 From: Roman Bolkhovitin Date: Fri, 22 Dec 2023 12:33:08 +0300 Subject: [PATCH 2/5] add advanced function in test --- database/multistmt/parse_postgres_test.go | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/database/multistmt/parse_postgres_test.go b/database/multistmt/parse_postgres_test.go index a202494f1..64a7f15fb 100644 --- a/database/multistmt/parse_postgres_test.go +++ b/database/multistmt/parse_postgres_test.go @@ -35,6 +35,21 @@ BEGIN END; $function$;` + advancedCreateFunction := `CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;` + testCases := []struct { name string multiStmt string @@ -53,8 +68,9 @@ $function$;` {name: "singe statement with nested dollar-quoted string", multiStmt: nestedDollarQuotes, delimiter: ";", expected: []string{nestedDollarQuotes}}, {name: "three statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, - createFunctionEmptyTagStmt, createTriggerStmt, nestedDollarQuotes}, ""), delimiter: ";", - expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, createTriggerStmt, nestedDollarQuotes}}, + createFunctionEmptyTagStmt, advancedCreateFunction, createTriggerStmt, nestedDollarQuotes}, ""), + delimiter: ";", expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, + createTriggerStmt, nestedDollarQuotes}}, } for _, tc := range testCases { From 698bd9f66d3d9d82bc8e8239445d12f20c6b6f85 Mon Sep 17 00:00:00 2001 From: Roman Bolkhovitin Date: Fri, 22 Dec 2023 13:02:25 +0300 Subject: [PATCH 3/5] lint --- Makefile | 6 +++++- database/multistmt/parse_postgres_test.go | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 8e23a43c7..cad350418 100644 --- a/Makefile +++ b/Makefile @@ -105,6 +105,9 @@ echo-database: @echo "$(DATABASE)" +lint: + golangci-lint run -c .golangci.yml + define external_deps @echo '-- $(1)'; go list -f '{{join .Deps "\n"}}' $(1) | grep -v github.com/$(REPO_OWNER)/migrate | xargs go list -f '{{if not .Standard}}{{.ImportPath}}{{end}}' @@ -113,7 +116,8 @@ endef .PHONY: build build-docker build-cli clean test-short test test-with-flags html-coverage \ restore-import-paths rewrite-import-paths list-external-deps release \ - docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database + docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database \ + lint SHELL = /bin/sh RAND = $(shell echo $$RANDOM) diff --git a/database/multistmt/parse_postgres_test.go b/database/multistmt/parse_postgres_test.go index 64a7f15fb..8816872db 100644 --- a/database/multistmt/parse_postgres_test.go +++ b/database/multistmt/parse_postgres_test.go @@ -69,8 +69,9 @@ $$ LANGUAGE plpgsql expected: []string{nestedDollarQuotes}}, {name: "three statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, createTriggerStmt, nestedDollarQuotes}, ""), - delimiter: ";", expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, - createTriggerStmt, nestedDollarQuotes}}, + delimiter: ";", + expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, + createTriggerStmt, nestedDollarQuotes}}, } for _, tc := range testCases { From 08305518eeed8e3288fade15968b240e67c68a5a Mon Sep 17 00:00:00 2001 From: Roman Bolkhovitin Date: Fri, 22 Dec 2023 15:43:44 +0300 Subject: [PATCH 4/5] add create function to integration tests --- database/pgx/pgx_test.go | 17 ++++++++++++++++- database/pgx/v5/pgx_test.go | 17 ++++++++++++++++- database/postgres/postgres_test.go | 18 +++++++++++++++++- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 53e8e1d86..eb4ff4d1c 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -211,7 +211,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql;`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -223,6 +229,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index c7339c4fc..d247016d8 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -186,7 +186,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql;`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -198,6 +204,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 65395cc7e..955ad0455 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -183,7 +183,14 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text); +CREATE INDEX CONCURRENTLY idx_foo ON foo (foo); +CREATE FUNCTION baz() RETURNS integer AS $$ + BEGIN + RETURN 1; + END; +$$ LANGUAGE plpgsql; +`)); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -195,6 +202,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { if !exists { t.Fatalf("expected table bar to exist") } + + // make sure procedure exists + var proc string + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil { + t.Fatal(err) + } + if proc != "baz" { + t.Fatalf("expected procedure baz to exists") + } }) } From d84929031ed5a2d4a4c7c96212dae67210a20812 Mon Sep 17 00:00:00 2001 From: Roman Bolkhovitin Date: Fri, 22 Dec 2023 17:49:58 +0300 Subject: [PATCH 5/5] rename test --- database/multistmt/parse_postgres_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/database/multistmt/parse_postgres_test.go b/database/multistmt/parse_postgres_test.go index 8816872db..49f920503 100644 --- a/database/multistmt/parse_postgres_test.go +++ b/database/multistmt/parse_postgres_test.go @@ -67,7 +67,7 @@ $$ LANGUAGE plpgsql expected: []string{"statement one;", " statement two;"}, expectedErr: nil}, {name: "singe statement with nested dollar-quoted string", multiStmt: nestedDollarQuotes, delimiter: ";", expected: []string{nestedDollarQuotes}}, - {name: "three statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, + {name: "multiple statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction, createTriggerStmt, nestedDollarQuotes}, ""), delimiter: ";", expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction,