Skip to content

Commit ebba141

Browse files
Prometheus2677FPiety0521
authored and
FPiety0521
committed
Postgres - add x-migrations-table-quoted url query option (#95) (#533)
* Postgres and pgx - Add x-migrations-table-quoted url query option to postgres and pgx drivers (#95) By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` * Work In Progress
1 parent 3c80053 commit ebba141

File tree

6 files changed

+337
-20
lines changed

6 files changed

+337
-20
lines changed

Diff for: database/pgx/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
| URL Query | WithInstance Config | Description |
66
|------------|---------------------|-------------|
77
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
8+
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
89
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
910
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
1011
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |

Diff for: database/pgx/pgx.go

+45-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"io/ioutil"
1111
nurl "net/url"
12+
"regexp"
1213
"strconv"
1314
"strings"
1415
"time"
@@ -46,6 +47,7 @@ type Config struct {
4647
DatabaseName string
4748
SchemaName string
4849
StatementTimeout time.Duration
50+
MigrationsTableQuoted bool
4951
MultiStatementEnabled bool
5052
MultiStatementMaxSize int
5153
}
@@ -137,6 +139,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
137139
}
138140

139141
migrationsTable := purl.Query().Get("x-migrations-table")
142+
migrationsTableQuoted := false
143+
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
144+
migrationsTableQuoted, err = strconv.ParseBool(s)
145+
if err != nil {
146+
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
147+
}
148+
}
149+
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
150+
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
151+
}
152+
140153
statementTimeoutString := purl.Query().Get("x-statement-timeout")
141154
statementTimeout := 0
142155
if statementTimeoutString != "" {
@@ -168,6 +181,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
168181
px, err := WithInstance(db, &Config{
169182
DatabaseName: purl.Path,
170183
MigrationsTable: migrationsTable,
184+
MigrationsTableQuoted: migrationsTableQuoted,
171185
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
172186
MultiStatementEnabled: multiStatementEnabled,
173187
MultiStatementMaxSize: multiStatementMaxSize,
@@ -321,7 +335,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
321335
return &database.Error{OrigErr: err, Err: "transaction start failed"}
322336
}
323337

324-
query := `TRUNCATE ` + quoteIdentifier(p.config.MigrationsTable)
338+
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
325339
if _, err := tx.Exec(query); err != nil {
326340
if errRollback := tx.Rollback(); errRollback != nil {
327341
err = multierror.Append(err, errRollback)
@@ -333,7 +347,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
333347
// empty schema version for failed down migration on the first migration
334348
// See: https://github.com/golang-migrate/migrate/issues/330
335349
if version >= 0 || (version == database.NilVersion && dirty) {
336-
query = `INSERT INTO ` + quoteIdentifier(p.config.MigrationsTable) +
350+
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
337351
` (version, dirty) VALUES ($1, $2)`
338352
if _, err := tx.Exec(query, version, dirty); err != nil {
339353
if errRollback := tx.Rollback(); errRollback != nil {
@@ -351,7 +365,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
351365
}
352366

353367
func (p *Postgres) Version() (version int, dirty bool, err error) {
354-
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
368+
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
355369
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
356370
switch {
357371
case err == sql.ErrNoRows:
@@ -401,7 +415,7 @@ func (p *Postgres) Drop() (err error) {
401415
if len(tableNames) > 0 {
402416
// delete one by one ...
403417
for _, t := range tableNames {
404-
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
418+
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
405419
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
406420
return &database.Error{OrigErr: err, Query: []byte(query)}
407421
}
@@ -433,10 +447,29 @@ func (p *Postgres) ensureVersionTable() (err error) {
433447
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
434448
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
435449
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
436-
var count int
437-
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
438-
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
450+
var row *sql.Row
451+
tableName := p.config.MigrationsTable
452+
schemaName := ""
453+
if p.config.MigrationsTableQuoted {
454+
re := regexp.MustCompile(`"(.*?)"`)
455+
result := re.FindAllStringSubmatch(p.config.MigrationsTable, -1)
456+
tableName = result[len(result)-1][1]
457+
if len(result) == 2 {
458+
schemaName = result[0][1]
459+
} else if len(result) > 2 {
460+
return fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", p.config.MigrationsTable)
461+
}
462+
}
463+
var query string
464+
if len(schemaName) > 0 {
465+
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
466+
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
467+
} else {
468+
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
469+
row = p.conn.QueryRowContext(context.Background(), query, tableName)
470+
}
439471

472+
var count int
440473
err = row.Scan(&count)
441474
if err != nil {
442475
return &database.Error{OrigErr: err, Query: []byte(query)}
@@ -446,7 +479,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
446479
return nil
447480
}
448481

449-
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
482+
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
450483
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
451484
return &database.Error{OrigErr: err, Query: []byte(query)}
452485
}
@@ -455,7 +488,10 @@ func (p *Postgres) ensureVersionTable() (err error) {
455488
}
456489

457490
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
458-
func quoteIdentifier(name string) string {
491+
func (p *Postgres) quoteIdentifier(name string) string {
492+
if p.config.MigrationsTableQuoted {
493+
return name
494+
}
459495
end := strings.IndexRune(name, 0)
460496
if end > -1 {
461497
name = name[:end]

Diff for: database/pgx/pgx_test.go

+119
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,74 @@ func TestWithSchema(t *testing.T) {
318318
})
319319
}
320320

321+
func TestMigrationTableOption(t *testing.T) {
322+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
323+
ip, port, err := c.FirstPort()
324+
if err != nil {
325+
t.Fatal(err)
326+
}
327+
328+
addr := pgConnectionString(ip, port)
329+
p := &Postgres{}
330+
d, _ := p.Open(addr)
331+
defer func() {
332+
if err := d.Close(); err != nil {
333+
t.Fatal(err)
334+
}
335+
}()
336+
337+
// create migrate schema
338+
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
339+
t.Fatal(err)
340+
}
341+
342+
// bad unquoted x-migrations-table parameter
343+
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
344+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
345+
pgPassword, ip, port))
346+
if (err != nil) && (err.Error() != wantErr) {
347+
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
348+
}
349+
350+
// too many quoted x-migrations-table parameters
351+
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
352+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
353+
pgPassword, ip, port))
354+
if (err != nil) && (err.Error() != wantErr) {
355+
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
356+
}
357+
358+
// good quoted x-migrations-table parameter
359+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
360+
pgPassword, ip, port))
361+
if err != nil {
362+
t.Fatal(err)
363+
}
364+
365+
// make sure migrate.schema_migrations table exists
366+
var exists bool
367+
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
368+
t.Fatal(err)
369+
}
370+
if !exists {
371+
t.Fatalf("expected table migrate.schema_migrations to exist")
372+
}
373+
374+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
375+
pgPassword, ip, port))
376+
if err != nil {
377+
t.Fatal(err)
378+
}
379+
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
380+
t.Fatal(err)
381+
}
382+
if !exists {
383+
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
384+
}
385+
386+
})
387+
}
388+
321389
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
322390
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
323391
ip, port, err := c.FirstPort()
@@ -373,6 +441,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) {
373441
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
374442
t.Fatal(e)
375443
}
444+
445+
// re-connect using that x-migrations-table and x-migrations-table-quoted
446+
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
447+
pgPassword, ip, port))
448+
449+
if !errors.As(err, &e) || err == nil {
450+
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
451+
}
452+
453+
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
454+
t.Fatal(e)
455+
}
376456
})
377457
}
378458

@@ -679,5 +759,44 @@ func Test_computeLineFromPos(t *testing.T) {
679759
run(true, true)
680760
})
681761
}
762+
}
763+
764+
func Test_quoteIdentifier(t *testing.T) {
765+
testcases := []struct {
766+
migrationsTableQuoted bool
767+
migrationsTable string
768+
expected string
769+
}{
770+
{
771+
false,
772+
"schema_name.table_name",
773+
"\"schema_name.table_name\"",
774+
},
775+
{
776+
false,
777+
"table_name",
778+
"\"table_name\"",
779+
},
780+
{
781+
true,
782+
"\"schema_name\".\"table_name\"",
783+
"\"schema_name\".\"table_name\"",
784+
},
785+
{
786+
true,
787+
"\"table_name\"",
788+
"\"table_name\"",
789+
},
790+
}
791+
p := &Postgres{
792+
config: &Config{},
793+
}
682794

795+
for _, tc := range testcases {
796+
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
797+
got := p.quoteIdentifier(tc.migrationsTable)
798+
if tc.expected != got {
799+
t.Fatalf("expected %s but got %s", tc.expected, got)
800+
}
801+
}
683802
}

Diff for: database/postgres/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
| URL Query | WithInstance Config | Description |
66
|------------|---------------------|-------------|
77
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
8+
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
89
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
910
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
1011
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |

0 commit comments

Comments
 (0)