Skip to content

Commit d2b1700

Browse files
Prometheus2677FPiety0521
authored andcommitted
Postgres: Add a check to determine if table already exists to elide CREATE query (#526)
* Squash commits * Format * Minor refactoring * Address PR feedback; Add mustRun * Fix a test assert
1 parent 7883204 commit d2b1700

File tree

4 files changed

+325
-12
lines changed

4 files changed

+325
-12
lines changed

database/pgx/pgx.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,24 @@ func (p *Postgres) ensureVersionTable() (err error) {
429429
}
430430
}()
431431

432-
query := `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
432+
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
433+
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
434+
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
435+
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
436+
var count int
437+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
438+
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
439+
440+
err = row.Scan(&count)
441+
if err != nil {
442+
return &database.Error{OrigErr: err, Query: []byte(query)}
443+
}
444+
445+
if count == 1 {
446+
return nil
447+
}
448+
449+
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
433450
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
434451
return &database.Error{OrigErr: err, Query: []byte(query)}
435452
}

database/pgx/pgx_test.go

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"database/sql"
88
sqldriver "database/sql/driver"
9+
"errors"
910
"fmt"
1011
"log"
1112

@@ -76,6 +77,14 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
7677
return true
7778
}
7879

80+
func mustRun(t *testing.T, d database.Driver, statements []string) {
81+
for _, statement := range statements {
82+
if err := d.Run(strings.NewReader(statement)); err != nil {
83+
t.Fatal(err)
84+
}
85+
}
86+
}
87+
7988
func Test(t *testing.T) {
8089
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
8190
ip, port, err := c.FirstPort()
@@ -309,6 +318,141 @@ func TestWithSchema(t *testing.T) {
309318
})
310319
}
311320

321+
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
322+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
323+
ip, port, err := c.FirstPort()
324+
if err != nil {
325+
t.Fatal(err)
326+
}
327+
328+
addr := pgConnectionString(ip, port)
329+
330+
// Check that opening the postgres connection returns NilVersion
331+
p := &Postgres{}
332+
333+
d, err := p.Open(addr)
334+
335+
if err != nil {
336+
t.Fatal(err)
337+
}
338+
339+
defer func() {
340+
if err := d.Close(); err != nil {
341+
t.Error(err)
342+
}
343+
}()
344+
345+
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
346+
// since this is a test environment and we're not expecting to the pgPassword to be malicious
347+
mustRun(t, d, []string{
348+
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
349+
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
350+
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
351+
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
352+
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
353+
})
354+
355+
// re-connect using that schema
356+
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
357+
pgPassword, ip, port))
358+
359+
defer func() {
360+
if d2 == nil {
361+
return
362+
}
363+
if err := d2.Close(); err != nil {
364+
t.Fatal(err)
365+
}
366+
}()
367+
368+
var e *database.Error
369+
if !errors.As(err, &e) || err == nil {
370+
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
371+
}
372+
373+
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
374+
t.Fatal(e)
375+
}
376+
})
377+
}
378+
379+
func TestCheckBeforeCreateTable(t *testing.T) {
380+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
381+
ip, port, err := c.FirstPort()
382+
if err != nil {
383+
t.Fatal(err)
384+
}
385+
386+
addr := pgConnectionString(ip, port)
387+
388+
// Check that opening the postgres connection returns NilVersion
389+
p := &Postgres{}
390+
391+
d, err := p.Open(addr)
392+
393+
if err != nil {
394+
t.Fatal(err)
395+
}
396+
397+
defer func() {
398+
if err := d.Close(); err != nil {
399+
t.Error(err)
400+
}
401+
}()
402+
403+
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
404+
// since this is a test environment and we're not expecting to the pgPassword to be malicious
405+
mustRun(t, d, []string{
406+
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
407+
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
408+
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
409+
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
410+
})
411+
412+
// re-connect using that schema
413+
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
414+
pgPassword, ip, port))
415+
416+
if err != nil {
417+
t.Fatal(err)
418+
}
419+
420+
if err := d2.Close(); err != nil {
421+
t.Fatal(err)
422+
}
423+
424+
// revoke privileges
425+
mustRun(t, d, []string{
426+
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
427+
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
428+
})
429+
430+
// re-connect using that schema
431+
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
432+
pgPassword, ip, port))
433+
434+
if err != nil {
435+
t.Fatal(err)
436+
}
437+
438+
version, _, err := d3.Version()
439+
440+
if err != nil {
441+
t.Fatal(err)
442+
}
443+
444+
if version != database.NilVersion {
445+
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
446+
}
447+
448+
defer func() {
449+
if err := d3.Close(); err != nil {
450+
t.Fatal(err)
451+
}
452+
}()
453+
})
454+
}
455+
312456
func TestParallelSchema(t *testing.T) {
313457
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
314458
ip, port, err := c.FirstPort()
@@ -375,10 +519,6 @@ func TestParallelSchema(t *testing.T) {
375519
})
376520
}
377521

378-
func TestWithInstance(t *testing.T) {
379-
380-
}
381-
382522
func TestPostgres_Lock(t *testing.T) {
383523
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
384524
ip, port, err := c.FirstPort()

database/postgres/postgres.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,24 @@ func (p *Postgres) ensureVersionTable() (err error) {
426426
}
427427
}()
428428

429-
query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
429+
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
430+
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
431+
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
432+
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
433+
var count int
434+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
435+
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
436+
437+
err = row.Scan(&count)
438+
if err != nil {
439+
return &database.Error{OrigErr: err, Query: []byte(query)}
440+
}
441+
442+
if count == 1 {
443+
return nil
444+
}
445+
446+
query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
430447
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
431448
return &database.Error{OrigErr: err, Query: []byte(query)}
432449
}

0 commit comments

Comments
 (0)