From 4ff8b070cd282e0c310295d0e4fc7d6a39488ca5 Mon Sep 17 00:00:00 2001 From: Jochen Schalanda Date: Thu, 8 Aug 2024 22:06:17 +0200 Subject: [PATCH] chore: add context.Context everywhere --- database/cassandra/cassandra.go | 31 +-- database/cassandra/cassandra_test.go | 12 +- database/clickhouse/clickhouse.go | 33 +-- database/clickhouse/clickhouse_test.go | 33 +-- database/cockroachdb/cockroachdb.go | 28 +-- database/cockroachdb/cockroachdb_test.go | 16 +- database/driver.go | 21 +- database/driver_test.go | 19 +- database/firebird/firebird.go | 28 +-- database/firebird/firebird_test.go | 37 +-- database/mongodb/mongodb.go | 30 +-- database/mongodb/mongodb_test.go | 53 ++-- database/mysql/mysql.go | 31 ++- database/mysql/mysql_test.go | 58 +++-- database/neo4j/neo4j.go | 17 +- database/neo4j/neo4j_test.go | 19 +- database/pgx/pgx.go | 28 +-- database/pgx/pgx_test.go | 156 ++++++------ database/pgx/v5/pgx.go | 28 +-- database/pgx/v5/pgx_test.go | 149 ++++++------ database/postgres/postgres.go | 30 ++- database/postgres/postgres_test.go | 151 ++++++------ database/ql/ql.go | 29 +-- database/ql/ql_test.go | 9 +- database/redshift/redshift.go | 28 +-- database/redshift/redshift_test.go | 61 ++--- database/rqlite/rqlite.go | 33 +-- database/rqlite/rqlite_test.go | 41 ++-- database/snowflake/snowflake.go | 28 +-- database/spanner/spanner.go | 37 ++- database/spanner/spanner_test.go | 9 +- database/sqlcipher/sqlcipher.go | 29 +-- database/sqlcipher/sqlcipher_test.go | 22 +- database/sqlite/sqlite.go | 29 +-- database/sqlite/sqlite_test.go | 25 +- database/sqlite3/sqlite3.go | 29 +-- database/sqlite3/sqlite3_test.go | 25 +- database/sqlserver/sqlserver.go | 28 +-- database/sqlserver/sqlserver_test.go | 50 ++-- database/stub/stub.go | 19 +- database/stub/stub_test.go | 9 +- database/testing/migrate_testing.go | 5 +- database/testing/testing.go | 22 +- database/yugabytedb/yugabytedb.go | 28 +-- database/yugabytedb/yugabytedb_test.go | 16 +- internal/cli/commands.go | 29 +-- internal/cli/main.go | 18 +- migrate.go | 191 +++++++-------- migrate_test.go | 293 +++++++++++++---------- source/aws_s3/s3.go | 25 +- source/aws_s3/s3_test.go | 3 +- source/bitbucket/bitbucket.go | 25 +- source/bitbucket/bitbucket_test.go | 3 +- source/driver.go | 19 +- source/file/file.go | 3 +- source/file/file_test.go | 31 +-- source/github/github.go | 22 +- source/github/github_test.go | 10 +- source/github_ee/github_ee.go | 5 +- source/github_ee/github_ee_test.go | 3 +- source/gitlab/gitlab.go | 23 +- source/gitlab/gitlab_test.go | 3 +- source/go_bindata/go-bindata.go | 23 +- source/go_bindata/go-bindata_test.go | 7 +- source/godoc_vfs/vfs.go | 5 +- source/godoc_vfs/vfs_example_test.go | 8 +- source/godoc_vfs/vfs_test.go | 5 +- source/google_cloud_storage/storage.go | 20 +- source/httpfs/driver.go | 3 +- source/httpfs/driver_test.go | 3 +- source/httpfs/partial_driver.go | 19 +- source/httpfs/partial_driver_test.go | 13 +- source/iofs/example_test.go | 6 +- source/iofs/iofs.go | 21 +- source/migration.go | 7 +- source/pkger/pkger.go | 7 +- source/pkger/pkger_test.go | 32 ++- source/stub/stub.go | 23 +- source/stub/stub_test.go | 3 +- source/testing/testing.go | 32 +-- 80 files changed, 1380 insertions(+), 1182 deletions(-) diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 74eecc98e..4470c9c3e 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "errors" "fmt" "io" @@ -52,7 +53,7 @@ type Cassandra struct { config *Config } -func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, session *gocql.Session, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } else if len(config.KeyspaceName) == 0 { @@ -76,14 +77,14 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro config: config, } - if err := c.ensureVersionTable(); err != nil { + if err := c.ensureVersionTable(ctx); err != nil { return nil, err } return c, nil } -func (c *Cassandra) Open(url string) (database.Driver, error) { +func (c *Cassandra) Open(ctx context.Context, url string) (database.Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -185,7 +186,7 @@ func (c *Cassandra) Open(url string) (database.Driver, error) { } } - return WithInstance(session, &Config{ + return WithInstance(ctx, session, &Config{ KeyspaceName: strings.TrimPrefix(u.Path, "/"), MigrationsTable: u.Query().Get("x-migrations-table"), MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true", @@ -193,26 +194,26 @@ func (c *Cassandra) Open(url string) (database.Driver, error) { }) } -func (c *Cassandra) Close() error { +func (c *Cassandra) Close(ctx context.Context) error { c.session.Close() return nil } -func (c *Cassandra) Lock() error { +func (c *Cassandra) Lock(ctx context.Context) error { if !c.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (c *Cassandra) Unlock() error { +func (c *Cassandra) Unlock(ctx context.Context) error { if !c.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (c *Cassandra) Run(migration io.Reader) error { +func (c *Cassandra) Run(ctx context.Context, migration io.Reader) error { if c.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool { @@ -243,7 +244,7 @@ func (c *Cassandra) Run(migration io.Reader) error { return nil } -func (c *Cassandra) SetVersion(version int, dirty bool) error { +func (c *Cassandra) SetVersion(ctx context.Context, version int, dirty bool) error { // DELETE instead of TRUNCATE because AWS Keyspaces does not support it // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` @@ -273,7 +274,7 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error { } // Return current keyspace version -func (c *Cassandra) Version() (version int, dirty bool, err error) { +func (c *Cassandra) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` err = c.session.Query(query).Scan(&version, &dirty) switch { @@ -291,7 +292,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) { } } -func (c *Cassandra) Drop() error { +func (c *Cassandra) Drop(ctx context.Context) error { // select all tables in current schema query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) iter := c.session.Query(query).Iter() @@ -309,13 +310,13 @@ func (c *Cassandra) Drop() error { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Cassandra type. -func (c *Cassandra) ensureVersionTable() (err error) { - if err = c.Lock(); err != nil { +func (c *Cassandra) ensureVersionTable(ctx context.Context) (err error) { + if err = c.Lock(ctx); err != nil { return err } defer func() { - if e := c.Unlock(); e != nil { + if e := c.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -328,7 +329,7 @@ func (c *Cassandra) ensureVersionTable() (err error) { if err != nil { return err } - if _, _, err = c.Version(); err != nil { + if _, _, err = c.Version(ctx); err != nil { return err } return nil diff --git a/database/cassandra/cassandra_test.go b/database/cassandra/cassandra_test.go index 84af2853e..0ca3c9f82 100644 --- a/database/cassandra/cassandra_test.go +++ b/database/cassandra/cassandra_test.go @@ -76,18 +76,19 @@ func Test(t *testing.T) { func test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(9042) if err != nil { t.Fatal("Unable to get mapped port:", err) } addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port) p := &Cassandra{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -97,23 +98,24 @@ func test(t *testing.T) { func testMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(9042) if err != nil { t.Fatal("Unable to get mapped port:", err) } addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port) p := &Cassandra{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "testks", d) if err != nil { t.Fatal(err) } diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index d2b65c0ce..c79f8b7c8 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -1,6 +1,7 @@ package clickhouse import ( + "context" "database/sql" "fmt" "io" @@ -40,7 +41,7 @@ func init() { database.Register("clickhouse", &ClickHouse{}) } -func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, conn *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -54,7 +55,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := ch.init(); err != nil { + if err := ch.init(ctx); err != nil { return nil, err } @@ -67,7 +68,7 @@ type ClickHouse struct { isLocked atomic.Bool } -func (ch *ClickHouse) Open(dsn string) (database.Driver, error) { +func (ch *ClickHouse) Open(ctx context.Context, dsn string) (database.Driver, error) { purl, err := url.Parse(dsn) if err != nil { return nil, err @@ -104,14 +105,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) { }, } - if err := ch.init(); err != nil { + if err := ch.init(ctx); err != nil { return nil, err } return ch, nil } -func (ch *ClickHouse) init() error { +func (ch *ClickHouse) init(ctx context.Context) error { if len(ch.config.DatabaseName) == 0 { if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil { return err @@ -130,10 +131,10 @@ func (ch *ClickHouse) init() error { ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine } - return ch.ensureVersionTable() + return ch.ensureVersionTable(ctx) } -func (ch *ClickHouse) Run(r io.Reader) error { +func (ch *ClickHouse) Run(ctx context.Context, r io.Reader) error { if ch.config.MultiStatementEnabled { var err error if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool { @@ -163,7 +164,7 @@ func (ch *ClickHouse) Run(r io.Reader) error { return nil } -func (ch *ClickHouse) Version() (int, bool, error) { +func (ch *ClickHouse) Version(ctx context.Context) (int, bool, error) { var ( version int dirty uint8 @@ -178,7 +179,7 @@ func (ch *ClickHouse) Version() (int, bool, error) { return version, dirty == 1, nil } -func (ch *ClickHouse) SetVersion(version int, dirty bool) error { +func (ch *ClickHouse) SetVersion(ctx context.Context, version int, dirty bool) error { var ( bool = func(v bool) uint8 { if v { @@ -203,13 +204,13 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the ClickHouse type. -func (ch *ClickHouse) ensureVersionTable() (err error) { - if err = ch.Lock(); err != nil { +func (ch *ClickHouse) ensureVersionTable(ctx context.Context) (err error) { + if err = ch.Lock(ctx); err != nil { return err } defer func() { - if e := ch.Unlock(); e != nil { + if e := ch.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -258,7 +259,7 @@ func (ch *ClickHouse) ensureVersionTable() (err error) { return nil } -func (ch *ClickHouse) Drop() (err error) { +func (ch *ClickHouse) Drop(ctx context.Context) (err error) { query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) tables, err := ch.conn.Query(query) @@ -290,21 +291,21 @@ func (ch *ClickHouse) Drop() (err error) { return nil } -func (ch *ClickHouse) Lock() error { +func (ch *ClickHouse) Lock(ctx context.Context) error { if !ch.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (ch *ClickHouse) Unlock() error { +func (ch *ClickHouse) Unlock(ctx context.Context) error { if !ch.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (ch *ClickHouse) Close() error { return ch.conn.Close() } +func (ch *ClickHouse) Close(ctx context.Context) error { return ch.conn.Close() } // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { diff --git a/database/clickhouse/clickhouse_test.go b/database/clickhouse/clickhouse_test.go index 694aa2a7b..18b67fd5a 100644 --- a/database/clickhouse/clickhouse_test.go +++ b/database/clickhouse/clickhouse_test.go @@ -85,6 +85,7 @@ func TestCases(t *testing.T) { func testSimple(t *testing.T, engine string) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -92,12 +93,12 @@ func testSimple(t *testing.T, engine string) { addr := clickhouseConnectionString(ip, port, engine) p := &clickhouse.ClickHouse{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -108,6 +109,7 @@ func testSimple(t *testing.T, engine string) { func testSimpleWithInstanceDefaultConfigValues(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -118,13 +120,13 @@ func testSimpleWithInstanceDefaultConfigValues(t *testing.T) { if err != nil { t.Fatal(err) } - d, err := clickhouse.WithInstance(conn, &clickhouse.Config{}) + d, err := clickhouse.WithInstance(ctx, conn, &clickhouse.Config{}) if err != nil { _ = conn.Close() t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -135,6 +137,7 @@ func testSimpleWithInstanceDefaultConfigValues(t *testing.T) { func testMigrate(t *testing.T, engine string) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -142,16 +145,16 @@ func testMigrate(t *testing.T, engine string) { addr := clickhouseConnectionString(ip, port, engine) p := &clickhouse.ClickHouse{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "db", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "db", d) if err != nil { t.Fatal(err) @@ -164,6 +167,7 @@ func testVersion(t *testing.T, engine string) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { expectedVersion := 1 + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -171,22 +175,22 @@ func testVersion(t *testing.T, engine string) { addr := clickhouseConnectionString(ip, port, engine) p := &clickhouse.ClickHouse{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - err = d.SetVersion(expectedVersion, false) + err = d.SetVersion(ctx, expectedVersion, false) if err != nil { t.Fatal(err) } - version, _, err := d.Version() + version, _, err := d.Version(ctx) if err != nil { t.Fatal(err) } @@ -199,6 +203,7 @@ func testVersion(t *testing.T, engine string) { func testDrop(t *testing.T, engine string) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -206,17 +211,17 @@ func testDrop(t *testing.T, engine string) { addr := clickhouseConnectionString(ip, port, engine) p := &clickhouse.ClickHouse{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - err = d.Drop() + err = d.Drop(ctx) if err != nil { t.Fatal(err) } diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 699b3facd..19b462683 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -47,7 +47,7 @@ type CockroachDb struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -88,14 +88,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (c *CockroachDb) Open(url string) (database.Driver, error) { +func (c *CockroachDb) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -127,7 +127,7 @@ func (c *CockroachDb) Open(url string) (database.Driver, error) { forceLock = false } - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, LockTable: lockTable, @@ -140,13 +140,13 @@ func (c *CockroachDb) Open(url string) (database.Driver, error) { return px, nil } -func (c *CockroachDb) Close() error { +func (c *CockroachDb) Close(ctx context.Context) error { return c.db.Close() } // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 -func (c *CockroachDb) Lock() error { +func (c *CockroachDb) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) { return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) { aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) @@ -183,7 +183,7 @@ func (c *CockroachDb) Lock() error { // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 -func (c *CockroachDb) Unlock() error { +func (c *CockroachDb) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) { aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) if err != nil { @@ -210,7 +210,7 @@ func (c *CockroachDb) Unlock() error { }) } -func (c *CockroachDb) Run(migration io.Reader) error { +func (c *CockroachDb) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -225,7 +225,7 @@ func (c *CockroachDb) Run(migration io.Reader) error { return nil } -func (c *CockroachDb) SetVersion(version int, dirty bool) error { +func (c *CockroachDb) SetVersion(ctx context.Context, version int, dirty bool) error { return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error { if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err @@ -244,7 +244,7 @@ func (c *CockroachDb) SetVersion(version int, dirty bool) error { }) } -func (c *CockroachDb) Version() (version int, dirty bool, err error) { +func (c *CockroachDb) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` err = c.db.QueryRow(query).Scan(&version, &dirty) @@ -267,7 +267,7 @@ func (c *CockroachDb) Version() (version int, dirty bool, err error) { } } -func (c *CockroachDb) Drop() (err error) { +func (c *CockroachDb) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())` tables, err := c.db.Query(query) @@ -311,13 +311,13 @@ func (c *CockroachDb) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the CockroachDb type. -func (c *CockroachDb) ensureVersionTable() (err error) { - if err = c.Lock(); err != nil { +func (c *CockroachDb) ensureVersionTable(ctx context.Context) (err error) { + if err = c.Lock(ctx); err != nil { return err } defer func() { - if e := c.Unlock(); e != nil { + if e := c.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/cockroachdb/cockroachdb_test.go b/database/cockroachdb/cockroachdb_test.go index d00e27503..218c29d97 100644 --- a/database/cockroachdb/cockroachdb_test.go +++ b/database/cockroachdb/cockroachdb_test.go @@ -86,6 +86,7 @@ func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(26257) if err != nil { t.Fatal(err) @@ -93,7 +94,7 @@ func Test(t *testing.T) { addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port) c := &CockroachDb{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -105,6 +106,7 @@ func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(26257) if err != nil { t.Fatal(err) @@ -112,12 +114,12 @@ func TestMigrate(t *testing.T) { addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port) c := &CockroachDb{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "migrate", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "migrate", d) if err != nil { t.Fatal(err) } @@ -129,6 +131,7 @@ func TestMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(26257) if err != nil { t.Fatal(err) @@ -136,11 +139,11 @@ func TestMultiStatement(t *testing.T) { addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port) c := &CockroachDb{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -159,6 +162,7 @@ func TestFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(26257) if err != nil { t.Fatal(err) @@ -166,7 +170,7 @@ func TestFilterCustomQuery(t *testing.T) { addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable&x-custom=foobar", ip, port) c := &CockroachDb{} - _, err = c.Open(addr) + _, err = c.Open(ctx, addr) if err != nil { t.Fatal(err) } diff --git a/database/driver.go b/database/driver.go index 11268e6b9..096022397 100644 --- a/database/driver.go +++ b/database/driver.go @@ -5,6 +5,7 @@ package database import ( + "context" "fmt" "io" "sync" @@ -46,43 +47,43 @@ type Driver interface { // Open returns a new driver instance configured with parameters // coming from the URL string. Migrate will call this function // only once per instance. - Open(url string) (Driver, error) + Open(ctx context.Context, url string) (Driver, error) // Close closes the underlying database instance managed by the driver. // Migrate will call this function only once per instance. - Close() error + Close(ctx context.Context) error // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. // Return database.ErrLocked if database is already locked. - Lock() error + Lock(ctx context.Context) error // Unlock should release the lock. Migrate will call this function after // all migrations have been run. - Unlock() error + Unlock(ctx context.Context) error // Run applies a migration to the database. migration is guaranteed to be not nil. - Run(migration io.Reader) error + Run(ctx context.Context, migration io.Reader) error // SetVersion saves version and dirty state. // Migrate will call this function before and after each call to Run. // version must be >= -1. -1 means NilVersion. - SetVersion(version int, dirty bool) error + SetVersion(ctx context.Context, version int, dirty bool) error // Version returns the currently active version and if the database is dirty. // When no migration has been applied, it must return version -1. // Dirty means, a previous migration failed and user interaction is required. - Version() (version int, dirty bool, err error) + Version(ctx context.Context) (version int, dirty bool, err error) // Drop deletes everything in the database. // Note that this is a breaking action, a new call to Open() is necessary to // ensure subsequent calls work as expected. - Drop() error + Drop(ctx context.Context) error } // Open returns a new driver instance. -func Open(url string) (Driver, error) { +func Open(ctx context.Context, url string) (Driver, error) { scheme, err := iurl.SchemeFromURL(url) if err != nil { return nil, err @@ -95,7 +96,7 @@ func Open(url string) (Driver, error) { return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme) } - return d.Open(url) + return d.Open(ctx, url) } // Register globally registers a driver. diff --git a/database/driver_test.go b/database/driver_test.go index 7880f3208..c79459635 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "io" "testing" ) @@ -18,37 +19,37 @@ type mockDriver struct { url string } -func (m *mockDriver) Open(url string) (Driver, error) { +func (m *mockDriver) Open(ctx context.Context, url string) (Driver, error) { return &mockDriver{ url: url, }, nil } -func (m *mockDriver) Close() error { +func (m *mockDriver) Close(ctx context.Context) error { return nil } -func (m *mockDriver) Lock() error { +func (m *mockDriver) Lock(ctx context.Context) error { return nil } -func (m *mockDriver) Unlock() error { +func (m *mockDriver) Unlock(ctx context.Context) error { return nil } -func (m *mockDriver) Run(migration io.Reader) error { +func (m *mockDriver) Run(ctx context.Context, migration io.Reader) error { return nil } -func (m *mockDriver) SetVersion(version int, dirty bool) error { +func (m *mockDriver) SetVersion(ctx context.Context, version int, dirty bool) error { return nil } -func (m *mockDriver) Version() (version int, dirty bool, err error) { +func (m *mockDriver) Version(ctx context.Context) (version int, dirty bool, err error) { return 0, false, nil } -func (m *mockDriver) Drop() error { +func (m *mockDriver) Drop(ctx context.Context) error { return nil } @@ -95,7 +96,7 @@ func TestOpen(t *testing.T) { for _, c := range cases { t.Run(c.url, func(t *testing.T) { - d, err := Open(c.url) + d, err := Open(context.Background(), c.url) if err == nil { if c.err { diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index f564cf74f..f04cb1bda 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -44,7 +44,7 @@ type Firebird struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -68,14 +68,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := fb.ensureVersionTable(); err != nil { + if err := fb.ensureVersionTable(ctx); err != nil { return nil, err } return fb, nil } -func (f *Firebird) Open(dsn string) (database.Driver, error) { +func (f *Firebird) Open(ctx context.Context, dsn string) (database.Driver, error) { purl, err := nurl.Parse(dsn) if err != nil { return nil, err @@ -86,7 +86,7 @@ func (f *Firebird) Open(dsn string) (database.Driver, error) { return nil, err } - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ MigrationsTable: purl.Query().Get("x-migrations-table"), DatabaseName: purl.Path, }) @@ -98,7 +98,7 @@ func (f *Firebird) Open(dsn string) (database.Driver, error) { return px, nil } -func (f *Firebird) Close() error { +func (f *Firebird) Close(ctx context.Context) error { connErr := f.conn.Close() dbErr := f.db.Close() if connErr != nil || dbErr != nil { @@ -107,21 +107,21 @@ func (f *Firebird) Close() error { return nil } -func (f *Firebird) Lock() error { +func (f *Firebird) Lock(ctx context.Context) error { if !f.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (f *Firebird) Unlock() error { +func (f *Firebird) Unlock(ctx context.Context) error { if !f.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (f *Firebird) Run(migration io.Reader) error { +func (f *Firebird) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -136,7 +136,7 @@ func (f *Firebird) Run(migration io.Reader) error { return nil } -func (f *Firebird) SetVersion(version int, dirty bool) error { +func (f *Firebird) SetVersion(ctx context.Context, version int, dirty bool) error { // Always re-write the schema version to prevent empty schema version // for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 @@ -157,7 +157,7 @@ func (f *Firebird) SetVersion(version int, dirty bool) error { return nil } -func (f *Firebird) Version() (version int, dirty bool, err error) { +func (f *Firebird) Version(ctx context.Context) (version int, dirty bool, err error) { var d int query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable) err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d) @@ -172,7 +172,7 @@ func (f *Firebird) Version() (version int, dirty bool, err error) { } } -func (f *Firebird) Drop() (err error) { +func (f *Firebird) Drop(ctx context.Context) (err error) { // select all tables query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);` tables, err := f.conn.QueryContext(context.Background(), query) @@ -217,13 +217,13 @@ func (f *Firebird) Drop() (err error) { } // ensureVersionTable checks if versions table exists and, if not, creates it. -func (f *Firebird) ensureVersionTable() (err error) { - if err = f.Lock(); err != nil { +func (f *Firebird) ensureVersionTable(ctx context.Context) (err error) { + if err = f.Lock(ctx); err != nil { return err } defer func() { - if e := f.Unlock(); e != nil { + if e := f.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/firebird/firebird_test.go b/database/firebird/firebird_test.go index 21fa8de2d..0345407cc 100644 --- a/database/firebird/firebird_test.go +++ b/database/firebird/firebird_test.go @@ -79,6 +79,7 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -86,12 +87,12 @@ func Test(t *testing.T) { addr := fbConnectionString(ip, port) p := &Firebird{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -101,6 +102,7 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -108,16 +110,16 @@ func TestMigrate(t *testing.T) { addr := fbConnectionString(ip, port) p := &Firebird{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "firebirdsql", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "firebirdsql", d) if err != nil { t.Fatal(err) } @@ -127,6 +129,7 @@ func TestMigrate(t *testing.T) { func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -134,12 +137,12 @@ func TestErrorParsing(t *testing.T) { addr := fbConnectionString(ip, port) p := &Firebird{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -150,7 +153,7 @@ Token unknown - line 1, column 8 TABLEE )` - if err := d.Run(strings.NewReader("CREATE TABLEE foo (foo varchar(40));")); err == nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLEE foo (foo varchar(40));")); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { msg := err.Error() @@ -161,6 +164,7 @@ TABLEE func TestFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -168,12 +172,12 @@ func TestFilterCustomQuery(t *testing.T) { addr := fbConnectionString(ip, port) + "?sslmode=disable&x-custom=foobar" p := &Firebird{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -182,6 +186,7 @@ func TestFilterCustomQuery(t *testing.T) { func Test_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -189,12 +194,12 @@ func Test_Lock(t *testing.T) { addr := fbConnectionString(ip, port) p := &Firebird{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -203,22 +208,22 @@ func Test_Lock(t *testing.T) { ps := d.(*Firebird) - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 3a9a6be9e..7400b2837 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -75,7 +75,7 @@ type findFilter struct { Key int `bson:"locking_key"` } -func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *mongo.Client, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -106,14 +106,14 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro return nil, err } } - if err := mc.ensureVersionTable(); err != nil { + if err := mc.ensureVersionTable(ctx); err != nil { return nil, err } return mc, nil } -func (m *Mongo) Open(dsn string) (database.Driver, error) { +func (m *Mongo) Open(ctx context.Context, dsn string) (database.Driver, error) { // connstring is experimental package, but it used for parse connection string in mongo.Connect function uri, err := connstring.Parse(dsn) if err != nil { @@ -165,7 +165,7 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) { if err = client.Ping(context.TODO(), nil); err != nil { return nil, err } - mc, err := WithInstance(client, &Config{ + mc, err := WithInstance(ctx, client, &Config{ DatabaseName: uri.Database, MigrationsCollection: migrationsCollection, TransactionMode: transactionMode, @@ -215,7 +215,7 @@ func parseInt(urlParam string, defaultValue int) (int, error) { // if no url Param passed, return default value return defaultValue, nil } -func (m *Mongo) SetVersion(version int, dirty bool) error { +func (m *Mongo) SetVersion(ctx context.Context, version int, dirty bool) error { migrationsCollection := m.db.Collection(m.config.MigrationsCollection) if err := migrationsCollection.Drop(context.TODO()); err != nil { return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} @@ -227,7 +227,7 @@ func (m *Mongo) SetVersion(version int, dirty bool) error { return nil } -func (m *Mongo) Version() (version int, dirty bool, err error) { +func (m *Mongo) Version(ctx context.Context) (version int, dirty bool, err error) { var versionInfo versionInfo err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo) switch { @@ -240,7 +240,7 @@ func (m *Mongo) Version() (version int, dirty bool, err error) { } } -func (m *Mongo) Run(migration io.Reader) error { +func (m *Mongo) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -293,11 +293,11 @@ func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error { return nil } -func (m *Mongo) Close() error { +func (m *Mongo) Close(ctx context.Context) error { return m.client.Disconnect(context.TODO()) } -func (m *Mongo) Drop() error { +func (m *Mongo) Drop(ctx context.Context) error { return m.db.Drop(context.TODO()) } @@ -318,13 +318,13 @@ func (m *Mongo) ensureLockTable() error { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the MongoDb type. -func (m *Mongo) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Mongo) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -336,7 +336,7 @@ func (m *Mongo) ensureVersionTable() (err error) { if err != nil { return err } - if _, _, err = m.Version(); err != nil { + if _, _, err = m.Version(ctx); err != nil { return err } return nil @@ -344,7 +344,7 @@ func (m *Mongo) ensureVersionTable() (err error) { // Utilizes advisory locking on the config.LockingCollection collection // This uses a unique index on the `locking_key` field. -func (m *Mongo) Lock() error { +func (m *Mongo) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { if !m.config.Locking.Enabled { return nil @@ -382,7 +382,7 @@ func (m *Mongo) Lock() error { }) } -func (m *Mongo) Unlock() error { +func (m *Mongo) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { if !m.config.Locking.Enabled { return nil diff --git a/database/mongodb/mongodb_test.go b/database/mongodb/mongodb_test.go index 8823bd9ca..f5e41db07 100644 --- a/database/mongodb/mongodb_test.go +++ b/database/mongodb/mongodb_test.go @@ -91,6 +91,7 @@ func Test(t *testing.T) { func test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -98,12 +99,12 @@ func test(t *testing.T) { addr := mongoConnectionString(ip, port) p := &Mongo{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -117,6 +118,7 @@ func test(t *testing.T) { func testMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -124,16 +126,16 @@ func testMigrate(t *testing.T) { addr := mongoConnectionString(ip, port) p := &Mongo{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "", d) if err != nil { t.Fatal(err) } @@ -143,6 +145,7 @@ func testMigrate(t *testing.T) { func testWithAuth(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -150,17 +153,17 @@ func testWithAuth(t *testing.T) { addr := mongoConnectionString(ip, port) p := &Mongo{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() createUserCMD := []byte(`[{"createUser":"deminem","pwd":"gogo","roles":[{"role":"readWrite","db":"testMigration"}]}]`) - err = d.Run(bytes.NewReader(createUserCMD)) + err = d.Run(ctx, bytes.NewReader(createUserCMD)) if err != nil { t.Fatal(err) } @@ -176,10 +179,10 @@ func testWithAuth(t *testing.T) { for _, tcase := range testcases { t.Run(tcase.name, func(t *testing.T) { mc := &Mongo{} - d, err := mc.Open(fmt.Sprintf(tcase.connectUri, ip, port)) + d, err := mc.Open(ctx, fmt.Sprintf(tcase.connectUri, ip, port)) if err == nil { defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -198,6 +201,7 @@ func testWithAuth(t *testing.T) { func testLockWorks(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -205,12 +209,12 @@ func testLockWorks(t *testing.T) { addr := mongoConnectionString(ip, port) p := &Mongo{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -219,20 +223,20 @@ func testLockWorks(t *testing.T) { mc := d.(*Mongo) - err = mc.Lock() + err = mc.Lock(ctx) if err != nil { t.Fatal(err) } - err = mc.Unlock() + err = mc.Unlock(ctx) if err != nil { t.Fatal(err) } - err = mc.Lock() + err = mc.Lock(ctx) if err != nil { t.Fatal(err) } - err = mc.Unlock() + err = mc.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -241,11 +245,11 @@ func testLockWorks(t *testing.T) { //try to hit a lock conflict mc.config.Locking.Enabled = true mc.config.Locking.Timeout = 1 - err = mc.Lock() + err = mc.Lock(ctx) if err != nil { t.Fatal(err) } - err = mc.Lock() + err = mc.Lock(ctx) if err == nil { t.Fatal("should have failed, mongo should be locked already") } @@ -267,6 +271,7 @@ func TestTransaction(t *testing.T) { }) dktesting.ParallelTest(t, transactionSpecs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -289,14 +294,14 @@ func TestTransaction(t *testing.T) { if err != nil { t.Fatal(err) } - d, err := WithInstance(client, &Config{ + d, err := WithInstance(ctx, client, &Config{ DatabaseName: "testMigration", }) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -315,7 +320,7 @@ func TestTransaction(t *testing.T) { "background": true }] }]`) - err = d.Run(bytes.NewReader(insertCMD)) + err = d.Run(ctx, bytes.NewReader(insertCMD)) if err != nil { t.Fatal(err) } @@ -360,7 +365,7 @@ func TestTransaction(t *testing.T) { if err != nil { t.Fatal(err) } - d, err := WithInstance(client, &Config{ + d, err := WithInstance(ctx, client, &Config{ DatabaseName: "testMigration", TransactionMode: true, }) @@ -368,11 +373,11 @@ func TestTransaction(t *testing.T) { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - runErr := d.Run(bytes.NewReader(tcase.cmds)) + runErr := d.Run(ctx, bytes.NewReader(tcase.cmds)) if runErr != nil { if !tcase.isErrorExpected { t.Fatal(runErr) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index c7e7ef617..f2ee2400e 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -90,7 +90,7 @@ func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql config.MigrationsTable = DefaultMigrationsTable } - if err := mx.ensureVersionTable(); err != nil { + if err := mx.ensureVersionTable(ctx); err != nil { return nil, err } @@ -98,9 +98,7 @@ func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql } // instance must have `multiStatements` set to true -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { - ctx := context.Background() - +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if err := instance.Ping(); err != nil { return nil, err } @@ -225,7 +223,7 @@ func urlToMySQLConfig(url string) (*mysql.Config, error) { return config, nil } -func (m *Mysql) Open(url string) (database.Driver, error) { +func (m *Mysql) Open(ctx context.Context, url string) (database.Driver, error) { config, err := urlToMySQLConfig(url) if err != nil { return nil, err @@ -258,7 +256,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) { return nil, err } - mx, err := WithInstance(db, &Config{ + mx, err := WithInstance(ctx, db, &Config{ DatabaseName: config.DBName, MigrationsTable: customParams["x-migrations-table"], NoLock: noLock, @@ -271,7 +269,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) { return mx, nil } -func (m *Mysql) Close() error { +func (m *Mysql) Close(ctx context.Context) error { connErr := m.conn.Close() var dbErr error if m.db != nil { @@ -284,7 +282,7 @@ func (m *Mysql) Close() error { return nil } -func (m *Mysql) Lock() error { +func (m *Mysql) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { if m.config.NoLock { return nil @@ -309,7 +307,7 @@ func (m *Mysql) Lock() error { }) } -func (m *Mysql) Unlock() error { +func (m *Mysql) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { if m.config.NoLock { return nil @@ -334,13 +332,12 @@ func (m *Mysql) Unlock() error { }) } -func (m *Mysql) Run(migration io.Reader) error { +func (m *Mysql) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err } - ctx := context.Background() if m.config.StatementTimeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout) @@ -355,7 +352,7 @@ func (m *Mysql) Run(migration io.Reader) error { return nil } -func (m *Mysql) SetVersion(version int, dirty bool) error { +func (m *Mysql) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -389,7 +386,7 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { return nil } -func (m *Mysql) Version() (version int, dirty bool, err error) { +func (m *Mysql) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -409,7 +406,7 @@ func (m *Mysql) Version() (version int, dirty bool, err error) { } } -func (m *Mysql) Drop() (err error) { +func (m *Mysql) Drop(ctx context.Context) (err error) { // select all tables query := `SHOW TABLES LIKE '%'` tables, err := m.conn.QueryContext(context.Background(), query) @@ -464,13 +461,13 @@ func (m *Mysql) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Mysql type. -func (m *Mysql) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Mysql) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index c7df8162d..e2bba7ca9 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -89,6 +89,7 @@ func Test(t *testing.T) { // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -96,23 +97,23 @@ func Test(t *testing.T) { addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() dt.Test(t, d, []byte("SELECT 1")) // check ensureVersionTable - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } // check again - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } }) @@ -122,6 +123,7 @@ func TestMigrate(t *testing.T) { // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -129,28 +131,28 @@ func TestMigrate(t *testing.T) { addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "public", d) if err != nil { t.Fatal(err) } dt.TestMigrate(t, m) // check ensureVersionTable - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } // check again - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } }) @@ -160,6 +162,7 @@ func TestMigrateAnsiQuotes(t *testing.T) { // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -167,28 +170,28 @@ func TestMigrateAnsiQuotes(t *testing.T) { addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "public", d) if err != nil { t.Fatal(err) } dt.TestMigrate(t, m) // check ensureVersionTable - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } // check again - if err := d.(*Mysql).ensureVersionTable(); err != nil { + if err := d.(*Mysql).ensureVersionTable(ctx); err != nil { t.Fatal(err) } }) @@ -196,6 +199,7 @@ func TestMigrateAnsiQuotes(t *testing.T) { func TestLockWorks(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -203,7 +207,7 @@ func TestLockWorks(t *testing.T) { addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -211,21 +215,21 @@ func TestLockWorks(t *testing.T) { ms := d.(*Mysql) - err = ms.Lock() + err = ms.Lock(ctx) if err != nil { t.Fatal(err) } - err = ms.Unlock() + err = ms.Unlock(ctx) if err != nil { t.Fatal(err) } // make sure the 2nd lock works (RELEASE_LOCK is very finicky) - err = ms.Lock() + err = ms.Lock(ctx) if err != nil { t.Fatal(err) } - err = ms.Unlock() + err = ms.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -233,11 +237,12 @@ func TestLockWorks(t *testing.T) { } func TestNoLockParamValidation(t *testing.T) { + ctx := context.Background() ip := "127.0.0.1" port := 3306 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - _, err := p.Open(addr + "?x-no-lock=not-a-bool") + _, err := p.Open(ctx, addr+"?x-no-lock=not-a-bool") if !errors.Is(err, strconv.ErrSyntax) { t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter") } @@ -245,6 +250,7 @@ func TestNoLockParamValidation(t *testing.T) { func TestNoLockWorks(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) @@ -252,7 +258,7 @@ func TestNoLockWorks(t *testing.T) { addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) p := &Mysql{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -260,7 +266,7 @@ func TestNoLockWorks(t *testing.T) { lock := d.(*Mysql) p = &Mysql{} - d, err = p.Open(addr + "?x-no-lock=true") + d, err = p.Open(ctx, addr+"?x-no-lock=true") if err != nil { t.Fatal(err) } @@ -268,16 +274,16 @@ func TestNoLockWorks(t *testing.T) { noLock := d.(*Mysql) // Should be possible to take real lock and no-lock at the same time - if err = lock.Lock(); err != nil { + if err = lock.Lock(ctx); err != nil { t.Fatal(err) } - if err = noLock.Lock(); err != nil { + if err = noLock.Lock(ctx); err != nil { t.Fatal(err) } - if err = lock.Unlock(); err != nil { + if err = lock.Unlock(ctx); err != nil { t.Fatal(err) } - if err = noLock.Unlock(); err != nil { + if err = noLock.Unlock(ctx); err != nil { t.Fatal(err) } }) diff --git a/database/neo4j/neo4j.go b/database/neo4j/neo4j.go index 179e0da60..b52547f47 100644 --- a/database/neo4j/neo4j.go +++ b/database/neo4j/neo4j.go @@ -2,6 +2,7 @@ package neo4j import ( "bytes" + "context" "fmt" "io" neturl "net/url" @@ -61,7 +62,7 @@ func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) return nDriver, nil } -func (n *Neo4j) Open(url string) (database.Driver, error) { +func (n *Neo4j) Open(ctx context.Context, url string) (database.Driver, error) { uri, err := neturl.Parse(url) if err != nil { return nil, err @@ -114,12 +115,12 @@ func (n *Neo4j) Open(url string) (database.Driver, error) { }) } -func (n *Neo4j) Close() error { +func (n *Neo4j) Close(ctx context.Context) error { return n.driver.Close() } // local locking in order to pass tests, Neo doesn't support database locking -func (n *Neo4j) Lock() error { +func (n *Neo4j) Lock(ctx context.Context) error { if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { return database.ErrLocked } @@ -127,14 +128,14 @@ func (n *Neo4j) Lock() error { return nil } -func (n *Neo4j) Unlock() error { +func (n *Neo4j) Unlock(ctx context.Context) error { if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { return database.ErrNotLocked } return nil } -func (n *Neo4j) Run(migration io.Reader) (err error) { +func (n *Neo4j) Run(ctx context.Context, migration io.Reader) (err error) { session, err := n.driver.Session(neo4j.AccessModeWrite) if err != nil { return err @@ -181,7 +182,7 @@ func (n *Neo4j) Run(migration io.Reader) (err error) { return err } -func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { +func (n *Neo4j) SetVersion(ctx context.Context, version int, dirty bool) (err error) { session, err := n.driver.Session(neo4j.AccessModeWrite) if err != nil { return err @@ -206,7 +207,7 @@ type MigrationRecord struct { Dirty bool } -func (n *Neo4j) Version() (version int, dirty bool, err error) { +func (n *Neo4j) Version(ctx context.Context) (version int, dirty bool, err error) { session, err := n.driver.Session(neo4j.AccessModeRead) if err != nil { return database.NilVersion, false, err @@ -254,7 +255,7 @@ ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, return mr.Version, mr.Dirty, err } -func (n *Neo4j) Drop() (err error) { +func (n *Neo4j) Drop(ctx context.Context) (err error) { session, err := n.driver.Session(neo4j.AccessModeWrite) if err != nil { return err diff --git a/database/neo4j/neo4j_test.go b/database/neo4j/neo4j_test.go index c8e914525..b66ef8560 100644 --- a/database/neo4j/neo4j_test.go +++ b/database/neo4j/neo4j_test.go @@ -67,18 +67,19 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(7687) if err != nil { t.Fatal(err) } n := &Neo4j{} - d, err := n.Open(neoConnectionString(ip, port)) + d, err := n.Open(ctx, neoConnectionString(ip, port)) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -88,6 +89,7 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(7687) if err != nil { t.Fatal(err) @@ -95,16 +97,16 @@ func TestMigrate(t *testing.T) { n := &Neo4j{} neoUrl := neoConnectionString(ip, port) + "/?x-multi-statement=true" - d, err := n.Open(neoUrl) + d, err := n.Open(ctx, neoUrl) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "neo4j", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "neo4j", d) if err != nil { t.Fatal(err) } @@ -114,24 +116,25 @@ func TestMigrate(t *testing.T) { func TestMalformed(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(7687) if err != nil { t.Fatal(err) } n := &Neo4j{} - d, err := n.Open(neoConnectionString(ip, port)) + d, err := n.Open(ctx, neoConnectionString(ip, port)) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() migration := bytes.NewReader([]byte("CREATE (a {qid: 1) RETURN a")) - if err := d.Run(migration); err == nil { + if err := d.Run(ctx, migration); err == nil { t.Fatal("expected failure for malformed migration") } }) diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7e42d29c9..e433c48ce 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -77,7 +77,7 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -155,14 +155,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (p *Postgres) Open(url string) (database.Driver, error) { +func (p *Postgres) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -221,7 +221,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { lockStrategy := purl.Query().Get("x-lock-strategy") lockTable := purl.Query().Get("x-lock-table") - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, MigrationsTableQuoted: migrationsTableQuoted, @@ -239,7 +239,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { return px, nil } -func (p *Postgres) Close() error { +func (p *Postgres) Close(ctx context.Context) error { connErr := p.conn.Close() dbErr := p.db.Close() if connErr != nil || dbErr != nil { @@ -248,7 +248,7 @@ func (p *Postgres) Close() error { return nil } -func (p *Postgres) Lock() error { +func (p *Postgres) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { switch p.config.LockStrategy { case LockStrategyAdvisory: @@ -261,7 +261,7 @@ func (p *Postgres) Lock() error { }) } -func (p *Postgres) Unlock() error { +func (p *Postgres) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { switch p.config.LockStrategy { case LockStrategyAdvisory: @@ -360,7 +360,7 @@ func (p *Postgres) releaseTableLock() error { return nil } -func (p *Postgres) Run(migration io.Reader) error { +func (p *Postgres) Run(ctx context.Context, migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { @@ -447,7 +447,7 @@ func runesLastIndex(input []rune, target rune) int { return -1 } -func (p *Postgres) SetVersion(version int, dirty bool) error { +func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -481,7 +481,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return nil } -func (p *Postgres) Version() (version int, dirty bool, err error) { +func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -501,7 +501,7 @@ func (p *Postgres) Version() (version int, dirty bool, err error) { } } -func (p *Postgres) Drop() (err error) { +func (p *Postgres) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := p.conn.QueryContext(context.Background(), query) @@ -551,13 +551,13 @@ func (p *Postgres) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Postgres type. -func (p *Postgres) ensureVersionTable() (err error) { - if err = p.Lock(); err != nil { +func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) { + if err = p.Lock(ctx); err != nil { return err } defer func() { - if e := p.Unlock(); e != nil { + if e := p.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 03977973d..5ca0c1298 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -75,9 +75,9 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return true } -func mustRun(t *testing.T, d database.Driver, statements []string) { +func mustRun(t *testing.T, ctx context.Context, d database.Driver, statements []string) { for _, statement := range statements { - if err := d.Run(strings.NewReader(statement)); err != nil { + if err := d.Run(ctx, strings.NewReader(statement)); err != nil { t.Fatal(err) } } @@ -85,6 +85,7 @@ func mustRun(t *testing.T, d database.Driver, statements []string) { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -92,12 +93,12 @@ func Test(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -107,6 +108,7 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -114,16 +116,16 @@ func TestMigrate(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "pgx", d) if err != nil { t.Fatal(err) } @@ -133,6 +135,7 @@ func TestMigrate(t *testing.T) { func TestMigrateLockTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -140,16 +143,16 @@ func TestMigrateLockTable(t *testing.T) { addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "pgx", d) if err != nil { t.Fatal(err) } @@ -159,6 +162,7 @@ func TestMigrateLockTable(t *testing.T) { func TestMultipleStatements(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -166,16 +170,16 @@ func TestMultipleStatements(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -192,6 +196,7 @@ func TestMultipleStatements(t *testing.T) { func TestMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -199,16 +204,16 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { addr := pgConnectionString(ip, port, "x-multi-statement=true") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -225,6 +230,7 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -232,19 +238,19 @@ func TestErrorParsing(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` + `(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))` - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -254,6 +260,7 @@ func TestErrorParsing(t *testing.T) { func TestFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -261,12 +268,12 @@ func TestFilterCustomQuery(t *testing.T) { addr := pgConnectionString(ip, port, "x-custom=foobar") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -275,6 +282,7 @@ func TestFilterCustomQuery(t *testing.T) { func TestWithSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -282,36 +290,36 @@ func TestWithSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create foobar schema - if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.SetVersion(1, false); err != nil { + if err := d.SetVersion(ctx, 1, false); err != nil { t.Fatal(err) } // re-connect using that schema - d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar")) + d2, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=foobar")) if err != nil { t.Fatal(err) } defer func() { - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() - version, _, err := d2.Version() + version, _, err := d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -320,10 +328,10 @@ func TestWithSchema(t *testing.T) { } // now update version and compare - if err := d2.SetVersion(2, false); err != nil { + if err := d2.SetVersion(ctx, 2, false); err != nil { t.Fatal(err) } - version, _, err = d2.Version() + version, _, err = d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -332,7 +340,7 @@ func TestWithSchema(t *testing.T) { } // meanwhile, the public schema still has the other version - version, _, err = d.Version() + version, _, err = d.Version(ctx) if err != nil { t.Fatal(err) } @@ -344,6 +352,7 @@ func TestWithSchema(t *testing.T) { func TestMigrationTableOption(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -351,21 +360,21 @@ func TestMigrationTableOption(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, _ := p.Open(addr) + d, _ := p.Open(ctx, addr) defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create migrate schema - if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // bad unquoted x-migrations-table parameter wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -373,14 +382,14 @@ func TestMigrationTableOption(t *testing.T) { // too many quoted x-migrations-table parameters wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) } // good quoted x-migrations-table parameter - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -395,7 +404,7 @@ func TestMigrationTableOption(t *testing.T) { t.Fatalf("expected table migrate.schema_migrations to exist") } - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -412,6 +421,7 @@ func TestMigrationTableOption(t *testing.T) { func TestFailToCreateTableWithoutPermissions(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -422,21 +432,21 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -445,14 +455,14 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) defer func() { if d2 == nil { return } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() @@ -467,7 +477,7 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { } // re-connect using that x-migrations-table and x-migrations-table-quoted - d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", + d2, err = p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if !errors.As(err, &e) || err == nil { @@ -482,6 +492,7 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { func TestCheckBeforeCreateTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -492,21 +503,21 @@ func TestCheckBeforeCreateTable(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -514,32 +525,32 @@ func TestCheckBeforeCreateTable(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } // revoke privileges - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", }) // re-connect using that schema - d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d3, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - version, _, err := d3.Version() + version, _, err := d3.Version(ctx) if err != nil { t.Fatal(err) @@ -550,7 +561,7 @@ func TestCheckBeforeCreateTable(t *testing.T) { } defer func() { - if err := d3.Close(); err != nil { + if err := d3.Close(ctx); err != nil { t.Fatal(err) } }() @@ -559,6 +570,7 @@ func TestCheckBeforeCreateTable(t *testing.T) { func TestParallelSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -566,58 +578,58 @@ func TestParallelSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create foo and bar schemas - if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // re-connect using that schemas - dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo")) + dfoo, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=foo")) if err != nil { t.Fatal(err) } defer func() { - if err := dfoo.Close(); err != nil { + if err := dfoo.Close(ctx); err != nil { t.Error(err) } }() - dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar")) + dbar, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=bar")) if err != nil { t.Fatal(err) } defer func() { - if err := dbar.Close(); err != nil { + if err := dbar.Close(ctx); err != nil { t.Error(err) } }() - if err := dfoo.Lock(); err != nil { + if err := dfoo.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Lock(); err != nil { + if err := dbar.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Unlock(); err != nil { + if err := dbar.Unlock(ctx); err != nil { t.Fatal(err) } - if err := dfoo.Unlock(); err != nil { + if err := dfoo.Unlock(ctx); err != nil { t.Fatal(err) } }) @@ -625,6 +637,7 @@ func TestParallelSchema(t *testing.T) { func TestPostgres_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -632,7 +645,7 @@ func TestPostgres_Lock(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -641,22 +654,22 @@ func TestPostgres_Lock(t *testing.T) { ps := d.(*Postgres) - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -665,6 +678,7 @@ func TestPostgres_Lock(t *testing.T) { func TestWithInstance_Concurrent(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -697,7 +711,7 @@ func TestWithInstance_Concurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func(i int) { defer wg.Done() - _, err := WithInstance(db, &Config{}) + _, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Errorf("process %d error: %s", i, err) } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..5ebccbda6 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -65,7 +65,7 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -131,14 +131,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (p *Postgres) Open(url string) (database.Driver, error) { +func (p *Postgres) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -194,7 +194,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } } - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, MigrationsTableQuoted: migrationsTableQuoted, @@ -210,7 +210,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { return px, nil } -func (p *Postgres) Close() error { +func (p *Postgres) Close(ctx context.Context) error { connErr := p.conn.Close() dbErr := p.db.Close() if connErr != nil || dbErr != nil { @@ -220,7 +220,7 @@ func (p *Postgres) Close() error { } // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS -func (p *Postgres) Lock() error { +func (p *Postgres) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) if err != nil { @@ -236,7 +236,7 @@ func (p *Postgres) Lock() error { }) } -func (p *Postgres) Unlock() error { +func (p *Postgres) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) if err != nil { @@ -251,7 +251,7 @@ func (p *Postgres) Unlock() error { }) } -func (p *Postgres) Run(migration io.Reader) error { +func (p *Postgres) Run(ctx context.Context, migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { @@ -338,7 +338,7 @@ func runesLastIndex(input []rune, target rune) int { return -1 } -func (p *Postgres) SetVersion(version int, dirty bool) error { +func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -372,7 +372,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return nil } -func (p *Postgres) Version() (version int, dirty bool, err error) { +func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -392,7 +392,7 @@ func (p *Postgres) Version() (version int, dirty bool, err error) { } } -func (p *Postgres) Drop() (err error) { +func (p *Postgres) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := p.conn.QueryContext(context.Background(), query) @@ -436,13 +436,13 @@ func (p *Postgres) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Postgres type. -func (p *Postgres) ensureVersionTable() (err error) { - if err = p.Lock(); err != nil { +func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) { + if err = p.Lock(ctx); err != nil { return err } defer func() { - if e := p.Unlock(); e != nil { + if e := p.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index 3066376b9..889d7ecd7 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -76,9 +76,9 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return true } -func mustRun(t *testing.T, d database.Driver, statements []string) { +func mustRun(t *testing.T, ctx context.Context, d database.Driver, statements []string) { for _, statement := range statements { - if err := d.Run(strings.NewReader(statement)); err != nil { + if err := d.Run(ctx, strings.NewReader(statement)); err != nil { t.Fatal(err) } } @@ -86,6 +86,7 @@ func mustRun(t *testing.T, d database.Driver, statements []string) { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -93,12 +94,12 @@ func Test(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -108,6 +109,7 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -115,16 +117,16 @@ func TestMigrate(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://../examples/migrations", "pgx", d) if err != nil { t.Fatal(err) } @@ -134,6 +136,7 @@ func TestMigrate(t *testing.T) { func TestMultipleStatements(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -141,16 +144,16 @@ func TestMultipleStatements(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -167,6 +170,7 @@ func TestMultipleStatements(t *testing.T) { func TestMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -174,16 +178,16 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { addr := pgConnectionString(ip, port, "x-multi-statement=true") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -200,6 +204,7 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) { func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -207,19 +212,19 @@ func TestErrorParsing(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` + `(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))` - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -229,6 +234,7 @@ func TestErrorParsing(t *testing.T) { func TestFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -236,12 +242,12 @@ func TestFilterCustomQuery(t *testing.T) { addr := pgConnectionString(ip, port, "x-custom=foobar") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -250,6 +256,7 @@ func TestFilterCustomQuery(t *testing.T) { func TestWithSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -257,36 +264,36 @@ func TestWithSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create foobar schema - if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.SetVersion(1, false); err != nil { + if err := d.SetVersion(ctx, 1, false); err != nil { t.Fatal(err) } // re-connect using that schema - d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar")) + d2, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=foobar")) if err != nil { t.Fatal(err) } defer func() { - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() - version, _, err := d2.Version() + version, _, err := d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -295,10 +302,10 @@ func TestWithSchema(t *testing.T) { } // now update version and compare - if err := d2.SetVersion(2, false); err != nil { + if err := d2.SetVersion(ctx, 2, false); err != nil { t.Fatal(err) } - version, _, err = d2.Version() + version, _, err = d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -307,7 +314,7 @@ func TestWithSchema(t *testing.T) { } // meanwhile, the public schema still has the other version - version, _, err = d.Version() + version, _, err = d.Version(ctx) if err != nil { t.Fatal(err) } @@ -319,6 +326,7 @@ func TestWithSchema(t *testing.T) { func TestMigrationTableOption(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -326,21 +334,21 @@ func TestMigrationTableOption(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, _ := p.Open(addr) + d, _ := p.Open(ctx, addr) defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create migrate schema - if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // bad unquoted x-migrations-table parameter wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -348,14 +356,14 @@ func TestMigrationTableOption(t *testing.T) { // too many quoted x-migrations-table parameters wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) } // good quoted x-migrations-table parameter - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -370,7 +378,7 @@ func TestMigrationTableOption(t *testing.T) { t.Fatalf("expected table migrate.schema_migrations to exist") } - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -387,6 +395,7 @@ func TestMigrationTableOption(t *testing.T) { func TestFailToCreateTableWithoutPermissions(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -397,21 +406,21 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -420,14 +429,14 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) defer func() { if d2 == nil { return } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() @@ -442,7 +451,7 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { } // re-connect using that x-migrations-table and x-migrations-table-quoted - d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", + d2, err = p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if !errors.As(err, &e) || err == nil { @@ -457,6 +466,7 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { func TestCheckBeforeCreateTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -467,21 +477,21 @@ func TestCheckBeforeCreateTable(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -489,32 +499,32 @@ func TestCheckBeforeCreateTable(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } // revoke privileges - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", }) // re-connect using that schema - d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d3, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - version, _, err := d3.Version() + version, _, err := d3.Version(ctx) if err != nil { t.Fatal(err) @@ -525,7 +535,7 @@ func TestCheckBeforeCreateTable(t *testing.T) { } defer func() { - if err := d3.Close(); err != nil { + if err := d3.Close(ctx); err != nil { t.Fatal(err) } }() @@ -534,6 +544,7 @@ func TestCheckBeforeCreateTable(t *testing.T) { func TestParallelSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -541,58 +552,58 @@ func TestParallelSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create foo and bar schemas - if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // re-connect using that schemas - dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo")) + dfoo, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=foo")) if err != nil { t.Fatal(err) } defer func() { - if err := dfoo.Close(); err != nil { + if err := dfoo.Close(ctx); err != nil { t.Error(err) } }() - dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar")) + dbar, err := p.Open(ctx, pgConnectionString(ip, port, "search_path=bar")) if err != nil { t.Fatal(err) } defer func() { - if err := dbar.Close(); err != nil { + if err := dbar.Close(ctx); err != nil { t.Error(err) } }() - if err := dfoo.Lock(); err != nil { + if err := dfoo.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Lock(); err != nil { + if err := dbar.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Unlock(); err != nil { + if err := dbar.Unlock(ctx); err != nil { t.Fatal(err) } - if err := dfoo.Unlock(); err != nil { + if err := dfoo.Unlock(ctx); err != nil { t.Fatal(err) } }) @@ -600,6 +611,7 @@ func TestParallelSchema(t *testing.T) { func TestPostgres_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -607,7 +619,7 @@ func TestPostgres_Lock(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -616,22 +628,22 @@ func TestPostgres_Lock(t *testing.T) { ps := d.(*Postgres) - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -640,6 +652,7 @@ func TestPostgres_Lock(t *testing.T) { func TestWithInstance_Concurrent(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -672,7 +685,7 @@ func TestWithInstance_Concurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func(i int) { defer wg.Done() - _, err := WithInstance(db, &Config{}) + _, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Errorf("process %d error: %s", i, err) } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 9e6d6277f..45b80d52b 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -124,16 +124,14 @@ func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postg config: config, } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { - ctx := context.Background() - +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if err := instance.Ping(); err != nil { return nil, err } @@ -151,7 +149,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } -func (p *Postgres) Open(url string) (database.Driver, error) { +func (p *Postgres) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -202,7 +200,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } } - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, MigrationsTableQuoted: migrationsTableQuoted, @@ -218,7 +216,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { return px, nil } -func (p *Postgres) Close() error { +func (p *Postgres) Close(ctx context.Context) error { connErr := p.conn.Close() var dbErr error if p.db != nil { @@ -232,7 +230,7 @@ func (p *Postgres) Close() error { } // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS -func (p *Postgres) Lock() error { +func (p *Postgres) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) if err != nil { @@ -249,7 +247,7 @@ func (p *Postgres) Lock() error { }) } -func (p *Postgres) Unlock() error { +func (p *Postgres) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) if err != nil { @@ -264,7 +262,7 @@ func (p *Postgres) Unlock() error { }) } -func (p *Postgres) Run(migration io.Reader) error { +func (p *Postgres) Run(ctx context.Context, migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { @@ -354,7 +352,7 @@ func runesLastIndex(input []rune, target rune) int { return -1 } -func (p *Postgres) SetVersion(version int, dirty bool) error { +func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -388,7 +386,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return nil } -func (p *Postgres) Version() (version int, dirty bool, err error) { +func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -408,7 +406,7 @@ func (p *Postgres) Version() (version int, dirty bool, err error) { } } -func (p *Postgres) Drop() (err error) { +func (p *Postgres) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := p.conn.QueryContext(context.Background(), query) @@ -452,13 +450,13 @@ func (p *Postgres) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Postgres type. -func (p *Postgres) ensureVersionTable() (err error) { - if err = p.Lock(); err != nil { +func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) { + if err = p.Lock(ctx); err != nil { return err } defer func() { - if e := p.Unlock(); e != nil { + if e := p.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 988d086b2..5a62229fd 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -76,9 +76,9 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { return true } -func mustRun(t *testing.T, d database.Driver, statements []string) { +func mustRun(t *testing.T, ctx context.Context, d database.Driver, statements []string) { for _, statement := range statements { - if err := d.Run(strings.NewReader(statement)); err != nil { + if err := d.Run(ctx, strings.NewReader(statement)); err != nil { t.Fatal(err) } } @@ -112,6 +112,7 @@ func Test(t *testing.T) { func test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -119,12 +120,12 @@ func test(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -134,6 +135,7 @@ func test(t *testing.T) { func testMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -141,16 +143,16 @@ func testMigrate(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "postgres", d) if err != nil { t.Fatal(err) } @@ -160,6 +162,7 @@ func testMigrate(t *testing.T) { func testMultipleStatements(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -167,16 +170,16 @@ func testMultipleStatements(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -193,6 +196,7 @@ func testMultipleStatements(t *testing.T) { func testMultipleStatementsInMultiStatementMode(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -200,16 +204,16 @@ func testMultipleStatementsInMultiStatementMode(t *testing.T) { addr := pgConnectionString(ip, port, "x-multi-statement=true") p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -226,6 +230,7 @@ func testMultipleStatementsInMultiStatementMode(t *testing.T) { func testErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -233,19 +238,19 @@ func testErrorParsing(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` + `(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")` - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -255,6 +260,7 @@ func testErrorParsing(t *testing.T) { func testFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -263,12 +269,12 @@ func testFilterCustomQuery(t *testing.T) { addr := fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-custom=foobar", pgPassword, ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -277,6 +283,7 @@ func testFilterCustomQuery(t *testing.T) { func testWithSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -284,37 +291,37 @@ func testWithSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create foobar schema - if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.SetVersion(1, false); err != nil { + if err := d.SetVersion(ctx, 1, false); err != nil { t.Fatal(err) } // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar", pgPassword, ip, port)) if err != nil { t.Fatal(err) } defer func() { - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() - version, _, err := d2.Version() + version, _, err := d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -323,10 +330,10 @@ func testWithSchema(t *testing.T) { } // now update version and compare - if err := d2.SetVersion(2, false); err != nil { + if err := d2.SetVersion(ctx, 2, false); err != nil { t.Fatal(err) } - version, _, err = d2.Version() + version, _, err = d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -335,7 +342,7 @@ func testWithSchema(t *testing.T) { } // meanwhile, the public schema still has the other version - version, _, err = d.Version() + version, _, err = d.Version(ctx) if err != nil { t.Fatal(err) } @@ -347,6 +354,7 @@ func testWithSchema(t *testing.T) { func testMigrationTableOption(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -354,21 +362,21 @@ func testMigrationTableOption(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, _ := p.Open(addr) + d, _ := p.Open(ctx, addr) defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } }() // create migrate schema - if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // bad unquoted x-migrations-table parameter wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -376,14 +384,14 @@ func testMigrationTableOption(t *testing.T) { // too many quoted x-migrations-table parameters wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters" - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if (err != nil) && (err.Error() != wantErr) { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) } // good quoted x-migrations-table parameter - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -398,7 +406,7 @@ func testMigrationTableOption(t *testing.T) { t.Fatalf("expected table migrate.schema_migrations to exist") } - d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", + d, err = p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations", pgPassword, ip, port)) if err != nil { t.Fatal(err) @@ -415,6 +423,7 @@ func testMigrationTableOption(t *testing.T) { func testFailToCreateTableWithoutPermissions(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -425,21 +434,21 @@ func testFailToCreateTableWithoutPermissions(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -448,14 +457,14 @@ func testFailToCreateTableWithoutPermissions(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) defer func() { if d2 == nil { return } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } }() @@ -470,7 +479,7 @@ func testFailToCreateTableWithoutPermissions(t *testing.T) { } // re-connect using that x-migrations-table and x-migrations-table-quoted - d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", + d2, err = p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", pgPassword, ip, port)) if !errors.As(err, &e) || err == nil { @@ -485,6 +494,7 @@ func testFailToCreateTableWithoutPermissions(t *testing.T) { func testCheckBeforeCreateTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -495,21 +505,21 @@ func testCheckBeforeCreateTable(t *testing.T) { // Check that opening the postgres connection returns NilVersion p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine // since this is a test environment and we're not expecting to the pgPassword to be malicious - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'", "CREATE SCHEMA barfoo AUTHORIZATION postgres", "GRANT USAGE ON SCHEMA barfoo TO not_owner", @@ -517,32 +527,32 @@ func testCheckBeforeCreateTable(t *testing.T) { }) // re-connect using that schema - d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d2, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Fatal(err) } // revoke privileges - mustRun(t, d, []string{ + mustRun(t, ctx, d, []string{ "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC", "REVOKE CREATE ON SCHEMA barfoo FROM not_owner", }) // re-connect using that schema - d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", + d3, err := p.Open(ctx, fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } - version, _, err := d3.Version() + version, _, err := d3.Version(ctx) if err != nil { t.Fatal(err) @@ -553,7 +563,7 @@ func testCheckBeforeCreateTable(t *testing.T) { } defer func() { - if err := d3.Close(); err != nil { + if err := d3.Close(ctx); err != nil { t.Fatal(err) } }() @@ -562,6 +572,7 @@ func testCheckBeforeCreateTable(t *testing.T) { func testParallelSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -569,60 +580,60 @@ func testParallelSchema(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create foo and bar schemas - if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil { t.Fatal(err) } - if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil { t.Fatal(err) } // re-connect using that schemas - dfoo, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foo", + dfoo, err := p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foo", pgPassword, ip, port)) if err != nil { t.Fatal(err) } defer func() { - if err := dfoo.Close(); err != nil { + if err := dfoo.Close(ctx); err != nil { t.Error(err) } }() - dbar, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=bar", + dbar, err := p.Open(ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=bar", pgPassword, ip, port)) if err != nil { t.Fatal(err) } defer func() { - if err := dbar.Close(); err != nil { + if err := dbar.Close(ctx); err != nil { t.Error(err) } }() - if err := dfoo.Lock(); err != nil { + if err := dfoo.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Lock(); err != nil { + if err := dbar.Lock(ctx); err != nil { t.Fatal(err) } - if err := dbar.Unlock(); err != nil { + if err := dbar.Unlock(ctx); err != nil { t.Fatal(err) } - if err := dfoo.Unlock(); err != nil { + if err := dfoo.Unlock(ctx); err != nil { t.Fatal(err) } }) @@ -630,6 +641,7 @@ func testParallelSchema(t *testing.T) { func testPostgresLock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -637,7 +649,7 @@ func testPostgresLock(t *testing.T) { addr := pgConnectionString(ip, port) p := &Postgres{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -646,22 +658,22 @@ func testPostgresLock(t *testing.T) { ps := d.(*Postgres) - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -670,6 +682,7 @@ func testPostgresLock(t *testing.T) { func testWithInstanceConcurrent(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -702,7 +715,7 @@ func testWithInstanceConcurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func(i int) { defer wg.Done() - _, err := WithInstance(db, &Config{}) + _, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Errorf("process %d error: %s", i, err) } @@ -740,7 +753,7 @@ func testWithConnection(t *testing.T) { } defer func() { - if err := p.Close(); err != nil { + if err := p.Close(ctx); err != nil { t.Error(err) } }() diff --git a/database/ql/ql.go b/database/ql/ql.go index 37c062455..f4b872c13 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -1,6 +1,7 @@ package ql import ( + "context" "database/sql" "fmt" "io" @@ -39,7 +40,7 @@ type Ql struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -56,7 +57,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { db: instance, config: config, } - if err := mx.ensureVersionTable(); err != nil { + if err := mx.ensureVersionTable(ctx); err != nil { return nil, err } return mx, nil @@ -65,13 +66,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Ql type. -func (m *Ql) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Ql) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -99,7 +100,7 @@ func (m *Ql) ensureVersionTable() (err error) { return nil } -func (m *Ql) Open(url string) (database.Driver, error) { +func (m *Ql) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -113,7 +114,7 @@ func (m *Ql) Open(url string) (database.Driver, error) { if len(migrationsTable) == 0 { migrationsTable = DefaultMigrationsTable } - mx, err := WithInstance(db, &Config{ + mx, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, }) @@ -122,10 +123,10 @@ func (m *Ql) Open(url string) (database.Driver, error) { } return mx, nil } -func (m *Ql) Close() error { +func (m *Ql) Close(ctx context.Context) error { return m.db.Close() } -func (m *Ql) Drop() (err error) { +func (m *Ql) Drop(ctx context.Context) (err error) { query := `SELECT Name FROM __Table` tables, err := m.db.Query(query) if err != nil { @@ -165,19 +166,19 @@ func (m *Ql) Drop() (err error) { return nil } -func (m *Ql) Lock() error { +func (m *Ql) Lock(ctx context.Context) error { if !m.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (m *Ql) Unlock() error { +func (m *Ql) Unlock(ctx context.Context) error { if !m.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (m *Ql) Run(migration io.Reader) error { +func (m *Ql) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -202,7 +203,7 @@ func (m *Ql) executeQuery(query string) error { } return nil } -func (m *Ql) SetVersion(version int, dirty bool) error { +func (m *Ql) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -234,7 +235,7 @@ func (m *Ql) SetVersion(version int, dirty bool) error { return nil } -func (m *Ql) Version() (version int, dirty bool, err error) { +func (m *Ql) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" err = m.db.QueryRow(query).Scan(&version, &dirty) if err != nil { diff --git a/database/ql/ql_test.go b/database/ql/ql_test.go index 1630b9ae7..f4063b986 100644 --- a/database/ql/ql_test.go +++ b/database/ql/ql_test.go @@ -1,6 +1,7 @@ package ql import ( + "context" "database/sql" "fmt" "path/filepath" @@ -15,9 +16,10 @@ import ( func Test(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "ql.db")) + ctx := context.Background() p := &Ql{} addr := fmt.Sprintf("ql://%s", filepath.Join(dir, "ql.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -38,6 +40,7 @@ func TestMigrate(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "ql.db")) + ctx := context.Background() db, err := sql.Open("ql", filepath.Join(dir, "ql.db")) if err != nil { return @@ -48,12 +51,12 @@ func TestMigrate(t *testing.T) { } }() - driver, err := WithInstance(db, &Config{}) + driver, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 9ba7b4311..90b681409 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -46,7 +46,7 @@ type Redshift struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -85,14 +85,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (p *Redshift) Open(url string) (database.Driver, error) { +func (p *Redshift) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -106,7 +106,7 @@ func (p *Redshift) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, }) @@ -117,7 +117,7 @@ func (p *Redshift) Open(url string) (database.Driver, error) { return px, nil } -func (p *Redshift) Close() error { +func (p *Redshift) Close(ctx context.Context) error { connErr := p.conn.Close() dbErr := p.db.Close() if connErr != nil || dbErr != nil { @@ -127,21 +127,21 @@ func (p *Redshift) Close() error { } // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html -func (p *Redshift) Lock() error { +func (p *Redshift) Lock(ctx context.Context) error { if !p.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (p *Redshift) Unlock() error { +func (p *Redshift) Unlock(ctx context.Context) error { if !p.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (p *Redshift) Run(migration io.Reader) error { +func (p *Redshift) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -209,7 +209,7 @@ func runesLastIndex(input []rune, target rune) int { return -1 } -func (p *Redshift) SetVersion(version int, dirty bool) error { +func (p *Redshift) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -243,7 +243,7 @@ func (p *Redshift) SetVersion(version int, dirty bool) error { return nil } -func (p *Redshift) Version() (version int, dirty bool, err error) { +func (p *Redshift) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -263,7 +263,7 @@ func (p *Redshift) Version() (version int, dirty bool, err error) { } } -func (p *Redshift) Drop() (err error) { +func (p *Redshift) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := p.conn.QueryContext(context.Background(), query) @@ -307,13 +307,13 @@ func (p *Redshift) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Redshift type. -func (p *Redshift) ensureVersionTable() (err error) { - if err = p.Lock(); err != nil { +func (p *Redshift) ensureVersionTable(ctx context.Context) (err error) { + if err = p.Lock(ctx); err != nil { return err } defer func() { - if e := p.Unlock(); e != nil { + if e := p.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/redshift/redshift_test.go b/database/redshift/redshift_test.go index 9ee5cbe64..d93617622 100644 --- a/database/redshift/redshift_test.go +++ b/database/redshift/redshift_test.go @@ -86,6 +86,7 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -93,12 +94,12 @@ func Test(t *testing.T) { addr := redshiftConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -108,6 +109,7 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -115,16 +117,16 @@ func TestMigrate(t *testing.T) { addr := redshiftConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "postgres", d) if err != nil { t.Fatal(err) } @@ -134,6 +136,7 @@ func TestMigrate(t *testing.T) { func TestMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -141,16 +144,16 @@ func TestMultiStatement(t *testing.T) { addr := redshiftConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"))); err != nil { + if err := d.Run(ctx, bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"))); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -167,6 +170,7 @@ func TestMultiStatement(t *testing.T) { func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -174,19 +178,19 @@ func TestErrorParsing(t *testing.T) { addr := redshiftConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` + `(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")` - if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);"))); err == nil { + if err := d.Run(ctx, bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);"))); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -196,6 +200,7 @@ func TestErrorParsing(t *testing.T) { func TestFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -203,12 +208,12 @@ func TestFilterCustomQuery(t *testing.T) { addr := fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-custom=foobar", pgPassword, ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -217,6 +222,7 @@ func TestFilterCustomQuery(t *testing.T) { func TestWithSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -224,21 +230,21 @@ func TestWithSchema(t *testing.T) { addr := redshiftConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() // create foobar schema - if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foobar AUTHORIZATION postgres"))); err != nil { + if err := d.Run(ctx, bytes.NewReader([]byte("CREATE SCHEMA foobar AUTHORIZATION postgres"))); err != nil { t.Fatal(err) } - if err := d.SetVersion(1, false); err != nil { + if err := d.SetVersion(ctx, 1, false); err != nil { t.Fatal(err) } @@ -248,12 +254,12 @@ func TestWithSchema(t *testing.T) { t.Fatal(err) } defer func() { - if err := d2.Close(); err != nil { + if err := d2.Close(ctx); err != nil { t.Error(err) } }() - version, _, err := d2.Version() + version, _, err := d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -262,10 +268,10 @@ func TestWithSchema(t *testing.T) { } // now update version and compare - if err := d2.SetVersion(2, false); err != nil { + if err := d2.SetVersion(ctx, 2, false); err != nil { t.Fatal(err) } - version, _, err = d2.Version() + version, _, err = d2.Version(ctx) if err != nil { t.Fatal(err) } @@ -274,7 +280,7 @@ func TestWithSchema(t *testing.T) { } // meanwhile, the public schema still has the other version - version, _, err = d.Version() + version, _, err = d.Version(ctx) if err != nil { t.Fatal(err) } @@ -290,6 +296,7 @@ func TestWithInstance(t *testing.T) { func TestRedshift_Lock(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) @@ -297,7 +304,7 @@ func TestRedshift_Lock(t *testing.T) { addr := pgConnectionString(ip, port) p := &Redshift{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -306,22 +313,22 @@ func TestRedshift_Lock(t *testing.T) { ps := d.(*Redshift) - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } - err = ps.Lock() + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - err = ps.Unlock() + err = ps.Unlock(ctx) if err != nil { t.Fatal(err) } diff --git a/database/rqlite/rqlite.go b/database/rqlite/rqlite.go index 14d3bd340..6f047529b 100644 --- a/database/rqlite/rqlite.go +++ b/database/rqlite/rqlite.go @@ -1,6 +1,7 @@ package rqlite import ( + "context" "fmt" "io" nurl "net/url" @@ -51,7 +52,7 @@ type Rqlite struct { // WithInstance creates a rqlite database driver with an existing gorqlite database connection // and a Config struct -func WithInstance(instance *gorqlite.Connection, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *gorqlite.Connection, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -70,7 +71,7 @@ func WithInstance(instance *gorqlite.Connection, config *Config) (database.Drive config: config, } - if err := driver.ensureVersionTable(); err != nil { + if err := driver.ensureVersionTable(ctx); err != nil { return nil, err } @@ -78,18 +79,18 @@ func WithInstance(instance *gorqlite.Connection, config *Config) (database.Drive } // OpenURL creates a rqlite database driver from a connect URL -func OpenURL(url string) (database.Driver, error) { +func OpenURL(ctx context.Context, url string) (database.Driver, error) { d := &Rqlite{} - return d.Open(url) + return d.Open(ctx, url) } -func (r *Rqlite) ensureVersionTable() (err error) { - if err = r.Lock(); err != nil { +func (r *Rqlite) ensureVersionTable(ctx context.Context) (err error) { + if err = r.Lock(ctx); err != nil { return err } defer func() { - if e := r.Unlock(); e != nil { + if e := r.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -113,7 +114,7 @@ func (r *Rqlite) ensureVersionTable() (err error) { // Open returns a new driver instance configured with parameters // coming from the URL string. Migrate will call this function // only once per instance. -func (r *Rqlite) Open(url string) (database.Driver, error) { +func (r *Rqlite) Open(ctx context.Context, url string) (database.Driver, error) { dburl, config, err := parseUrl(url) if err != nil { return nil, err @@ -125,7 +126,7 @@ func (r *Rqlite) Open(url string) (database.Driver, error) { return nil, err } - if err := r.ensureVersionTable(); err != nil { + if err := r.ensureVersionTable(ctx); err != nil { return nil, err } @@ -134,7 +135,7 @@ func (r *Rqlite) Open(url string) (database.Driver, error) { // Close closes the underlying database instance managed by the driver. // Migrate will call this function only once per instance. -func (r *Rqlite) Close() error { +func (r *Rqlite) Close(ctx context.Context) error { r.db.Close() return nil } @@ -143,7 +144,7 @@ func (r *Rqlite) Close() error { // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. // Return database.ErrLocked if database is already locked. -func (r *Rqlite) Lock() error { +func (r *Rqlite) Lock(ctx context.Context) error { if !r.isLocked.CAS(false, true) { return database.ErrLocked } @@ -152,7 +153,7 @@ func (r *Rqlite) Lock() error { // Unlock should release the lock. Migrate will call this function after // all migrations have been run. -func (r *Rqlite) Unlock() error { +func (r *Rqlite) Unlock(ctx context.Context) error { if !r.isLocked.CAS(true, false) { return database.ErrNotLocked } @@ -160,7 +161,7 @@ func (r *Rqlite) Unlock() error { } // Run applies a migration to the database. migration is guaranteed to be not nil. -func (r *Rqlite) Run(migration io.Reader) error { +func (r *Rqlite) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -177,7 +178,7 @@ func (r *Rqlite) Run(migration io.Reader) error { // SetVersion saves version and dirty state. // Migrate will call this function before and after each call to Run. // version must be >= -1. -1 means NilVersion. -func (r *Rqlite) SetVersion(version int, dirty bool) error { +func (r *Rqlite) SetVersion(ctx context.Context, version int, dirty bool) error { deleteQuery := fmt.Sprintf(`DELETE FROM %s`, r.config.MigrationsTable) statements := []gorqlite.ParameterizedStatement{ { @@ -217,7 +218,7 @@ func (r *Rqlite) SetVersion(version int, dirty bool) error { // Version returns the currently active version and if the database is dirty. // When no migration has been applied, it must return version -1. // Dirty means, a previous migration failed and user interaction is required. -func (r *Rqlite) Version() (version int, dirty bool, err error) { +func (r *Rqlite) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + r.config.MigrationsTable + " LIMIT 1" qr, err := r.db.QueryOne(query) @@ -239,7 +240,7 @@ func (r *Rqlite) Version() (version int, dirty bool, err error) { // Drop deletes everything in the database. // Note that this is a breaking action, a new call to Open() is necessary to // ensure subsequent calls work as expected. -func (r *Rqlite) Drop() error { +func (r *Rqlite) Drop(ctx context.Context) error { query := `SELECT name FROM sqlite_master WHERE type = 'table'` tables, err := r.db.QueryOne(query) diff --git a/database/rqlite/rqlite_test.go b/database/rqlite/rqlite_test.go index c19f7476b..25f6c6f87 100644 --- a/database/rqlite/rqlite_test.go +++ b/database/rqlite/rqlite_test.go @@ -75,6 +75,7 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { func Test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) @@ -82,7 +83,7 @@ func Test(t *testing.T) { t.Logf("DB connect string : %s\n", connectString) r := &Rqlite{} - d, err := r.Open(connectString) + d, err := r.Open(ctx, connectString) assert.NoError(t, err) dt.Test(t, d, []byte("CREATE TABLE t (Qty int, Name string);")) @@ -91,21 +92,22 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) connectString := fmt.Sprintf("rqlite://%s:%s?level=strong&disableClusterDiscovery=true&x-connect-insecure=true", ip, port) t.Logf("DB connect string : %s\n", connectString) - driver, err := OpenURL(connectString) + driver, err := OpenURL(ctx, connectString) assert.NoError(t, err) defer func() { - if err := driver.Close(); err != nil { + if err := driver.Close(ctx); err != nil { return } }() - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) assert.NoError(t, err) @@ -116,32 +118,35 @@ func TestMigrate(t *testing.T) { func TestBadConnectInsecureParam(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) connectString := fmt.Sprintf("rqlite://%s:%s?x-connect-insecure=foo", ip, port) t.Logf("DB connect string : %s\n", connectString) - _, err = OpenURL(connectString) + _, err = OpenURL(ctx, connectString) assert.ErrorIs(t, err, ErrBadConfig) }) } func TestBadProtocol(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) connectString := fmt.Sprintf("postgres://%s:%s/database", ip, port) t.Logf("DB connect string : %s\n", connectString) - _, err = OpenURL(connectString) + _, err = OpenURL(ctx, connectString) assert.ErrorIs(t, err, ErrBadConfig) }) } func TestNoConfig(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) @@ -151,13 +156,14 @@ func TestNoConfig(t *testing.T) { db, err := gorqlite.Open(connectString) assert.NoError(t, err) - _, err = WithInstance(db, nil) + _, err = WithInstance(ctx, db, nil) assert.ErrorIs(t, err, ErrNilConfig) }) } func TestWithInstanceEmptyConfig(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) @@ -167,35 +173,36 @@ func TestWithInstanceEmptyConfig(t *testing.T) { db, err := gorqlite.Open(connectString) assert.NoError(t, err) - driver, err := WithInstance(db, &Config{}) + driver, err := WithInstance(ctx, db, &Config{}) assert.NoError(t, err) defer func() { - if err := driver.Close(); err != nil { + if err := driver.Close(ctx); err != nil { t.Fatal(err) } }() - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) assert.NoError(t, err) t.Log("UP") - err = m.Up() + err = m.Up(ctx) assert.NoError(t, err) _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", DefaultMigrationsTable)) assert.NoError(t, err) t.Log("DOWN") - err = m.Down() + err = m.Down(ctx) assert.NoError(t, err) }) } func TestMigrationTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() ip, port, err := c.Port(defaultPort) assert.NoError(t, err) @@ -206,22 +213,22 @@ func TestMigrationTable(t *testing.T) { assert.NoError(t, err) config := Config{MigrationsTable: "my_migration_table"} - driver, err := WithInstance(db, &config) + driver, err := WithInstance(ctx, db, &config) assert.NoError(t, err) defer func() { - if err := driver.Close(); err != nil { + if err := driver.Close(ctx); err != nil { t.Fatal(err) } }() - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) assert.NoError(t, err) t.Log("UP") - err = m.Up() + err = m.Up(ctx) assert.NoError(t, err) _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", config.MigrationsTable)) @@ -244,7 +251,7 @@ func TestMigrationTable(t *testing.T) { assert.Equal(t, petPredator, 1) t.Log("DOWN") - err = m.Down() + err = m.Down(ctx) assert.NoError(t, err) _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", config.MigrationsTable)) diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index 46ce30200..837b1fa12 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -46,7 +46,7 @@ type Snowflake struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -85,14 +85,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (p *Snowflake) Open(url string) (database.Driver, error) { +func (p *Snowflake) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -138,7 +138,7 @@ func (p *Snowflake) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: database, MigrationsTable: migrationsTable, }) @@ -149,7 +149,7 @@ func (p *Snowflake) Open(url string) (database.Driver, error) { return px, nil } -func (p *Snowflake) Close() error { +func (p *Snowflake) Close(ctx context.Context) error { connErr := p.conn.Close() dbErr := p.db.Close() if connErr != nil || dbErr != nil { @@ -158,21 +158,21 @@ func (p *Snowflake) Close() error { return nil } -func (p *Snowflake) Lock() error { +func (p *Snowflake) Lock(ctx context.Context) error { if !p.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (p *Snowflake) Unlock() error { +func (p *Snowflake) Unlock(ctx context.Context) error { if !p.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (p *Snowflake) Run(migration io.Reader) error { +func (p *Snowflake) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -240,7 +240,7 @@ func runesLastIndex(input []rune, target rune) int { return -1 } -func (p *Snowflake) SetVersion(version int, dirty bool) error { +func (p *Snowflake) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -276,7 +276,7 @@ func (p *Snowflake) SetVersion(version int, dirty bool) error { return nil } -func (p *Snowflake) Version() (version int, dirty bool, err error) { +func (p *Snowflake) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -296,7 +296,7 @@ func (p *Snowflake) Version() (version int, dirty bool, err error) { } } -func (p *Snowflake) Drop() (err error) { +func (p *Snowflake) Drop(ctx context.Context) (err error) { // select all tables in current schema query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := p.conn.QueryContext(context.Background(), query) @@ -340,13 +340,13 @@ func (p *Snowflake) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Snowflake type. -func (p *Snowflake) ensureVersionTable() (err error) { - if err = p.Lock(); err != nil { +func (p *Snowflake) ensureVersionTable(ctx context.Context) (err error) { + if err = p.Lock(ctx); err != nil { return err } defer func() { - if e := p.Unlock(); e != nil { + if e := p.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index b733302d5..fbe92c578 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -80,7 +80,7 @@ func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { } // WithInstance implements database.Driver -func WithInstance(instance *DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -99,7 +99,7 @@ func WithInstance(instance *DB, config *Config) (database.Driver, error) { lock: uatomic.NewUint32(unlockedVal), } - if err := sx.ensureVersionTable(); err != nil { + if err := sx.ensureVersionTable(ctx); err != nil { return nil, err } @@ -107,14 +107,12 @@ func WithInstance(instance *DB, config *Config) (database.Driver, error) { } // Open implements database.Driver -func (s *Spanner) Open(url string) (database.Driver, error) { +func (s *Spanner) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err } - ctx := context.Background() - adminClient, err := sdb.NewDatabaseAdminClient(ctx) if err != nil { return nil, err @@ -137,7 +135,7 @@ func (s *Spanner) Open(url string) (database.Driver, error) { } db := &DB{admin: adminClient, data: dataClient} - return WithInstance(db, &Config{ + return WithInstance(ctx, db, &Config{ DatabaseName: dbname, MigrationsTable: migrationsTable, CleanStatements: clean, @@ -145,14 +143,14 @@ func (s *Spanner) Open(url string) (database.Driver, error) { } // Close implements database.Driver -func (s *Spanner) Close() error { +func (s *Spanner) Close(ctx context.Context) error { s.db.data.Close() return s.db.admin.Close() } // Lock implements database.Driver but doesn't do anything because Spanner only // enqueues the UpdateDatabaseDdlRequest. -func (s *Spanner) Lock() error { +func (s *Spanner) Lock(ctx context.Context) error { if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped { return nil } @@ -160,7 +158,7 @@ func (s *Spanner) Lock() error { } // Unlock implements database.Driver but no action required, see Lock. -func (s *Spanner) Unlock() error { +func (s *Spanner) Unlock(ctx context.Context) error { if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped { return nil } @@ -168,7 +166,7 @@ func (s *Spanner) Unlock() error { } // Run implements database.Driver -func (s *Spanner) Run(migration io.Reader) error { +func (s *Spanner) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -182,7 +180,6 @@ func (s *Spanner) Run(migration io.Reader) error { } } - ctx := context.Background() op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ Database: s.config.DatabaseName, Statements: stmts, @@ -200,9 +197,7 @@ func (s *Spanner) Run(migration io.Reader) error { } // SetVersion implements database.Driver -func (s *Spanner) SetVersion(version int, dirty bool) error { - ctx := context.Background() - +func (s *Spanner) SetVersion(ctx context.Context, version int, dirty bool) error { _, err := s.db.data.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { m := []*spanner.Mutation{ @@ -221,9 +216,7 @@ func (s *Spanner) SetVersion(version int, dirty bool) error { } // Version implements database.Driver -func (s *Spanner) Version() (version int, dirty bool, err error) { - ctx := context.Background() - +func (s *Spanner) Version(ctx context.Context) (version int, dirty bool, err error) { stmt := spanner.Statement{ SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, } @@ -255,8 +248,7 @@ var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\ // provided in the schema. Assuming the schema describes how the database can // be "build up", it seems logical to "unbuild" the database simply by going the // opposite direction. More testing -func (s *Spanner) Drop() error { - ctx := context.Background() +func (s *Spanner) Drop(ctx context.Context) error { res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ Database: s.config.DatabaseName, }) @@ -298,13 +290,13 @@ func (s *Spanner) Drop() error { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Spanner type. -func (s *Spanner) ensureVersionTable() (err error) { - if err = s.Lock(); err != nil { +func (s *Spanner) ensureVersionTable(ctx context.Context) (err error) { + if err = s.Lock(ctx); err != nil { return err } defer func() { - if e := s.Unlock(); e != nil { + if e := s.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -313,7 +305,6 @@ func (s *Spanner) ensureVersionTable() (err error) { } }() - ctx := context.Background() tbl := s.config.MigrationsTable iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { diff --git a/database/spanner/spanner_test.go b/database/spanner/spanner_test.go index d6ab4db32..9c62116fe 100644 --- a/database/spanner/spanner_test.go +++ b/database/spanner/spanner_test.go @@ -1,6 +1,7 @@ package spanner import ( + "context" "fmt" "os" "testing" @@ -35,9 +36,10 @@ const db = "projects/abc/instances/def/databases/testdb" func Test(t *testing.T) { withSpannerEmulator(t, func(t *testing.T) { + ctx := context.Background() uri := fmt.Sprintf("spanner://%s", db) s := &Spanner{} - d, err := s.Open(uri) + d, err := s.Open(ctx, uri) if err != nil { t.Fatal(err) } @@ -47,13 +49,14 @@ func Test(t *testing.T) { func TestMigrate(t *testing.T) { withSpannerEmulator(t, func(t *testing.T) { + ctx := context.Background() s := &Spanner{} uri := fmt.Sprintf("spanner://%s", db) - d, err := s.Open(uri) + d, err := s.Open(ctx, uri) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", uri, d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", uri, d) if err != nil { t.Fatal(err) } diff --git a/database/sqlcipher/sqlcipher.go b/database/sqlcipher/sqlcipher.go index f98fb3a21..81e586f42 100644 --- a/database/sqlcipher/sqlcipher.go +++ b/database/sqlcipher/sqlcipher.go @@ -1,6 +1,7 @@ package sqlcipher import ( + "context" "database/sql" "fmt" "io" @@ -40,7 +41,7 @@ type Sqlite struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -57,7 +58,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { db: instance, config: config, } - if err := mx.ensureVersionTable(); err != nil { + if err := mx.ensureVersionTable(ctx); err != nil { return nil, err } return mx, nil @@ -66,13 +67,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Sqlite type. -func (m *Sqlite) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Sqlite) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -92,7 +93,7 @@ func (m *Sqlite) ensureVersionTable() (err error) { return nil } -func (m *Sqlite) Open(url string) (database.Driver, error) { +func (m *Sqlite) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -118,7 +119,7 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { } } - mx, err := WithInstance(db, &Config{ + mx, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, NoTxWrap: noTxWrap, @@ -129,11 +130,11 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { return mx, nil } -func (m *Sqlite) Close() error { +func (m *Sqlite) Close(ctx context.Context) error { return m.db.Close() } -func (m *Sqlite) Drop() (err error) { +func (m *Sqlite) Drop(ctx context.Context) (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) if err != nil { @@ -177,21 +178,21 @@ func (m *Sqlite) Drop() (err error) { return nil } -func (m *Sqlite) Lock() error { +func (m *Sqlite) Lock(ctx context.Context) error { if !m.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (m *Sqlite) Unlock() error { +func (m *Sqlite) Unlock(ctx context.Context) error { if !m.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (m *Sqlite) Run(migration io.Reader) error { +func (m *Sqlite) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -228,7 +229,7 @@ func (m *Sqlite) executeQueryNoTx(query string) error { return nil } -func (m *Sqlite) SetVersion(version int, dirty bool) error { +func (m *Sqlite) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -259,7 +260,7 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return nil } -func (m *Sqlite) Version() (version int, dirty bool, err error) { +func (m *Sqlite) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" err = m.db.QueryRow(query).Scan(&version, &dirty) if err != nil { diff --git a/database/sqlcipher/sqlcipher_test.go b/database/sqlcipher/sqlcipher_test.go index 183fb4bb5..33283cb3a 100644 --- a/database/sqlcipher/sqlcipher_test.go +++ b/database/sqlcipher/sqlcipher_test.go @@ -1,6 +1,7 @@ package sqlcipher import ( + "context" "database/sql" "fmt" "path/filepath" @@ -17,9 +18,10 @@ import ( func Test(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s", filepath.Join(dir, "sqlite3.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -30,6 +32,7 @@ func TestMigrate(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() db, err := sql.Open("sqlite3", filepath.Join(dir, "sqlite3.db")) if err != nil { return @@ -39,12 +42,12 @@ func TestMigrate(t *testing.T) { return } }() - driver, err := WithInstance(db, &Config{}) + driver, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { @@ -58,6 +61,7 @@ func TestMigrationTable(t *testing.T) { t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() db, err := sql.Open("sqlite3", filepath.Join(dir, "sqlite3.db")) if err != nil { return @@ -71,18 +75,18 @@ func TestMigrationTable(t *testing.T) { config := &Config{ MigrationsTable: "my_migration_table", } - driver, err := WithInstance(db, config) + driver, err := WithInstance(ctx, db, config) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { t.Fatal(err) } t.Log("UP") - err = m.Up() + err = m.Up(ctx) if err != nil { t.Fatal(err) } @@ -96,9 +100,10 @@ func TestMigrationTable(t *testing.T) { func TestNoTxWrap(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s?x-no-tx-wrap=true", filepath.Join(dir, "sqlite3.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -110,9 +115,10 @@ func TestNoTxWrap(t *testing.T) { func TestNoTxWrapInvalidValue(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s?x-no-tx-wrap=yeppers", filepath.Join(dir, "sqlite3.db")) - _, err := p.Open(addr) + _, err := p.Open(ctx, addr) if assert.Error(t, err) { assert.Contains(t, err.Error(), "x-no-tx-wrap") assert.Contains(t, err.Error(), "invalid syntax") diff --git a/database/sqlite/sqlite.go b/database/sqlite/sqlite.go index ce449dfa0..896b4e719 100644 --- a/database/sqlite/sqlite.go +++ b/database/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "fmt" "io" @@ -40,7 +41,7 @@ type Sqlite struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -57,7 +58,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { db: instance, config: config, } - if err := mx.ensureVersionTable(); err != nil { + if err := mx.ensureVersionTable(ctx); err != nil { return nil, err } return mx, nil @@ -66,13 +67,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Sqlite type. -func (m *Sqlite) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Sqlite) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -92,7 +93,7 @@ func (m *Sqlite) ensureVersionTable() (err error) { return nil } -func (m *Sqlite) Open(url string) (database.Driver, error) { +func (m *Sqlite) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -118,7 +119,7 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { } } - mx, err := WithInstance(db, &Config{ + mx, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, NoTxWrap: noTxWrap, @@ -129,11 +130,11 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { return mx, nil } -func (m *Sqlite) Close() error { +func (m *Sqlite) Close(ctx context.Context) error { return m.db.Close() } -func (m *Sqlite) Drop() (err error) { +func (m *Sqlite) Drop(ctx context.Context) (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) if err != nil { @@ -177,21 +178,21 @@ func (m *Sqlite) Drop() (err error) { return nil } -func (m *Sqlite) Lock() error { +func (m *Sqlite) Lock(ctx context.Context) error { if !m.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (m *Sqlite) Unlock() error { +func (m *Sqlite) Unlock(ctx context.Context) error { if !m.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (m *Sqlite) Run(migration io.Reader) error { +func (m *Sqlite) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -228,7 +229,7 @@ func (m *Sqlite) executeQueryNoTx(query string) error { return nil } -func (m *Sqlite) SetVersion(version int, dirty bool) error { +func (m *Sqlite) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -259,7 +260,7 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return nil } -func (m *Sqlite) Version() (version int, dirty bool, err error) { +func (m *Sqlite) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" err = m.db.QueryRow(query).Scan(&version, &dirty) if err != nil { diff --git a/database/sqlite/sqlite_test.go b/database/sqlite/sqlite_test.go index 31fd17a73..a9840834e 100644 --- a/database/sqlite/sqlite_test.go +++ b/database/sqlite/sqlite_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "fmt" "path/filepath" @@ -17,9 +18,10 @@ import ( func Test(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite://%s", filepath.Join(dir, "sqlite.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -30,6 +32,7 @@ func TestMigrate(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite.db")) + ctx := context.Background() db, err := sql.Open("sqlite", filepath.Join(dir, "sqlite.db")) if err != nil { return @@ -39,12 +42,12 @@ func TestMigrate(t *testing.T) { return } }() - driver, err := WithInstance(db, &Config{}) + driver, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { @@ -58,6 +61,7 @@ func TestMigrationTable(t *testing.T) { t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite.db")) + ctx := context.Background() db, err := sql.Open("sqlite", filepath.Join(dir, "sqlite.db")) if err != nil { return @@ -71,18 +75,18 @@ func TestMigrationTable(t *testing.T) { config := &Config{ MigrationsTable: "my_migration_table", } - driver, err := WithInstance(db, config) + driver, err := WithInstance(ctx, db, config) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { t.Fatal(err) } t.Log("UP") - err = m.Up() + err = m.Up(ctx) if err != nil { t.Fatal(err) } @@ -96,9 +100,10 @@ func TestMigrationTable(t *testing.T) { func TestNoTxWrap(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite://%s?x-no-tx-wrap=true", filepath.Join(dir, "sqlite.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -110,9 +115,10 @@ func TestNoTxWrap(t *testing.T) { func TestNoTxWrapInvalidValue(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite://%s?x-no-tx-wrap=yeppers", filepath.Join(dir, "sqlite.db")) - _, err := p.Open(addr) + _, err := p.Open(ctx, addr) if assert.Error(t, err) { assert.Contains(t, err.Error(), "x-no-tx-wrap") assert.Contains(t, err.Error(), "invalid syntax") @@ -123,9 +129,10 @@ func TestMigrateWithDirectoryNameContainsWhitespaces(t *testing.T) { dir := t.TempDir() dbPath := filepath.Join(dir, "sqlite.db") t.Logf("DB path : %s\n", dbPath) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite://file:%s", dbPath) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 56bb23338..23d2221b0 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "context" "database/sql" "fmt" "io" @@ -40,7 +41,7 @@ type Sqlite struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -57,7 +58,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { db: instance, config: config, } - if err := mx.ensureVersionTable(); err != nil { + if err := mx.ensureVersionTable(ctx); err != nil { return nil, err } return mx, nil @@ -66,13 +67,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database, which deviates from the usual // convention of "caller locks" in the Sqlite type. -func (m *Sqlite) ensureVersionTable() (err error) { - if err = m.Lock(); err != nil { +func (m *Sqlite) ensureVersionTable(ctx context.Context) (err error) { + if err = m.Lock(ctx); err != nil { return err } defer func() { - if e := m.Unlock(); e != nil { + if e := m.Unlock(ctx); e != nil { if err == nil { err = e } else { @@ -92,7 +93,7 @@ func (m *Sqlite) ensureVersionTable() (err error) { return nil } -func (m *Sqlite) Open(url string) (database.Driver, error) { +func (m *Sqlite) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -118,7 +119,7 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { } } - mx, err := WithInstance(db, &Config{ + mx, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, NoTxWrap: noTxWrap, @@ -129,11 +130,11 @@ func (m *Sqlite) Open(url string) (database.Driver, error) { return mx, nil } -func (m *Sqlite) Close() error { +func (m *Sqlite) Close(ctx context.Context) error { return m.db.Close() } -func (m *Sqlite) Drop() (err error) { +func (m *Sqlite) Drop(ctx context.Context) (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) if err != nil { @@ -177,21 +178,21 @@ func (m *Sqlite) Drop() (err error) { return nil } -func (m *Sqlite) Lock() error { +func (m *Sqlite) Lock(ctx context.Context) error { if !m.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (m *Sqlite) Unlock() error { +func (m *Sqlite) Unlock(ctx context.Context) error { if !m.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (m *Sqlite) Run(migration io.Reader) error { +func (m *Sqlite) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -228,7 +229,7 @@ func (m *Sqlite) executeQueryNoTx(query string) error { return nil } -func (m *Sqlite) SetVersion(version int, dirty bool) error { +func (m *Sqlite) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := m.db.Begin() if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} @@ -259,7 +260,7 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return nil } -func (m *Sqlite) Version() (version int, dirty bool, err error) { +func (m *Sqlite) Version(ctx context.Context) (version int, dirty bool, err error) { query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" err = m.db.QueryRow(query).Scan(&version, &dirty) if err != nil { diff --git a/database/sqlite3/sqlite3_test.go b/database/sqlite3/sqlite3_test.go index 6d152e923..dc7f7581a 100644 --- a/database/sqlite3/sqlite3_test.go +++ b/database/sqlite3/sqlite3_test.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "context" "database/sql" "fmt" "path/filepath" @@ -17,9 +18,10 @@ import ( func Test(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s", filepath.Join(dir, "sqlite3.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -30,6 +32,7 @@ func TestMigrate(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() db, err := sql.Open("sqlite3", filepath.Join(dir, "sqlite3.db")) if err != nil { return @@ -39,12 +42,12 @@ func TestMigrate(t *testing.T) { return } }() - driver, err := WithInstance(db, &Config{}) + driver, err := WithInstance(ctx, db, &Config{}) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { @@ -58,6 +61,7 @@ func TestMigrationTable(t *testing.T) { t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() db, err := sql.Open("sqlite3", filepath.Join(dir, "sqlite3.db")) if err != nil { return @@ -71,18 +75,18 @@ func TestMigrationTable(t *testing.T) { config := &Config{ MigrationsTable: "my_migration_table", } - driver, err := WithInstance(db, config) + driver, err := WithInstance(ctx, db, config) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance( + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "ql", driver) if err != nil { t.Fatal(err) } t.Log("UP") - err = m.Up() + err = m.Up(ctx) if err != nil { t.Fatal(err) } @@ -96,9 +100,10 @@ func TestMigrationTable(t *testing.T) { func TestNoTxWrap(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s?x-no-tx-wrap=true", filepath.Join(dir, "sqlite3.db")) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -110,9 +115,10 @@ func TestNoTxWrap(t *testing.T) { func TestNoTxWrapInvalidValue(t *testing.T) { dir := t.TempDir() t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://%s?x-no-tx-wrap=yeppers", filepath.Join(dir, "sqlite3.db")) - _, err := p.Open(addr) + _, err := p.Open(ctx, addr) if assert.Error(t, err) { assert.Contains(t, err.Error(), "x-no-tx-wrap") assert.Contains(t, err.Error(), "invalid syntax") @@ -123,9 +129,10 @@ func TestMigrateWithDirectoryNameContainsWhitespaces(t *testing.T) { dir := t.TempDir() dbPath := filepath.Join(dir, "sqlite3.db") t.Logf("DB path : %s\n", dbPath) + ctx := context.Background() p := &Sqlite{} addr := fmt.Sprintf("sqlite3://file:%s", dbPath) - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 3cfa48bf9..d6b3d48d4 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -61,7 +61,7 @@ type SQLServer struct { // WithInstance returns a database instance from an already created database connection. // // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -114,7 +114,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := ss.ensureVersionTable(); err != nil { + if err := ss.ensureVersionTable(ctx); err != nil { return nil, err } @@ -122,7 +122,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } // Open a connection to the database. -func (ss *SQLServer) Open(url string) (database.Driver, error) { +func (ss *SQLServer) Open(ctx context.Context, url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err @@ -168,7 +168,7 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, }) @@ -181,7 +181,7 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { } // Close the database connection -func (ss *SQLServer) Close() error { +func (ss *SQLServer) Close(ctx context.Context) error { connErr := ss.conn.Close() dbErr := ss.db.Close() if connErr != nil || dbErr != nil { @@ -191,7 +191,7 @@ func (ss *SQLServer) Close() error { } // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. -func (ss *SQLServer) Lock() error { +func (ss *SQLServer) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) if err != nil { @@ -215,7 +215,7 @@ func (ss *SQLServer) Lock() error { } // Unlock froms the migration lock from the database -func (ss *SQLServer) Unlock() error { +func (ss *SQLServer) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) if err != nil { @@ -233,7 +233,7 @@ func (ss *SQLServer) Unlock() error { } // Run the migrations for the database -func (ss *SQLServer) Run(migration io.Reader) error { +func (ss *SQLServer) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -256,7 +256,7 @@ func (ss *SQLServer) Run(migration io.Reader) error { } // SetVersion for the current database -func (ss *SQLServer) SetVersion(version int, dirty bool) error { +func (ss *SQLServer) SetVersion(ctx context.Context, version int, dirty bool) error { tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { @@ -296,7 +296,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { } // Version of the current database state -func (ss *SQLServer) Version() (version int, dirty bool, err error) { +func (ss *SQLServer) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable() err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { @@ -313,7 +313,7 @@ func (ss *SQLServer) Version() (version int, dirty bool, err error) { } // Drop all tables from the database. -func (ss *SQLServer) Drop() error { +func (ss *SQLServer) Drop(ctx context.Context) error { // drop all referential integrity constraints query := ` @@ -347,13 +347,13 @@ func (ss *SQLServer) Drop() error { return nil } -func (ss *SQLServer) ensureVersionTable() (err error) { - if err = ss.Lock(); err != nil { +func (ss *SQLServer) ensureVersionTable(ctx context.Context) (err error) { + if err = ss.Lock(ctx); err != nil { return err } defer func() { - if e := ss.Unlock(); e != nil { + if e := ss.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index e7af5d7f9..15710e799 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -124,14 +124,15 @@ func test(t *testing.T) { } addr := msConnectionString(ip, port) + ctx := context.Background() p := &SQLServer{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatalf("%v", err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -149,19 +150,20 @@ func testMigrate(t *testing.T) { } addr := msConnectionString(ip, port) + ctx := context.Background() p := &SQLServer{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatalf("%v", err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "master", d) if err != nil { t.Fatal(err) } @@ -178,17 +180,18 @@ func testMultiStatement(t *testing.T) { } addr := msConnectionString(ip, port) + ctx := context.Background() ms := &SQLServer{} - d, err := ms.Open(addr) + d, err := ms.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -213,13 +216,14 @@ func testErrorParsing(t *testing.T) { addr := msConnectionString(ip, port) + ctx := context.Background() p := &SQLServer{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -227,7 +231,7 @@ func testErrorParsing(t *testing.T) { wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` + ` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` + `'TABLEE' used in a CREATE, DROP, or ALTER statement.)` - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { t.Fatal("expected err but got nil") } else if err.Error() != wantErr { t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) @@ -244,8 +248,9 @@ func testLockWorks(t *testing.T) { } addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port) + ctx := context.Background() p := &SQLServer{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatalf("%v", err) } @@ -253,21 +258,21 @@ func testLockWorks(t *testing.T) { ms := d.(*SQLServer) - err = ms.Lock() + err = ms.Lock(ctx) if err != nil { t.Fatal(err) } - err = ms.Unlock() + err = ms.Unlock(ctx) if err != nil { t.Fatal(err) } // make sure the 2nd lock works (RELEASE_LOCK is very finicky) - err = ms.Lock() + err = ms.Lock(ctx) if err != nil { t.Fatal(err) } - err = ms.Unlock() + err = ms.Unlock(ctx) if err != nil { t.Fatal(err) } @@ -283,8 +288,9 @@ func testMsiTrue(t *testing.T) { } addr := msConnectionStringMsi(ip, port, true) + ctx := context.Background() p := &SQLServer{} - _, err = p.Open(addr) + _, err = p.Open(ctx, addr) if err == nil { t.Fatal("MSI should fail when not running in an Azure context.") } @@ -300,21 +306,22 @@ func testOpenWithPasswordAndMSI(t *testing.T) { } addr := msConnectionStringMsiWithPassword(ip, port, true) + ctx := context.Background() p := &SQLServer{} - _, err = p.Open(addr) + _, err = p.Open(ctx, addr) if err == nil { t.Fatal("Open should fail when both password and useMsi=true are passed.") } addr = msConnectionStringMsiWithPassword(ip, port, false) p = &SQLServer{} - d, err := p.Open(addr) + d, err := p.Open(ctx, addr) if err != nil { t.Fatal(err) } defer func() { - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Error(err) } }() @@ -332,8 +339,9 @@ func testMsiFalse(t *testing.T) { } addr := msConnectionStringMsi(ip, port, false) + ctx := context.Background() p := &SQLServer{} - _, err = p.Open(addr) + _, err = p.Open(ctx, addr) if err == nil { t.Fatal("Open should fail since no password was passed and useMsi is false.") } diff --git a/database/stub/stub.go b/database/stub/stub.go index ae502650b..9870c6b1f 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -1,6 +1,7 @@ package stub import ( + "context" "io" "reflect" @@ -25,7 +26,7 @@ type Stub struct { Config *Config } -func (s *Stub) Open(url string) (database.Driver, error) { +func (s *Stub) Open(ctx context.Context, url string) (database.Driver, error) { return &Stub{ Url: url, CurrentVersion: database.NilVersion, @@ -36,7 +37,7 @@ func (s *Stub) Open(url string) (database.Driver, error) { type Config struct{} -func WithInstance(instance interface{}, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance interface{}, config *Config) (database.Driver, error) { return &Stub{ Instance: instance, CurrentVersion: database.NilVersion, @@ -45,25 +46,25 @@ func WithInstance(instance interface{}, config *Config) (database.Driver, error) }, nil } -func (s *Stub) Close() error { +func (s *Stub) Close(ctx context.Context) error { return nil } -func (s *Stub) Lock() error { +func (s *Stub) Lock(ctx context.Context) error { if !s.isLocked.CAS(false, true) { return database.ErrLocked } return nil } -func (s *Stub) Unlock() error { +func (s *Stub) Unlock(ctx context.Context) error { if !s.isLocked.CAS(true, false) { return database.ErrNotLocked } return nil } -func (s *Stub) Run(migration io.Reader) error { +func (s *Stub) Run(ctx context.Context, migration io.Reader) error { m, err := io.ReadAll(migration) if err != nil { return err @@ -73,19 +74,19 @@ func (s *Stub) Run(migration io.Reader) error { return nil } -func (s *Stub) SetVersion(version int, state bool) error { +func (s *Stub) SetVersion(ctx context.Context, version int, state bool) error { s.CurrentVersion = version s.IsDirty = state return nil } -func (s *Stub) Version() (version int, dirty bool, err error) { +func (s *Stub) Version(ctx context.Context) (version int, dirty bool, err error) { return s.CurrentVersion, s.IsDirty, nil } const DROP = "DROP" -func (s *Stub) Drop() error { +func (s *Stub) Drop(ctx context.Context) error { s.CurrentVersion = database.NilVersion s.LastRunMigration = nil s.MigrationSequence = append(s.MigrationSequence, DROP) diff --git a/database/stub/stub_test.go b/database/stub/stub_test.go index 131c935a8..4ef6419d6 100644 --- a/database/stub/stub_test.go +++ b/database/stub/stub_test.go @@ -1,6 +1,7 @@ package stub import ( + "context" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source/stub" @@ -10,8 +11,9 @@ import ( ) func Test(t *testing.T) { + ctx := context.Background() s := &Stub{} - d, err := s.Open("") + d, err := s.Open(ctx, "") if err != nil { t.Fatal(err) } @@ -19,8 +21,9 @@ func Test(t *testing.T) { } func TestMigrate(t *testing.T) { + ctx := context.Background() s := &Stub{} - d, err := s.Open("") + d, err := s.Open(ctx, "") if err != nil { t.Fatal(err) } @@ -29,7 +32,7 @@ func TestMigrate(t *testing.T) { stubMigrations.Append(&source.Migration{Version: 1, Direction: source.Up, Identifier: "CREATE 1"}) stubMigrations.Append(&source.Migration{Version: 1, Direction: source.Down, Identifier: "DROP 1"}) src := &stub.Stub{} - srcDrv, err := src.Open("") + srcDrv, err := src.Open(ctx, "") if err != nil { t.Fatal(err) } diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go index be8ed195f..528010896 100644 --- a/database/testing/migrate_testing.go +++ b/database/testing/migrate_testing.go @@ -4,6 +4,7 @@ package testing import ( + "context" "testing" ) @@ -21,14 +22,14 @@ func TestMigrate(t *testing.T, m *migrate.Migrate) { // Similar to TestDrop(), but tests the dropping mechanism through the Migrate logic instead, to check for // double-locking during the Drop logic. func TestMigrateDrop(t *testing.T, m *migrate.Migrate) { - if err := m.Drop(); err != nil { + if err := m.Drop(context.Background()); err != nil { t.Fatal(err) } } func TestMigrateUp(t *testing.T, m *migrate.Migrate) { t.Log("UP") - if err := m.Up(); err != nil { + if err := m.Up(context.Background()); err != nil { t.Fatal(err) } } diff --git a/database/testing/testing.go b/database/testing/testing.go index bd3294b1e..dab30a029 100644 --- a/database/testing/testing.go +++ b/database/testing/testing.go @@ -5,6 +5,7 @@ package testing import ( "bytes" + "context" "errors" "fmt" "io" @@ -29,7 +30,7 @@ func Test(t *testing.T, d database.Driver, migration []byte) { } func TestNilVersion(t *testing.T, d database.Driver) { - v, _, err := d.Version() + v, _, err := d.Version(context.Background()) if err != nil { t.Fatal(err) } @@ -57,30 +58,31 @@ func TestLockAndUnlock(t *testing.T, d database.Driver) { }() // run the locking test ... + ctx := context.Background() go func() { - if err := d.Lock(); err != nil { + if err := d.Lock(ctx); err != nil { errs <- err return } // try to acquire lock again - if err := d.Lock(); err == nil { + if err := d.Lock(ctx); err == nil { errs <- errors.New("lock: expected err not to be nil") return } // unlock - if err := d.Unlock(); err != nil { + if err := d.Unlock(ctx); err != nil { errs <- err return } // try to lock - if err := d.Lock(); err != nil { + if err := d.Lock(ctx); err != nil { errs <- err return } - if err := d.Unlock(); err != nil { + if err := d.Unlock(ctx); err != nil { errs <- err return } @@ -104,13 +106,13 @@ func TestRun(t *testing.T, d database.Driver, migration io.Reader) { t.Fatal("migration can't be nil") } - if err := d.Run(migration); err != nil { + if err := d.Run(context.Background(), migration); err != nil { t.Fatal(err) } } func TestDrop(t *testing.T, d database.Driver) { - if err := d.Drop(); err != nil { + if err := d.Drop(context.Background()); err != nil { t.Fatal(err) } } @@ -136,11 +138,11 @@ func TestSetVersion(t *testing.T, d database.Driver) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := d.SetVersion(tc.version, tc.dirty) + err := d.SetVersion(context.Background(), tc.version, tc.dirty) if err != tc.expectedErr { t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr) } - v, dirty, readErr := d.Version() + v, dirty, readErr := d.Version(context.Background()) if readErr != tc.expectedReadErr { t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr) } diff --git a/database/yugabytedb/yugabytedb.go b/database/yugabytedb/yugabytedb.go index 764d23c02..a3b12f818 100644 --- a/database/yugabytedb/yugabytedb.go +++ b/database/yugabytedb/yugabytedb.go @@ -59,7 +59,7 @@ type YugabyteDB struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithInstance(ctx context.Context, instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } @@ -112,14 +112,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - if err := px.ensureVersionTable(); err != nil { + if err := px.ensureVersionTable(ctx); err != nil { return nil, err } return px, nil } -func (c *YugabyteDB) Open(dbURL string) (database.Driver, error) { +func (c *YugabyteDB) Open(ctx context.Context, dbURL string) (database.Driver, error) { purl, err := url.Parse(dbURL) if err != nil { return nil, err @@ -169,7 +169,7 @@ func (c *YugabyteDB) Open(dbURL string) (database.Driver, error) { maxRetries = DefaultMaxRetries } - px, err := WithInstance(db, &Config{ + px, err := WithInstance(ctx, db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, LockTable: lockTable, @@ -185,13 +185,13 @@ func (c *YugabyteDB) Open(dbURL string) (database.Driver, error) { return px, nil } -func (c *YugabyteDB) Close() error { +func (c *YugabyteDB) Close(ctx context.Context) error { return c.db.Close() } // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed // See: https://github.com/yugabyte/yugabyte-db/issues/3642 -func (c *YugabyteDB) Lock() error { +func (c *YugabyteDB) Lock(ctx context.Context) error { return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) { return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) (err error) { aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) @@ -228,7 +228,7 @@ func (c *YugabyteDB) Lock() error { // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed // See: https://github.com/yugabyte/yugabyte-db/issues/3642 -func (c *YugabyteDB) Unlock() error { +func (c *YugabyteDB) Unlock(ctx context.Context) error { return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) { aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName) if err != nil { @@ -255,7 +255,7 @@ func (c *YugabyteDB) Unlock() error { }) } -func (c *YugabyteDB) Run(migration io.Reader) error { +func (c *YugabyteDB) Run(ctx context.Context, migration io.Reader) error { migr, err := io.ReadAll(migration) if err != nil { return err @@ -270,7 +270,7 @@ func (c *YugabyteDB) Run(migration io.Reader) error { return nil } -func (c *YugabyteDB) SetVersion(version int, dirty bool) error { +func (c *YugabyteDB) SetVersion(ctx context.Context, version int, dirty bool) error { return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err @@ -289,7 +289,7 @@ func (c *YugabyteDB) SetVersion(version int, dirty bool) error { }) } -func (c *YugabyteDB) Version() (version int, dirty bool, err error) { +func (c *YugabyteDB) Version(ctx context.Context) (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` err = c.db.QueryRow(query).Scan(&version, &dirty) @@ -312,7 +312,7 @@ func (c *YugabyteDB) Version() (version int, dirty bool, err error) { } } -func (c *YugabyteDB) Drop() (err error) { +func (c *YugabyteDB) Drop(ctx context.Context) (err error) { query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` tables, err := c.db.Query(query) if err != nil { @@ -353,13 +353,13 @@ func (c *YugabyteDB) Drop() (err error) { // ensureVersionTable checks if versions table exists and, if not, creates it. // Note that this function locks the database -func (c *YugabyteDB) ensureVersionTable() (err error) { - if err = c.Lock(); err != nil { +func (c *YugabyteDB) ensureVersionTable(ctx context.Context) (err error) { + if err = c.Lock(ctx); err != nil { return err } defer func() { - if e := c.Unlock(); e != nil { + if e := c.Unlock(ctx); e != nil { if err == nil { err = e } else { diff --git a/database/yugabytedb/yugabytedb_test.go b/database/yugabytedb/yugabytedb_test.go index 05fb14fa7..6895f30c5 100644 --- a/database/yugabytedb/yugabytedb_test.go +++ b/database/yugabytedb/yugabytedb_test.go @@ -110,6 +110,7 @@ func test(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(defaultPort) if err != nil { t.Fatal(err) @@ -117,7 +118,7 @@ func test(t *testing.T) { addr := getConnectionString(ip, port) c := &YugabyteDB{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } @@ -129,6 +130,7 @@ func testMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(defaultPort) if err != nil { t.Fatal(err) @@ -136,12 +138,12 @@ func testMigrate(t *testing.T) { addr := getConnectionString(ip, port) c := &YugabyteDB{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "migrate", d) + m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "migrate", d) if err != nil { t.Fatal(err) } @@ -153,6 +155,7 @@ func testMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(defaultPort) if err != nil { t.Fatal(err) @@ -160,11 +163,11 @@ func testMultiStatement(t *testing.T) { addr := getConnectionString(ip, port) c := &YugabyteDB{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } - if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -183,6 +186,7 @@ func testFilterCustomQuery(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) + ctx := context.Background() ip, port, err := ci.Port(defaultPort) if err != nil { t.Fatal(err) @@ -190,7 +194,7 @@ func testFilterCustomQuery(t *testing.T) { addr := getConnectionString(ip, port, "x-custom=foobar") c := &YugabyteDB{} - d, err := c.Open(addr) + d, err := c.Open(ctx, addr) if err != nil { t.Fatal(err) } diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 7adec2f84..598d23cc7 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -147,8 +148,8 @@ func createFile(filename string) error { return f.Close() } -func gotoCmd(m *migrate.Migrate, v uint) error { - if err := m.Migrate(v); err != nil { +func gotoCmd(ctx context.Context, m *migrate.Migrate, v uint) error { + if err := m.Migrate(ctx, v); err != nil { if err != migrate.ErrNoChange { return err } @@ -157,16 +158,16 @@ func gotoCmd(m *migrate.Migrate, v uint) error { return nil } -func upCmd(m *migrate.Migrate, limit int) error { +func upCmd(ctx context.Context, m *migrate.Migrate, limit int) error { if limit >= 0 { - if err := m.Steps(limit); err != nil { + if err := m.Steps(ctx, limit); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } else { - if err := m.Up(); err != nil { + if err := m.Up(ctx); err != nil { if err != migrate.ErrNoChange { return err } @@ -176,16 +177,16 @@ func upCmd(m *migrate.Migrate, limit int) error { return nil } -func downCmd(m *migrate.Migrate, limit int) error { +func downCmd(ctx context.Context, m *migrate.Migrate, limit int) error { if limit >= 0 { - if err := m.Steps(-limit); err != nil { + if err := m.Steps(ctx, -limit); err != nil { if err != migrate.ErrNoChange { return err } log.Println(err) } } else { - if err := m.Down(); err != nil { + if err := m.Down(ctx); err != nil { if err != migrate.ErrNoChange { return err } @@ -195,22 +196,22 @@ func downCmd(m *migrate.Migrate, limit int) error { return nil } -func dropCmd(m *migrate.Migrate) error { - if err := m.Drop(); err != nil { +func dropCmd(ctx context.Context, m *migrate.Migrate) error { + if err := m.Drop(ctx); err != nil { return err } return nil } -func forceCmd(m *migrate.Migrate, v int) error { - if err := m.Force(v); err != nil { +func forceCmd(ctx context.Context, m *migrate.Migrate, v int) error { + if err := m.Force(ctx, v); err != nil { return err } return nil } -func versionCmd(m *migrate.Migrate) error { - v, dirty, err := m.Version() +func versionCmd(ctx context.Context, m *migrate.Migrate) error { + v, dirty, err := m.Version(ctx) if err != nil { return err } diff --git a/internal/cli/main.go b/internal/cli/main.go index c7a3bd74a..83f9dcee0 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -1,6 +1,7 @@ package cli import ( + "context" "flag" "fmt" "os" @@ -122,10 +123,11 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU // initialize migrate // don't catch migraterErr here and let each command decide // how it wants to handle the error - migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr) + ctx := context.Background() + migrater, migraterErr := migrate.New(ctx, *sourcePtr, *databasePtr) defer func() { if migraterErr == nil { - if _, err := migrater.Close(); err != nil { + if _, err := migrater.Close(ctx); err != nil { log.Println(err) } } @@ -215,7 +217,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU log.fatal("error: can't read version argument V") } - if err := gotoCmd(migrater, uint(v)); err != nil { + if err := gotoCmd(ctx, migrater, uint(v)); err != nil { log.fatalErr(err) } @@ -245,7 +247,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU limit = int(n) } - if err := upCmd(migrater, limit); err != nil { + if err := upCmd(ctx, migrater, limit); err != nil { log.fatalErr(err) } @@ -285,7 +287,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU } } - if err := downCmd(migrater, num); err != nil { + if err := downCmd(ctx, migrater, num); err != nil { log.fatalErr(err) } @@ -320,7 +322,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU log.fatalErr(migraterErr) } - if err := dropCmd(migrater); err != nil { + if err := dropCmd(ctx, migrater); err != nil { log.fatalErr(err) } @@ -354,7 +356,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU log.fatal("error: argument V must be >= -1") } - if err := forceCmd(migrater, int(v)); err != nil { + if err := forceCmd(ctx, migrater, int(v)); err != nil { log.fatalErr(err) } @@ -367,7 +369,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU log.fatalErr(migraterErr) } - if err := versionCmd(migrater); err != nil { + if err := versionCmd(ctx, migrater); err != nil { log.fatalErr(err) } diff --git a/migrate.go b/migrate.go index 7763782a0..f954855a5 100644 --- a/migrate.go +++ b/migrate.go @@ -5,6 +5,7 @@ package migrate import ( + "context" "errors" "fmt" "os" @@ -84,7 +85,7 @@ type Migrate struct { // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. -func New(sourceURL, databaseURL string) (*Migrate, error) { +func New(ctx context.Context, sourceURL, databaseURL string) (*Migrate, error) { m := newCommon() sourceName, err := iurl.SchemeFromURL(sourceURL) @@ -99,13 +100,13 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { } m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) + sourceDrv, err := source.Open(ctx, sourceURL) if err != nil { return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) } m.sourceDrv = sourceDrv - databaseDrv, err := database.Open(databaseURL) + databaseDrv, err := database.Open(ctx, databaseURL) if err != nil { return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) } @@ -118,7 +119,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. -func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { +func NewWithDatabaseInstance(ctx context.Context, sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { m := newCommon() sourceName, err := iurl.SchemeFromURL(sourceURL) @@ -129,7 +130,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) + sourceDrv, err := source.Open(ctx, sourceURL) if err != nil { return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) } @@ -144,7 +145,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // and a database URL. The database URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. -func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { +func NewWithSourceInstance(ctx context.Context, sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { m := newCommon() databaseName, err := iurl.SchemeFromURL(databaseURL) @@ -155,7 +156,7 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data m.sourceName = sourceName - databaseDrv, err := database.Open(databaseURL) + databaseDrv, err := database.Open(ctx, databaseURL) if err != nil { return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) } @@ -192,18 +193,18 @@ func newCommon() *Migrate { } // Close closes the source and the database. -func (m *Migrate) Close() (source error, database error) { +func (m *Migrate) Close(ctx context.Context) (source error, database error) { databaseSrvClose := make(chan error) sourceSrvClose := make(chan error) m.logVerbosePrintf("Closing source and database\n") go func() { - databaseSrvClose <- m.databaseDrv.Close() + databaseSrvClose <- m.databaseDrv.Close(ctx) }() go func() { - sourceSrvClose <- m.sourceDrv.Close() + sourceSrvClose <- m.sourceDrv.Close(ctx) }() return <-sourceSrvClose, <-databaseSrvClose @@ -211,131 +212,131 @@ func (m *Migrate) Close() (source error, database error) { // Migrate looks at the currently active migration version, // then migrates either up or down to the specified version. -func (m *Migrate) Migrate(version uint) error { - if err := m.lock(); err != nil { +func (m *Migrate) Migrate(ctx context.Context, version uint) error { + if err := m.lock(ctx); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() + curVersion, dirty, err := m.databaseDrv.Version(ctx) if err != nil { - return m.unlockErr(err) + return m.unlockErr(ctx, err) } if dirty { - return m.unlockErr(ErrDirty{curVersion}) + return m.unlockErr(ctx, ErrDirty{curVersion}) } ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(version), ret) + go m.read(ctx, curVersion, int(version), ret) - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(ctx, m.runMigrations(ctx, ret)) } // Steps looks at the currently active migration version. // It will migrate up if n > 0, and down if n < 0. -func (m *Migrate) Steps(n int) error { +func (m *Migrate) Steps(ctx context.Context, n int) error { if n == 0 { return ErrNoChange } - if err := m.lock(); err != nil { + if err := m.lock(ctx); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() + curVersion, dirty, err := m.databaseDrv.Version(ctx) if err != nil { - return m.unlockErr(err) + return m.unlockErr(ctx, err) } if dirty { - return m.unlockErr(ErrDirty{curVersion}) + return m.unlockErr(ctx, ErrDirty{curVersion}) } ret := make(chan interface{}, m.PrefetchMigrations) if n > 0 { - go m.readUp(curVersion, n, ret) + go m.readUp(ctx, curVersion, n, ret) } else { - go m.readDown(curVersion, -n, ret) + go m.readDown(ctx, curVersion, -n, ret) } - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(ctx, m.runMigrations(ctx, ret)) } // Up looks at the currently active migration version // and will migrate all the way up (applying all up migrations). -func (m *Migrate) Up() error { - if err := m.lock(); err != nil { +func (m *Migrate) Up(ctx context.Context) error { + if err := m.lock(ctx); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() + curVersion, dirty, err := m.databaseDrv.Version(ctx) if err != nil { - return m.unlockErr(err) + return m.unlockErr(ctx, err) } if dirty { - return m.unlockErr(ErrDirty{curVersion}) + return m.unlockErr(ctx, ErrDirty{curVersion}) } ret := make(chan interface{}, m.PrefetchMigrations) - go m.readUp(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + go m.readUp(ctx, curVersion, -1, ret) + return m.unlockErr(ctx, m.runMigrations(ctx, ret)) } // Down looks at the currently active migration version // and will migrate all the way down (applying all down migrations). -func (m *Migrate) Down() error { - if err := m.lock(); err != nil { +func (m *Migrate) Down(ctx context.Context) error { + if err := m.lock(ctx); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() + curVersion, dirty, err := m.databaseDrv.Version(ctx) if err != nil { - return m.unlockErr(err) + return m.unlockErr(ctx, err) } if dirty { - return m.unlockErr(ErrDirty{curVersion}) + return m.unlockErr(ctx, ErrDirty{curVersion}) } ret := make(chan interface{}, m.PrefetchMigrations) - go m.readDown(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + go m.readDown(ctx, curVersion, -1, ret) + return m.unlockErr(ctx, m.runMigrations(ctx, ret)) } // Drop deletes everything in the database. -func (m *Migrate) Drop() error { - if err := m.lock(); err != nil { +func (m *Migrate) Drop(ctx context.Context) error { + if err := m.lock(ctx); err != nil { return err } - if err := m.databaseDrv.Drop(); err != nil { - return m.unlockErr(err) + if err := m.databaseDrv.Drop(ctx); err != nil { + return m.unlockErr(ctx, err) } - return m.unlock() + return m.unlock(ctx) } // Run runs any migration provided by you against the database. // It does not check any currently active version in database. // Usually you don't need this function at all. Use Migrate, // Steps, Up or Down instead. -func (m *Migrate) Run(migration ...*Migration) error { +func (m *Migrate) Run(ctx context.Context, migration ...*Migration) error { if len(migration) == 0 { return ErrNoChange } - if err := m.lock(); err != nil { + if err := m.lock(ctx); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() + curVersion, dirty, err := m.databaseDrv.Version(ctx) if err != nil { - return m.unlockErr(err) + return m.unlockErr(ctx, err) } if dirty { - return m.unlockErr(ErrDirty{curVersion}) + return m.unlockErr(ctx, ErrDirty{curVersion}) } ret := make(chan interface{}, m.PrefetchMigrations) @@ -358,32 +359,32 @@ func (m *Migrate) Run(migration ...*Migration) error { } }() - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(ctx, m.runMigrations(ctx, ret)) } // Force sets a migration version. // It does not check any currently active version in database. // It resets the dirty state to false. -func (m *Migrate) Force(version int) error { +func (m *Migrate) Force(ctx context.Context, version int) error { if version < -1 { return ErrInvalidVersion } - if err := m.lock(); err != nil { + if err := m.lock(ctx); err != nil { return err } - if err := m.databaseDrv.SetVersion(version, false); err != nil { - return m.unlockErr(err) + if err := m.databaseDrv.SetVersion(ctx, version, false); err != nil { + return m.unlockErr(ctx, err) } - return m.unlock() + return m.unlock(ctx) } // Version returns the currently active migration version. // If no migration has been applied, yet, it will return ErrNilVersion. -func (m *Migrate) Version() (version uint, dirty bool, err error) { - v, d, err := m.databaseDrv.Version() +func (m *Migrate) Version(ctx context.Context) (version uint, dirty bool, err error) { + v, d, err := m.databaseDrv.Version(ctx) if err != nil { return 0, false, err } @@ -399,12 +400,12 @@ func (m *Migrate) Version() (version uint, dirty bool, err error) { // Each migration is then written to the ret channel. // If an error occurs during reading, that error is written to the ret channel, too. // Once read is done reading it will close the ret channel. -func (m *Migrate) read(from int, to int, ret chan<- interface{}) { +func (m *Migrate) read(ctx context.Context, from int, to int, ret chan<- interface{}) { defer close(ret) // check if from version exists if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { + if err := m.versionExists(ctx, suint(from)); err != nil { ret <- err return } @@ -412,7 +413,7 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // check if to version exists if to >= 0 { - if err := m.versionExists(suint(to)); err != nil { + if err := m.versionExists(ctx, suint(to)); err != nil { ret <- err return } @@ -428,13 +429,13 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // it's going up // apply first migration if from is nil version if from == -1 { - firstVersion, err := m.sourceDrv.First() + firstVersion, err := m.sourceDrv.First(ctx) if err != nil { ret <- err return } - migr, err := m.newMigration(firstVersion, int(firstVersion)) + migr, err := m.newMigration(ctx, firstVersion, int(firstVersion)) if err != nil { ret <- err return @@ -456,13 +457,13 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { return } - next, err := m.sourceDrv.Next(suint(from)) + next, err := m.sourceDrv.Next(ctx, suint(from)) if err != nil { ret <- err return } - migr, err := m.newMigration(next, int(next)) + migr, err := m.newMigration(ctx, next, int(next)) if err != nil { ret <- err return @@ -486,10 +487,10 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { return } - prev, err := m.sourceDrv.Prev(suint(from)) + prev, err := m.sourceDrv.Prev(ctx, suint(from)) if errors.Is(err, os.ErrNotExist) && to == -1 { // apply nil migration - migr, err := m.newMigration(suint(from), -1) + migr, err := m.newMigration(ctx, suint(from), -1) if err != nil { ret <- err return @@ -508,7 +509,7 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { return } - migr, err := m.newMigration(suint(from), int(prev)) + migr, err := m.newMigration(ctx, suint(from), int(prev)) if err != nil { ret <- err return @@ -531,12 +532,12 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // Each migration is then written to the ret channel. // If an error occurs during reading, that error is written to the ret channel, too. // Once readUp is done reading it will close the ret channel. -func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { +func (m *Migrate) readUp(ctx context.Context, from int, limit int, ret chan<- interface{}) { defer close(ret) // check if from version exists if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { + if err := m.versionExists(ctx, suint(from)); err != nil { ret <- err return } @@ -555,13 +556,13 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { // apply first migration if from is nil version if from == -1 { - firstVersion, err := m.sourceDrv.First() + firstVersion, err := m.sourceDrv.First(ctx) if err != nil { ret <- err return } - migr, err := m.newMigration(firstVersion, int(firstVersion)) + migr, err := m.newMigration(ctx, firstVersion, int(firstVersion)) if err != nil { ret <- err return @@ -579,7 +580,7 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { } // apply next migration - next, err := m.sourceDrv.Next(suint(from)) + next, err := m.sourceDrv.Next(ctx, suint(from)) if errors.Is(err, os.ErrNotExist) { // no limit, but no migrations applied? if limit == -1 && count == 0 { @@ -609,7 +610,7 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { return } - migr, err := m.newMigration(next, int(next)) + migr, err := m.newMigration(ctx, next, int(next)) if err != nil { ret <- err return @@ -631,12 +632,12 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { // Each migration is then written to the ret channel. // If an error occurs during reading, that error is written to the ret channel, too. // Once readDown is done reading it will close the ret channel. -func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { +func (m *Migrate) readDown(ctx context.Context, from int, limit int, ret chan<- interface{}) { defer close(ret) // check if from version exists if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { + if err := m.versionExists(ctx, suint(from)); err != nil { ret <- err return } @@ -665,17 +666,17 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { return } - prev, err := m.sourceDrv.Prev(suint(from)) + prev, err := m.sourceDrv.Prev(ctx, suint(from)) if errors.Is(err, os.ErrNotExist) { // no limit or haven't reached limit, apply "first" migration if limit == -1 || limit-count > 0 { - firstVersion, err := m.sourceDrv.First() + firstVersion, err := m.sourceDrv.First(ctx) if err != nil { ret <- err return } - migr, err := m.newMigration(firstVersion, -1) + migr, err := m.newMigration(ctx, firstVersion, -1) if err != nil { ret <- err return @@ -699,7 +700,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { return } - migr, err := m.newMigration(suint(from), int(prev)) + migr, err := m.newMigration(ctx, suint(from), int(prev)) if err != nil { ret <- err return @@ -722,7 +723,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // Before running a newly received migration it will check if it's supposed // to stop execution because it might have received a stop signal on the // GracefulStop channel. -func (m *Migrate) runMigrations(ret <-chan interface{}) error { +func (m *Migrate) runMigrations(ctx context.Context, ret <-chan interface{}) error { for r := range ret { if m.stop() { @@ -737,19 +738,19 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { migr := r // set version with dirty state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { + if err := m.databaseDrv.SetVersion(ctx, migr.TargetVersion, true); err != nil { return err } if migr.Body != nil { m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) - if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + if err := m.databaseDrv.Run(ctx, migr.BufferedBody); err != nil { return err } } // set clean state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { + if err := m.databaseDrv.SetVersion(ctx, migr.TargetVersion, false); err != nil { return err } @@ -775,9 +776,9 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { // versionExists checks the source if either the up or down migration for // the specified migration version exists. -func (m *Migrate) versionExists(version uint) (result error) { +func (m *Migrate) versionExists(ctx context.Context, version uint) (result error) { // try up migration first - up, _, err := m.sourceDrv.ReadUp(version) + up, _, err := m.sourceDrv.ReadUp(ctx, version) if err == nil { defer func() { if errClose := up.Close(); errClose != nil { @@ -792,7 +793,7 @@ func (m *Migrate) versionExists(version uint) (result error) { } // then try down migration - down, _, err := m.sourceDrv.ReadDown(version) + down, _, err := m.sourceDrv.ReadDown(ctx, version) if err == nil { defer func() { if errClose := down.Close(); errClose != nil { @@ -831,11 +832,11 @@ func (m *Migrate) stop() bool { // newMigration is a helper func that returns a *Migration for the // specified version and targetVersion. -func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) { +func (m *Migrate) newMigration(ctx context.Context, version uint, targetVersion int) (*Migration, error) { var migr *Migration if targetVersion >= int(version) { - r, identifier, err := m.sourceDrv.ReadUp(version) + r, identifier, err := m.sourceDrv.ReadUp(ctx, version) if errors.Is(err, os.ErrNotExist) { // create "empty" migration migr, err = NewMigration(nil, "", version, targetVersion) @@ -855,7 +856,7 @@ func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, err } } else { - r, identifier, err := m.sourceDrv.ReadDown(version) + r, identifier, err := m.sourceDrv.ReadDown(ctx, version) if errors.Is(err, os.ErrNotExist) { // create "empty" migration migr, err = NewMigration(nil, "", version, targetVersion) @@ -886,7 +887,7 @@ func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, err // lock is a thread safe helper function to lock the database. // It should be called as late as possible when running migrations. -func (m *Migrate) lock() error { +func (m *Migrate) lock(ctx context.Context) error { m.isLockedMu.Lock() defer m.isLockedMu.Unlock() @@ -919,7 +920,7 @@ func (m *Migrate) lock() error { // now try to acquire the lock go func() { - if err := m.databaseDrv.Lock(); err != nil { + if err := m.databaseDrv.Lock(ctx); err != nil { errchan <- err } else { errchan <- nil @@ -937,11 +938,11 @@ func (m *Migrate) lock() error { // unlock is a thread safe helper function to unlock the database. // It should be called as early as possible when no more migrations are // expected to be executed. -func (m *Migrate) unlock() error { +func (m *Migrate) unlock(ctx context.Context) error { m.isLockedMu.Lock() defer m.isLockedMu.Unlock() - if err := m.databaseDrv.Unlock(); err != nil { + if err := m.databaseDrv.Unlock(ctx); err != nil { // BUG: Can potentially create a deadlock. Add a timeout. return err } @@ -952,8 +953,8 @@ func (m *Migrate) unlock() error { // unlockErr calls unlock and returns a combined error // if a prevErr is not nil. -func (m *Migrate) unlockErr(prevErr error) error { - if err := m.unlock(); err != nil { +func (m *Migrate) unlockErr(ctx context.Context, prevErr error) error { + if err := m.unlock(ctx); err != nil { return multierror.Append(prevErr, err) } return prevErr diff --git a/migrate_test.go b/migrate_test.go index f2728179e..4ae8f7a75 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -2,6 +2,7 @@ package migrate import ( "bytes" + "context" "database/sql" "errors" "io" @@ -42,7 +43,8 @@ func init() { type DummyInstance struct{ Name string } func TestNew(t *testing.T) { - m, err := New("stub://", "stub://") + ctx := context.Background() + m, err := New(ctx, "stub://", "stub://") if err != nil { t.Fatal(err) } @@ -64,25 +66,27 @@ func TestNew(t *testing.T) { func ExampleNew() { // Read migrations from /home/mattes/migrations and connect to a local postgres database. - m, err := New("file:///home/mattes/migrations", "postgres://mattes:secret@localhost:5432/database?sslmode=disable") + ctx := context.Background() + m, err := New(ctx, "file:///home/mattes/migrations", "postgres://mattes:secret@localhost:5432/database?sslmode=disable") if err != nil { log.Fatal(err) } // Migrate all the way up ... - if err := m.Up(); err != nil && err != ErrNoChange { + if err := m.Up(ctx); err != nil && err != ErrNoChange { log.Fatal(err) } } func TestNewWithDatabaseInstance(t *testing.T) { + ctx := context.Background() dummyDb := &DummyInstance{"database"} - dbInst, err := dStub.WithInstance(dummyDb, &dStub.Config{}) + dbInst, err := dStub.WithInstance(ctx, dummyDb, &dStub.Config{}) if err != nil { t.Fatal(err) } - m, err := NewWithDatabaseInstance("stub://", dbDrvNameStub, dbInst) + m, err := NewWithDatabaseInstance(ctx, "stub://", dbDrvNameStub, dbInst) if err != nil { t.Fatal(err) } @@ -103,6 +107,8 @@ func TestNewWithDatabaseInstance(t *testing.T) { } func ExampleNewWithDatabaseInstance() { + ctx := context.Background() + // Create and use an existing database instance. db, err := sql.Open("postgres", "postgres://mattes:secret@localhost:5432/database?sslmode=disable") if err != nil { @@ -117,31 +123,32 @@ func ExampleNewWithDatabaseInstance() { // Create driver instance from db. // Check each driver if it supports the WithInstance function. // `import "github.com/golang-migrate/migrate/v4/database/postgres"` - instance, err := dStub.WithInstance(db, &dStub.Config{}) + instance, err := dStub.WithInstance(ctx, db, &dStub.Config{}) if err != nil { log.Fatal(err) } // Read migrations from /home/mattes/migrations and connect to a local postgres database. - m, err := NewWithDatabaseInstance("file:///home/mattes/migrations", "postgres", instance) + m, err := NewWithDatabaseInstance(ctx, "file:///home/mattes/migrations", "postgres", instance) if err != nil { log.Fatal(err) } // Migrate all the way up ... - if err := m.Up(); err != nil { + if err := m.Up(ctx); err != nil { log.Fatal(err) } } func TestNewWithSourceInstance(t *testing.T) { + ctx := context.Background() dummySource := &DummyInstance{"source"} - sInst, err := sStub.WithInstance(dummySource, &sStub.Config{}) + sInst, err := sStub.WithInstance(ctx, dummySource, &sStub.Config{}) if err != nil { t.Fatal(err) } - m, err := NewWithSourceInstance(srcDrvNameStub, sInst, "stub://") + m, err := NewWithSourceInstance(ctx, srcDrvNameStub, sInst, "stub://") if err != nil { t.Fatal(err) } @@ -162,37 +169,39 @@ func TestNewWithSourceInstance(t *testing.T) { } func ExampleNewWithSourceInstance() { + ctx := context.Background() di := &DummyInstance{"think any client required for a source here"} // Create driver instance from DummyInstance di. // Check each driver if it support the WithInstance function. // `import "github.com/golang-migrate/migrate/v4/source/stub"` - instance, err := sStub.WithInstance(di, &sStub.Config{}) + instance, err := sStub.WithInstance(ctx, di, &sStub.Config{}) if err != nil { log.Fatal(err) } // Read migrations from Stub and connect to a local postgres database. - m, err := NewWithSourceInstance(srcDrvNameStub, instance, "postgres://mattes:secret@localhost:5432/database?sslmode=disable") + m, err := NewWithSourceInstance(ctx, srcDrvNameStub, instance, "postgres://mattes:secret@localhost:5432/database?sslmode=disable") if err != nil { log.Fatal(err) } // Migrate all the way up ... - if err := m.Up(); err != nil { + if err := m.Up(ctx); err != nil { log.Fatal(err) } } func TestNewWithInstance(t *testing.T) { + ctx := context.Background() dummyDb := &DummyInstance{"database"} - dbInst, err := dStub.WithInstance(dummyDb, &dStub.Config{}) + dbInst, err := dStub.WithInstance(ctx, dummyDb, &dStub.Config{}) if err != nil { t.Fatal(err) } dummySource := &DummyInstance{"source"} - sInst, err := sStub.WithInstance(dummySource, &sStub.Config{}) + sInst, err := sStub.WithInstance(ctx, dummySource, &sStub.Config{}) if err != nil { t.Fatal(err) } @@ -222,8 +231,9 @@ func ExampleNewWithInstance() { } func TestClose(t *testing.T) { - m, _ := New("stub://", "stub://") - sourceErr, databaseErr := m.Close() + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") + sourceErr, databaseErr := m.Close(ctx) if sourceErr != nil { t.Error(sourceErr) } @@ -233,7 +243,8 @@ func TestClose(t *testing.T) { } func TestMigrate(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations dbDrv := m.databaseDrv.(*dStub.Stub) @@ -467,13 +478,13 @@ func TestMigrate(t *testing.T) { } for i, v := range tt { - err := m.Migrate(v.version) + err := m.Migrate(ctx, v.version) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || (v.expectErr != os.ErrNotExist && err != v.expectErr) { t.Errorf("expected err %v, got %v, in %v", v.expectErr, err, i) } else if err == nil { - version, _, err := m.Version() + version, _, err := m.Version(ctx) if err != nil { t.Error(err) } @@ -486,20 +497,22 @@ func TestMigrate(t *testing.T) { } func TestMigrateDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } - err := m.Migrate(1) + err := m.Migrate(ctx, 1) if _, ok := err.(ErrDirty); !ok { t.Fatalf("expected ErrDirty, got %v", err) } } func TestSteps(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations dbDrv := m.databaseDrv.(*dStub.Stub) @@ -730,13 +743,13 @@ func TestSteps(t *testing.T) { } for i, v := range tt { - err := m.Steps(v.steps) + err := m.Steps(ctx, v.steps) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || (v.expectErr != os.ErrNotExist && err != v.expectErr) { t.Errorf("expected err %v, got %v, in %v", v.expectErr, err, i) } else if err == nil { - version, _, err := m.Version() + version, _, err := m.Version(ctx) if err != ErrNilVersion && err != nil { t.Error(err) } @@ -752,25 +765,27 @@ func TestSteps(t *testing.T) { } func TestStepsDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } - err := m.Steps(1) + err := m.Steps(ctx, 1) if _, ok := err.(ErrDirty); !ok { t.Fatalf("expected ErrDirty, got %v", err) } } func TestUpAndDown(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations dbDrv := m.databaseDrv.(*dStub.Stub) // go Up first - if err := m.Up(); err != nil { + if err := m.Up(ctx); err != nil { t.Fatal(err) } expectedSequence := migrationSequence{ @@ -782,7 +797,7 @@ func TestUpAndDown(t *testing.T) { equalDbSeq(t, 0, expectedSequence, dbDrv) // go Down - if err := m.Down(); err != nil { + if err := m.Down(ctx); err != nil { t.Fatal(err) } expectedSequence = migrationSequence{ @@ -798,7 +813,7 @@ func TestUpAndDown(t *testing.T) { equalDbSeq(t, 1, expectedSequence, dbDrv) // go 1 Up and then all the way Up - if err := m.Steps(1); err != nil { + if err := m.Steps(ctx, 1); err != nil { t.Fatal(err) } expectedSequence = migrationSequence{ @@ -814,7 +829,7 @@ func TestUpAndDown(t *testing.T) { } equalDbSeq(t, 2, expectedSequence, dbDrv) - if err := m.Up(); err != nil { + if err := m.Up(ctx); err != nil { t.Fatal(err) } expectedSequence = migrationSequence{ @@ -834,7 +849,7 @@ func TestUpAndDown(t *testing.T) { equalDbSeq(t, 3, expectedSequence, dbDrv) // go 1 Down and then all the way Down - if err := m.Steps(-1); err != nil { + if err := m.Steps(ctx, -1); err != nil { t.Fatal(err) } expectedSequence = migrationSequence{ @@ -854,7 +869,7 @@ func TestUpAndDown(t *testing.T) { } equalDbSeq(t, 1, expectedSequence, dbDrv) - if err := m.Down(); err != nil { + if err := m.Down(ctx); err != nil { t.Fatal(err) } expectedSequence = migrationSequence{ @@ -879,37 +894,40 @@ func TestUpAndDown(t *testing.T) { } func TestUpDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } - err := m.Up() + err := m.Up(ctx) if _, ok := err.(ErrDirty); !ok { t.Fatalf("expected ErrDirty, got %v", err) } } func TestDownDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } - err := m.Down() + err := m.Down(ctx) if _, ok := err.(ErrDirty); !ok { t.Fatalf("expected ErrDirty, got %v", err) } } func TestDrop(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations dbDrv := m.databaseDrv.(*dStub.Stub) - if err := m.Drop(); err != nil { + if err := m.Drop(ctx); err != nil { t.Fatal(err) } @@ -919,23 +937,24 @@ func TestDrop(t *testing.T) { } func TestVersion(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - _, _, err := m.Version() + _, _, err := m.Version(ctx) if err != ErrNilVersion { t.Fatalf("expected ErrNilVersion, got %v", err) } - if err := dbDrv.Run(bytes.NewBufferString("1_up")); err != nil { + if err := dbDrv.Run(ctx, bytes.NewBufferString("1_up")); err != nil { t.Fatal(err) } - if err := dbDrv.SetVersion(1, false); err != nil { + if err := dbDrv.SetVersion(ctx, 1, false); err != nil { t.Fatal(err) } - v, _, err := m.Version() + v, _, err := m.Version(ctx) if err != nil { t.Fatal(err) } @@ -946,18 +965,19 @@ func TestVersion(t *testing.T) { } func TestRun(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") mx, err := NewMigration(nil, "", 1, 2) if err != nil { t.Fatal(err) } - if err := m.Run(mx); err != nil { + if err := m.Run(ctx, mx); err != nil { t.Fatal(err) } - v, _, err := m.Version() + v, _, err := m.Version(ctx) if err != nil { t.Fatal(err) } @@ -968,9 +988,10 @@ func TestRun(t *testing.T) { } func TestRunDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } @@ -979,21 +1000,22 @@ func TestRunDirty(t *testing.T) { t.Fatal(err) } - err = m.Run(migr) + err = m.Run(ctx, migr) if _, ok := err.(ErrDirty); !ok { t.Fatalf("expected ErrDirty, got %v", err) } } func TestForce(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations - if err := m.Force(7); err != nil { + if err := m.Force(ctx, 7); err != nil { t.Fatal(err) } - v, dirty, err := m.Version() + v, dirty, err := m.Version(ctx) if err != nil { t.Fatal(err) } @@ -1006,19 +1028,21 @@ func TestForce(t *testing.T) { } func TestForceDirty(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { + if err := dbDrv.SetVersion(ctx, 0, true); err != nil { t.Fatal(err) } - if err := m.Force(1); err != nil { + if err := m.Force(ctx, 1); err != nil { t.Fatal(err) } } func TestRead(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations tt := []struct { @@ -1029,13 +1053,13 @@ func TestRead(t *testing.T) { }{ {from: -1, to: -1, expectErr: ErrNoChange}, {from: -1, to: 0, expectErr: os.ErrNotExist}, - {from: -1, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(1))}, + {from: -1, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1))}, {from: -1, to: 2, expectErr: os.ErrNotExist}, - {from: -1, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3))}, - {from: -1, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3), M(4))}, - {from: -1, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3), M(4), M(5))}, + {from: -1, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3))}, + {from: -1, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3), M(ctx, 4))}, + {from: -1, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3), M(ctx, 4), M(ctx, 5))}, {from: -1, to: 6, expectErr: os.ErrNotExist}, - {from: -1, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3), M(4), M(5), M(7))}, + {from: -1, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3), M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: -1, to: 8, expectErr: os.ErrNotExist}, {from: 0, to: -1, expectErr: os.ErrNotExist}, @@ -1049,15 +1073,15 @@ func TestRead(t *testing.T) { {from: 0, to: 7, expectErr: os.ErrNotExist}, {from: 0, to: 8, expectErr: os.ErrNotExist}, - {from: 1, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(1, -1))}, + {from: 1, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1, -1))}, {from: 1, to: 0, expectErr: os.ErrNotExist}, {from: 1, to: 1, expectErr: ErrNoChange}, {from: 1, to: 2, expectErr: os.ErrNotExist}, - {from: 1, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(3))}, - {from: 1, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(3), M(4))}, - {from: 1, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(3), M(4), M(5))}, + {from: 1, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3))}, + {from: 1, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3), M(ctx, 4))}, + {from: 1, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3), M(ctx, 4), M(ctx, 5))}, {from: 1, to: 6, expectErr: os.ErrNotExist}, - {from: 1, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(3), M(4), M(5), M(7))}, + {from: 1, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3), M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: 1, to: 8, expectErr: os.ErrNotExist}, {from: 2, to: -1, expectErr: os.ErrNotExist}, @@ -1071,37 +1095,37 @@ func TestRead(t *testing.T) { {from: 2, to: 7, expectErr: os.ErrNotExist}, {from: 2, to: 8, expectErr: os.ErrNotExist}, - {from: 3, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(3, 1), M(1, -1))}, + {from: 3, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 3, to: 0, expectErr: os.ErrNotExist}, - {from: 3, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(3, 1))}, + {from: 3, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3, 1))}, {from: 3, to: 2, expectErr: os.ErrNotExist}, {from: 3, to: 3, expectErr: ErrNoChange}, - {from: 3, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(4))}, - {from: 3, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(4), M(5))}, + {from: 3, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4))}, + {from: 3, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4), M(ctx, 5))}, {from: 3, to: 6, expectErr: os.ErrNotExist}, - {from: 3, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(4), M(5), M(7))}, + {from: 3, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: 3, to: 8, expectErr: os.ErrNotExist}, - {from: 4, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(4, 3), M(3, 1), M(1, -1))}, + {from: 4, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 4, to: 0, expectErr: os.ErrNotExist}, - {from: 4, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(4, 3), M(3, 1))}, + {from: 4, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3), M(ctx, 3, 1))}, {from: 4, to: 2, expectErr: os.ErrNotExist}, - {from: 4, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(4, 3))}, + {from: 4, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3))}, {from: 4, to: 4, expectErr: ErrNoChange}, - {from: 4, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(5))}, + {from: 4, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5))}, {from: 4, to: 6, expectErr: os.ErrNotExist}, - {from: 4, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(5), M(7))}, + {from: 4, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5), M(ctx, 7))}, {from: 4, to: 8, expectErr: os.ErrNotExist}, - {from: 5, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(5, 4), M(4, 3), M(3, 1), M(1, -1))}, + {from: 5, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 5, to: 0, expectErr: os.ErrNotExist}, - {from: 5, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(5, 4), M(4, 3), M(3, 1))}, + {from: 5, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1))}, {from: 5, to: 2, expectErr: os.ErrNotExist}, - {from: 5, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(5, 4), M(4, 3))}, - {from: 5, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(5, 4))}, + {from: 5, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4), M(ctx, 4, 3))}, + {from: 5, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4))}, {from: 5, to: 5, expectErr: ErrNoChange}, {from: 5, to: 6, expectErr: os.ErrNotExist}, - {from: 5, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(7))}, + {from: 5, to: 7, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7))}, {from: 5, to: 8, expectErr: os.ErrNotExist}, {from: 6, to: -1, expectErr: os.ErrNotExist}, @@ -1115,13 +1139,13 @@ func TestRead(t *testing.T) { {from: 6, to: 7, expectErr: os.ErrNotExist}, {from: 6, to: 8, expectErr: os.ErrNotExist}, - {from: 7, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4), M(4, 3), M(3, 1), M(1, -1))}, + {from: 7, to: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 7, to: 0, expectErr: os.ErrNotExist}, - {from: 7, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4), M(4, 3), M(3, 1))}, + {from: 7, to: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1))}, {from: 7, to: 2, expectErr: os.ErrNotExist}, - {from: 7, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4), M(4, 3))}, - {from: 7, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4))}, - {from: 7, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(7, 5))}, + {from: 7, to: 3, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4), M(ctx, 4, 3))}, + {from: 7, to: 4, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4))}, + {from: 7, to: 5, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5))}, {from: 7, to: 6, expectErr: os.ErrNotExist}, {from: 7, to: 7, expectErr: ErrNoChange}, {from: 7, to: 8, expectErr: os.ErrNotExist}, @@ -1140,7 +1164,7 @@ func TestRead(t *testing.T) { for i, v := range tt { ret := make(chan interface{}) - go m.read(v.from, v.to, ret) + go m.read(ctx, v.from, v.to, ret) migrations, err := migrationsFromChannel(ret) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || @@ -1155,7 +1179,8 @@ func TestRead(t *testing.T) { } func TestReadUp(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations tt := []struct { @@ -1164,40 +1189,40 @@ func TestReadUp(t *testing.T) { expectErr error expectMigrations migrationSequence }{ - {from: -1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3), M(4), M(5), M(7))}, + {from: -1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3), M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: -1, limit: 0, expectErr: ErrNoChange}, - {from: -1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(1))}, - {from: -1, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(1), M(3))}, + {from: -1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1))}, + {from: -1, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1), M(ctx, 3))}, {from: 0, limit: -1, expectErr: os.ErrNotExist}, {from: 0, limit: 0, expectErr: os.ErrNotExist}, {from: 0, limit: 1, expectErr: os.ErrNotExist}, {from: 0, limit: 2, expectErr: os.ErrNotExist}, - {from: 1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(3), M(4), M(5), M(7))}, + {from: 1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3), M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: 1, limit: 0, expectErr: ErrNoChange}, - {from: 1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(3))}, - {from: 1, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(3), M(4))}, + {from: 1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3))}, + {from: 1, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3), M(ctx, 4))}, {from: 2, limit: -1, expectErr: os.ErrNotExist}, {from: 2, limit: 0, expectErr: os.ErrNotExist}, {from: 2, limit: 1, expectErr: os.ErrNotExist}, {from: 2, limit: 2, expectErr: os.ErrNotExist}, - {from: 3, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(4), M(5), M(7))}, + {from: 3, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4), M(ctx, 5), M(ctx, 7))}, {from: 3, limit: 0, expectErr: ErrNoChange}, - {from: 3, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(4))}, - {from: 3, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(4), M(5))}, + {from: 3, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4))}, + {from: 3, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4), M(ctx, 5))}, - {from: 4, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(5), M(7))}, + {from: 4, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5), M(ctx, 7))}, {from: 4, limit: 0, expectErr: ErrNoChange}, - {from: 4, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(5))}, - {from: 4, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(5), M(7))}, + {from: 4, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5))}, + {from: 4, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5), M(ctx, 7))}, - {from: 5, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(7))}, + {from: 5, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7))}, {from: 5, limit: 0, expectErr: ErrNoChange}, - {from: 5, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(7))}, - {from: 5, limit: 2, expectErr: ErrShortLimit{1}, expectMigrations: newMigSeq(M(7))}, + {from: 5, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7))}, + {from: 5, limit: 2, expectErr: ErrShortLimit{1}, expectMigrations: newMigSeq(M(ctx, 7))}, {from: 6, limit: -1, expectErr: os.ErrNotExist}, {from: 6, limit: 0, expectErr: os.ErrNotExist}, @@ -1217,7 +1242,7 @@ func TestReadUp(t *testing.T) { for i, v := range tt { ret := make(chan interface{}) - go m.readUp(v.from, v.limit, ret) + go m.readUp(ctx, v.from, v.limit, ret) migrations, err := migrationsFromChannel(ret) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || @@ -1232,7 +1257,8 @@ func TestReadUp(t *testing.T) { } func TestReadDown(t *testing.T) { - m, _ := New("stub://", "stub://") + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations tt := []struct { @@ -1251,40 +1277,40 @@ func TestReadDown(t *testing.T) { {from: 0, limit: 1, expectErr: os.ErrNotExist}, {from: 0, limit: 2, expectErr: os.ErrNotExist}, - {from: 1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(1, -1))}, + {from: 1, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1, -1))}, {from: 1, limit: 0, expectErr: ErrNoChange}, - {from: 1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(1, -1))}, - {from: 1, limit: 2, expectErr: ErrShortLimit{1}, expectMigrations: newMigSeq(M(1, -1))}, + {from: 1, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 1, -1))}, + {from: 1, limit: 2, expectErr: ErrShortLimit{1}, expectMigrations: newMigSeq(M(ctx, 1, -1))}, {from: 2, limit: -1, expectErr: os.ErrNotExist}, {from: 2, limit: 0, expectErr: os.ErrNotExist}, {from: 2, limit: 1, expectErr: os.ErrNotExist}, {from: 2, limit: 2, expectErr: os.ErrNotExist}, - {from: 3, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(3, 1), M(1, -1))}, + {from: 3, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 3, limit: 0, expectErr: ErrNoChange}, - {from: 3, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(3, 1))}, - {from: 3, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(3, 1), M(1, -1))}, + {from: 3, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3, 1))}, + {from: 3, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 3, 1), M(ctx, 1, -1))}, - {from: 4, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(4, 3), M(3, 1), M(1, -1))}, + {from: 4, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 4, limit: 0, expectErr: ErrNoChange}, - {from: 4, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(4, 3))}, - {from: 4, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(4, 3), M(3, 1))}, + {from: 4, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3))}, + {from: 4, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 4, 3), M(ctx, 3, 1))}, - {from: 5, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(5, 4), M(4, 3), M(3, 1), M(1, -1))}, + {from: 5, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 5, limit: 0, expectErr: ErrNoChange}, - {from: 5, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(5, 4))}, - {from: 5, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(5, 4), M(4, 3))}, + {from: 5, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4))}, + {from: 5, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 5, 4), M(ctx, 4, 3))}, {from: 6, limit: -1, expectErr: os.ErrNotExist}, {from: 6, limit: 0, expectErr: os.ErrNotExist}, {from: 6, limit: 1, expectErr: os.ErrNotExist}, {from: 6, limit: 2, expectErr: os.ErrNotExist}, - {from: 7, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4), M(4, 3), M(3, 1), M(1, -1))}, + {from: 7, limit: -1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4), M(ctx, 4, 3), M(ctx, 3, 1), M(ctx, 1, -1))}, {from: 7, limit: 0, expectErr: ErrNoChange}, - {from: 7, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(7, 5))}, - {from: 7, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(7, 5), M(5, 4))}, + {from: 7, limit: 1, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5))}, + {from: 7, limit: 2, expectErr: nil, expectMigrations: newMigSeq(M(ctx, 7, 5), M(ctx, 5, 4))}, {from: 8, limit: -1, expectErr: os.ErrNotExist}, {from: 8, limit: 0, expectErr: os.ErrNotExist}, @@ -1294,7 +1320,7 @@ func TestReadDown(t *testing.T) { for i, v := range tt { ret := make(chan interface{}) - go m.readDown(v.from, v.limit, ret) + go m.readDown(ctx, v.from, v.limit, ret) migrations, err := migrationsFromChannel(ret) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || @@ -1309,12 +1335,13 @@ func TestReadDown(t *testing.T) { } func TestLock(t *testing.T) { - m, _ := New("stub://", "stub://") - if err := m.lock(); err != nil { + ctx := context.Background() + m, _ := New(ctx, "stub://", "stub://") + if err := m.lock(ctx); err != nil { t.Fatal(err) } - if err := m.lock(); err == nil { + if err := m.lock(ctx); err == nil { t.Fatal("should be locked already") } } @@ -1366,7 +1393,7 @@ func (m *migrationSequence) bodySequence() []string { } // M is a convenience func to create a new *Migration -func M(version uint, targetVersion ...int) *Migration { +func M(ctx context.Context, version uint, targetVersion ...int) *Migration { if len(targetVersion) > 1 { panic("only one targetVersion allowed") } @@ -1375,9 +1402,9 @@ func M(version uint, targetVersion ...int) *Migration { ts = targetVersion[0] } - m, _ := New("stub://", "stub://") + m, _ := New(ctx, "stub://", "stub://") m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations - migr, err := m.newMigration(version, ts) + migr, err := m.newMigration(ctx, version, ts) if err != nil { panic(err) } diff --git a/source/aws_s3/s3.go b/source/aws_s3/s3.go index cec87ae35..7aee0ef76 100644 --- a/source/aws_s3/s3.go +++ b/source/aws_s3/s3.go @@ -1,6 +1,7 @@ package awss3 import ( + "context" "fmt" "io" "net/url" @@ -30,7 +31,7 @@ type Config struct { Prefix string } -func (s *s3Driver) Open(folder string) (source.Driver, error) { +func (s *s3Driver) Open(ctx context.Context, folder string) (source.Driver, error) { config, err := parseURI(folder) if err != nil { return nil, err @@ -41,10 +42,10 @@ func (s *s3Driver) Open(folder string) (source.Driver, error) { return nil, err } - return WithInstance(s3.New(sess), config) + return WithInstance(ctx, s3.New(sess), config) } -func WithInstance(s3client s3iface.S3API, config *Config) (source.Driver, error) { +func WithInstance(ctx context.Context, s3client s3iface.S3API, config *Config) (source.Driver, error) { driver := &s3Driver{ config: config, s3client: s3client, @@ -97,42 +98,42 @@ func (s *s3Driver) loadMigrations() error { return nil } -func (s *s3Driver) Close() error { +func (s *s3Driver) Close(ctx context.Context) error { return nil } -func (s *s3Driver) First() (uint, error) { - v, ok := s.migrations.First() +func (s *s3Driver) First(ctx context.Context) (uint, error) { + v, ok := s.migrations.First(ctx) if !ok { return 0, os.ErrNotExist } return v, nil } -func (s *s3Driver) Prev(version uint) (uint, error) { - v, ok := s.migrations.Prev(version) +func (s *s3Driver) Prev(ctx context.Context, version uint) (uint, error) { + v, ok := s.migrations.Prev(ctx, version) if !ok { return 0, os.ErrNotExist } return v, nil } -func (s *s3Driver) Next(version uint) (uint, error) { - v, ok := s.migrations.Next(version) +func (s *s3Driver) Next(ctx context.Context, version uint) (uint, error) { + v, ok := s.migrations.Next(ctx, version) if !ok { return 0, os.ErrNotExist } return v, nil } -func (s *s3Driver) ReadUp(version uint) (io.ReadCloser, string, error) { +func (s *s3Driver) ReadUp(ctx context.Context, version uint) (io.ReadCloser, string, error) { if m, ok := s.migrations.Up(version); ok { return s.open(m) } return nil, "", os.ErrNotExist } -func (s *s3Driver) ReadDown(version uint) (io.ReadCloser, string, error) { +func (s *s3Driver) ReadDown(ctx context.Context, version uint) (io.ReadCloser, string, error) { if m, ok := s.migrations.Down(version); ok { return s.open(m) } diff --git a/source/aws_s3/s3_test.go b/source/aws_s3/s3_test.go index 12cc85646..27aae5322 100644 --- a/source/aws_s3/s3_test.go +++ b/source/aws_s3/s3_test.go @@ -1,6 +1,7 @@ package awss3 import ( + "context" "errors" "io" "strings" @@ -30,7 +31,7 @@ func Test(t *testing.T) { "prod/migrations/0-random-stuff/whatever.txt": "", }, } - driver, err := WithInstance(&s3Client, &Config{ + driver, err := WithInstance(context.Background(), &s3Client, &Config{ Bucket: "some-bucket", Prefix: "prod/migrations/", }) diff --git a/source/bitbucket/bitbucket.go b/source/bitbucket/bitbucket.go index 95e0a4226..cbfb7461a 100644 --- a/source/bitbucket/bitbucket.go +++ b/source/bitbucket/bitbucket.go @@ -1,6 +1,7 @@ package bitbucket import ( + "context" "fmt" "io" nurl "net/url" @@ -38,7 +39,7 @@ type Config struct { Ref string } -func (b *Bitbucket) Open(url string) (source.Driver, error) { +func (b *Bitbucket) Open(ctx context.Context, url string) (source.Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -68,7 +69,7 @@ func (b *Bitbucket) Open(url string) (source.Driver, error) { } cfg.Ref = u.Fragment - bi, err := WithInstance(cl, cfg) + bi, err := WithInstance(ctx, cl, cfg) if err != nil { return nil, err } @@ -76,7 +77,7 @@ func (b *Bitbucket) Open(url string) (source.Driver, error) { return bi, nil } -func WithInstance(client *bitbucket.Client, config *Config) (source.Driver, error) { +func WithInstance(ctx context.Context, client *bitbucket.Client, config *Config) (source.Driver, error) { bi := &Bitbucket{ client: client, config: config, @@ -126,41 +127,41 @@ func (b *Bitbucket) ensureFields() { } } -func (b *Bitbucket) Close() error { +func (b *Bitbucket) Close(ctx context.Context) error { return nil } -func (b *Bitbucket) First() (version uint, er error) { +func (b *Bitbucket) First(ctx context.Context) (version uint, er error) { b.ensureFields() - if v, ok := b.migrations.First(); !ok { + if v, ok := b.migrations.First(ctx); !ok { return 0, &os.PathError{Op: "first", Path: b.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bitbucket) Prev(version uint) (prevVersion uint, err error) { +func (b *Bitbucket) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { b.ensureFields() - if v, ok := b.migrations.Prev(version); !ok { + if v, ok := b.migrations.Prev(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: b.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bitbucket) Next(version uint) (nextVersion uint, err error) { +func (b *Bitbucket) Next(ctx context.Context, version uint) (nextVersion uint, err error) { b.ensureFields() - if v, ok := b.migrations.Next(version); !ok { + if v, ok := b.migrations.Next(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: b.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bitbucket) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (b *Bitbucket) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { b.ensureFields() if m, ok := b.migrations.Up(version); ok { @@ -182,7 +183,7 @@ func (b *Bitbucket) ReadUp(version uint) (r io.ReadCloser, identifier string, er return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: b.config.Path, Err: os.ErrNotExist} } -func (b *Bitbucket) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (b *Bitbucket) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { b.ensureFields() if m, ok := b.migrations.Down(version); ok { diff --git a/source/bitbucket/bitbucket_test.go b/source/bitbucket/bitbucket_test.go index 075ec6506..cc0a0461a 100644 --- a/source/bitbucket/bitbucket_test.go +++ b/source/bitbucket/bitbucket_test.go @@ -2,6 +2,7 @@ package bitbucket import ( "bytes" + "context" "os" "testing" @@ -24,7 +25,7 @@ func Test(t *testing.T) { b := &Bitbucket{} - d, err := b.Open("bitbucket://" + BitbucketTestSecret + "@abhishekbipp/test-migration/migrations/test#master") + d, err := b.Open(context.Background(), "bitbucket://"+BitbucketTestSecret+"@abhishekbipp/test-migration/migrations/test#master") if err != nil { t.Fatal(err) } diff --git a/source/driver.go b/source/driver.go index 396eabfae..7a1b954fb 100644 --- a/source/driver.go +++ b/source/driver.go @@ -5,6 +5,7 @@ package source import ( + "context" "fmt" "io" nurl "net/url" @@ -36,44 +37,44 @@ type Driver interface { // Open returns a new driver instance configured with parameters // coming from the URL string. Migrate will call this function // only once per instance. - Open(url string) (Driver, error) + Open(ctx context.Context, url string) (Driver, error) // Close closes the underlying source instance managed by the driver. // Migrate will call this function only once per instance. - Close() error + Close(ctx context.Context) error // First returns the very first migration version available to the driver. // Migrate will call this function multiple times. // If there is no version available, it must return os.ErrNotExist. - First() (version uint, err error) + First(ctx context.Context) (version uint, err error) // Prev returns the previous version for a given version available to the driver. // Migrate will call this function multiple times. // If there is no previous version available, it must return os.ErrNotExist. - Prev(version uint) (prevVersion uint, err error) + Prev(ctx context.Context, version uint) (prevVersion uint, err error) // Next returns the next version for a given version available to the driver. // Migrate will call this function multiple times. // If there is no next version available, it must return os.ErrNotExist. - Next(version uint) (nextVersion uint, err error) + Next(ctx context.Context, version uint) (nextVersion uint, err error) // ReadUp returns the UP migration body and an identifier that helps // finding this migration in the source for a given version. // If there is no up migration available for this version, // it must return os.ErrNotExist. // Do not start reading, just return the ReadCloser! - ReadUp(version uint) (r io.ReadCloser, identifier string, err error) + ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) // ReadDown returns the DOWN migration body and an identifier that helps // finding this migration in the source for a given version. // If there is no down migration available for this version, // it must return os.ErrNotExist. // Do not start reading, just return the ReadCloser! - ReadDown(version uint) (r io.ReadCloser, identifier string, err error) + ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) } // Open returns a new driver instance. -func Open(url string) (Driver, error) { +func Open(ctx context.Context, url string) (Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -90,7 +91,7 @@ func Open(url string) (Driver, error) { return nil, fmt.Errorf("source driver: unknown driver '%s' (forgotten import?)", u.Scheme) } - return d.Open(url) + return d.Open(ctx, url) } // Register globally registers a driver. diff --git a/source/file/file.go b/source/file/file.go index d8b21dffb..6d31c5242 100644 --- a/source/file/file.go +++ b/source/file/file.go @@ -1,6 +1,7 @@ package file import ( + "context" nurl "net/url" "os" "path/filepath" @@ -19,7 +20,7 @@ type File struct { path string } -func (f *File) Open(url string) (source.Driver, error) { +func (f *File) Open(ctx context.Context, url string) (source.Driver, error) { p, err := parseURL(url) if err != nil { return nil, err diff --git a/source/file/file_test.go b/source/file/file_test.go index 5680aa2a3..712c4b8c1 100644 --- a/source/file/file_test.go +++ b/source/file/file_test.go @@ -1,6 +1,7 @@ package file import ( + "context" "errors" "fmt" "os" @@ -31,7 +32,7 @@ func Test(t *testing.T) { mustWriteFile(t, tmpDir, "7_foobar.down.sql", "7 down") f := &File{} - d, err := f.Open(scheme + tmpDir) + d, err := f.Open(context.Background(), scheme+tmpDir) if err != nil { t.Fatal(err) } @@ -50,7 +51,7 @@ func TestOpen(t *testing.T) { } f := &File{} - _, err := f.Open(scheme + tmpDir) // absolute path + _, err := f.Open(context.Background(), scheme+tmpDir) // absolute path if err != nil { t.Fatal(err) } @@ -80,24 +81,25 @@ func TestOpenWithRelativePath(t *testing.T) { mustWriteFile(t, filepath.Join(tmpDir, "foo"), "1_foobar.up.sql", "") + ctx := context.Background() f := &File{} // dir: foo - d, err := f.Open("file://foo") + d, err := f.Open(ctx, "file://foo") if err != nil { t.Fatal(err) } - _, err = d.First() + _, err = d.First(ctx) if err != nil { t.Fatalf("expected first file in working dir %v for foo", tmpDir) } // dir: ./foo - d, err = f.Open("file://./foo") + d, err = f.Open(ctx, "file://./foo") if err != nil { t.Fatal(err) } - _, err = d.First() + _, err = d.First(ctx) if err != nil { t.Fatalf("expected first file in working dir %v for ./foo", tmpDir) } @@ -110,7 +112,7 @@ func TestOpenDefaultsToCurrentDirectory(t *testing.T) { } f := &File{} - d, err := f.Open(scheme) + d, err := f.Open(context.Background(), scheme) if err != nil { t.Fatal(err) } @@ -127,7 +129,7 @@ func TestOpenWithDuplicateVersion(t *testing.T) { mustWriteFile(t, tmpDir, "1_bar.up.sql", "") // 1 up f := &File{} - _, err := f.Open(scheme + tmpDir) + _, err := f.Open(context.Background(), scheme+tmpDir) if err == nil { t.Fatal("expected err") } @@ -137,12 +139,12 @@ func TestClose(t *testing.T) { tmpDir := t.TempDir() f := &File{} - d, err := f.Open(scheme + tmpDir) + d, err := f.Open(context.Background(), scheme+tmpDir) if err != nil { t.Fatal(err) } - if d.Close() != nil { + if d.Close(context.Background()) != nil { t.Fatal("expected nil") } } @@ -174,7 +176,7 @@ func BenchmarkOpen(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { f := &File{} - _, err := f.Open(scheme + dir) + _, err := f.Open(context.Background(), scheme+dir) if err != nil { b.Error(err) } @@ -189,13 +191,14 @@ func BenchmarkNext(b *testing.B) { b.Error(err) } }() + ctx := context.Background() f := &File{} - d, _ := f.Open(scheme + dir) + d, _ := f.Open(ctx, scheme+dir) b.ResetTimer() - v, err := d.First() + v, err := d.First(ctx) for n := 0; n < b.N; n++ { for !errors.Is(err, os.ErrNotExist) { - v, err = d.Next(v) + v, err = d.Next(ctx, v) } } b.StopTimer() diff --git a/source/github/github.go b/source/github/github.go index cea429d8c..13750d033 100644 --- a/source/github/github.go +++ b/source/github/github.go @@ -42,7 +42,7 @@ type Config struct { Ref string } -func (g *Github) Open(url string) (source.Driver, error) { +func (g *Github) Open(ctx context.Context, url string) (source.Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -88,7 +88,7 @@ func (g *Github) Open(url string) (source.Driver, error) { return gn, nil } -func WithInstance(client *github.Client, config *Config) (source.Driver, error) { +func WithInstance(ctx context.Context, client *github.Client, config *Config) (source.Driver, error) { gn := &Github{ client: client, config: config, @@ -140,41 +140,41 @@ func (g *Github) ensureFields() { } } -func (g *Github) Close() error { +func (g *Github) Close(ctx context.Context) error { return nil } -func (g *Github) First() (version uint, err error) { +func (g *Github) First(ctx context.Context) (version uint, err error) { g.ensureFields() - if v, ok := g.migrations.First(); !ok { + if v, ok := g.migrations.First(ctx); !ok { return 0, &os.PathError{Op: "first", Path: g.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Github) Prev(version uint) (prevVersion uint, err error) { +func (g *Github) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { g.ensureFields() - if v, ok := g.migrations.Prev(version); !ok { + if v, ok := g.migrations.Prev(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: g.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Github) Next(version uint) (nextVersion uint, err error) { +func (g *Github) Next(ctx context.Context, version uint) (nextVersion uint, err error) { g.ensureFields() - if v, ok := g.migrations.Next(version); !ok { + if v, ok := g.migrations.Next(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: g.config.Path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Github) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (g *Github) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { g.ensureFields() if m, ok := g.migrations.Up(version); ok { @@ -194,7 +194,7 @@ func (g *Github) ReadUp(version uint) (r io.ReadCloser, identifier string, err e return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: g.config.Path, Err: os.ErrNotExist} } -func (g *Github) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (g *Github) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { g.ensureFields() if m, ok := g.migrations.Down(version); ok { diff --git a/source/github/github_test.go b/source/github/github_test.go index daa6de9ea..1719d3934 100644 --- a/source/github/github_test.go +++ b/source/github/github_test.go @@ -2,6 +2,7 @@ package github import ( "bytes" + "context" "fmt" "os" "testing" @@ -25,7 +26,7 @@ func Test(t *testing.T) { } g := &Github{} - d, err := g.Open("github://" + GithubTestSecret + "@mattes/migrate_test_tmp/test#452b8003e7") + d, err := g.Open(context.Background(), "github://"+GithubTestSecret+"@mattes/migrate_test_tmp/test#452b8003e7") if err != nil { t.Fatal(err) } @@ -34,24 +35,25 @@ func Test(t *testing.T) { } func TestDefaultClient(t *testing.T) { + ctx := context.Background() g := &Github{} owner := "golang-migrate" repo := "migrate" path := "source/github/examples/migrations" url := fmt.Sprintf("github://%s/%s/%s", owner, repo, path) - d, err := g.Open(url) + d, err := g.Open(ctx, url) if err != nil { t.Fatal(err) } - ver, err := d.First() + ver, err := d.First(ctx) if err != nil { t.Fatal(err) } assert.Equal(t, uint(1085649617), ver) - ver, err = d.Next(ver) + ver, err = d.Next(ctx, ver) if err != nil { t.Fatal(err) } diff --git a/source/github_ee/github_ee.go b/source/github_ee/github_ee.go index 57e41b12e..9dbd61a20 100644 --- a/source/github_ee/github_ee.go +++ b/source/github_ee/github_ee.go @@ -1,6 +1,7 @@ package github_ee import ( + "context" "crypto/tls" "fmt" "net/http" @@ -22,7 +23,7 @@ type GithubEE struct { source.Driver } -func (g *GithubEE) Open(url string) (source.Driver, error) { +func (g *GithubEE) Open(ctx context.Context, url string) (source.Driver, error) { verifyTLS := true u, err := nurl.Parse(url) @@ -64,7 +65,7 @@ func (g *GithubEE) Open(url string) (source.Driver, error) { cfg.Path = strings.Join(pe[2:], "/") } - i, err := gh.WithInstance(ghc, cfg) + i, err := gh.WithInstance(ctx, ghc, cfg) if err != nil { return nil, err } diff --git a/source/github_ee/github_ee_test.go b/source/github_ee/github_ee_test.go index 3a8224912..dadbd5863 100644 --- a/source/github_ee/github_ee_test.go +++ b/source/github_ee/github_ee_test.go @@ -1,6 +1,7 @@ package github_ee import ( + "context" "net/http" "net/http/httptest" nurl "net/url" @@ -36,7 +37,7 @@ func Test(t *testing.T) { } g := &GithubEE{} - _, err = g.Open("github-ee://foo:bar@" + u.Host + "/mattes/migrate_test_tmp/test?verify-tls=false#452b8003e7") + _, err = g.Open(context.Background(), "github-ee://foo:bar@"+u.Host+"/mattes/migrate_test_tmp/test?verify-tls=false#452b8003e7") if err != nil { t.Fatal(err) diff --git a/source/gitlab/gitlab.go b/source/gitlab/gitlab.go index 674062e27..75475583f 100644 --- a/source/gitlab/gitlab.go +++ b/source/gitlab/gitlab.go @@ -1,6 +1,7 @@ package gitlab import ( + "context" "encoding/base64" "fmt" "io" @@ -42,7 +43,7 @@ type Gitlab struct { type Config struct { } -func (g *Gitlab) Open(url string) (source.Driver, error) { +func (g *Gitlab) Open(ctx context.Context, url string) (source.Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -103,7 +104,7 @@ func (g *Gitlab) Open(url string) (source.Driver, error) { return gn, nil } -func WithInstance(client *gitlab.Client, config *Config) (source.Driver, error) { +func WithInstance(ctx context.Context, client *gitlab.Client, config *Config) (source.Driver, error) { gn := &Gitlab{ client: client, migrations: source.NewMigrations(), @@ -164,35 +165,35 @@ func (g *Gitlab) nodeToMigration(node *gitlab.TreeNode) (*source.Migration, erro return nil, source.ErrParse } -func (g *Gitlab) Close() error { +func (g *Gitlab) Close(ctx context.Context) error { return nil } -func (g *Gitlab) First() (version uint, er error) { - if v, ok := g.migrations.First(); !ok { +func (g *Gitlab) First(ctx context.Context) (version uint, er error) { + if v, ok := g.migrations.First(ctx); !ok { return 0, &os.PathError{Op: "first", Path: g.path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Gitlab) Prev(version uint) (prevVersion uint, err error) { - if v, ok := g.migrations.Prev(version); !ok { +func (g *Gitlab) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { + if v, ok := g.migrations.Prev(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: g.path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Gitlab) Next(version uint) (nextVersion uint, err error) { - if v, ok := g.migrations.Next(version); !ok { +func (g *Gitlab) Next(ctx context.Context, version uint) (nextVersion uint, err error) { + if v, ok := g.migrations.Next(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: g.path, Err: os.ErrNotExist} } else { return v, nil } } -func (g *Gitlab) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (g *Gitlab) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := g.migrations.Up(version); ok { f, response, err := g.client.RepositoryFiles.GetFile(g.projectID, m.Raw, g.getOptions) if err != nil { @@ -214,7 +215,7 @@ func (g *Gitlab) ReadUp(version uint) (r io.ReadCloser, identifier string, err e return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: g.path, Err: os.ErrNotExist} } -func (g *Gitlab) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (g *Gitlab) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := g.migrations.Down(version); ok { f, response, err := g.client.RepositoryFiles.GetFile(g.projectID, m.Raw, g.getOptions) if err != nil { diff --git a/source/gitlab/gitlab_test.go b/source/gitlab/gitlab_test.go index b305f5ddd..d5b173f62 100644 --- a/source/gitlab/gitlab_test.go +++ b/source/gitlab/gitlab_test.go @@ -2,6 +2,7 @@ package gitlab import ( "bytes" + "context" "os" "testing" @@ -23,7 +24,7 @@ func Test(t *testing.T) { } g := &Gitlab{} - d, err := g.Open("gitlab://" + GitlabTestSecret + "@gitlab.com/11197284/migrations") + d, err := g.Open(context.Background(), "gitlab://"+GitlabTestSecret+"@gitlab.com/11197284/migrations") if err != nil { t.Fatal(err) } diff --git a/source/go_bindata/go-bindata.go b/source/go_bindata/go-bindata.go index d0d42f5af..5c2f2e962 100644 --- a/source/go_bindata/go-bindata.go +++ b/source/go_bindata/go-bindata.go @@ -2,6 +2,7 @@ package bindata import ( "bytes" + "context" "fmt" "io" "os" @@ -33,7 +34,7 @@ type Bindata struct { migrations *source.Migrations } -func (b *Bindata) Open(url string) (source.Driver, error) { +func (b *Bindata) Open(ctx context.Context, url string) (source.Driver, error) { return nil, fmt.Errorf("not yet implemented") } @@ -41,7 +42,7 @@ var ( ErrNoAssetSource = fmt.Errorf("expects *AssetSource") ) -func WithInstance(instance interface{}) (source.Driver, error) { +func WithInstance(ctx context.Context, instance interface{}) (source.Driver, error) { if _, ok := instance.(*AssetSource); !ok { return nil, ErrNoAssetSource } @@ -67,35 +68,35 @@ func WithInstance(instance interface{}) (source.Driver, error) { return bn, nil } -func (b *Bindata) Close() error { +func (b *Bindata) Close(ctx context.Context) error { return nil } -func (b *Bindata) First() (version uint, err error) { - if v, ok := b.migrations.First(); !ok { +func (b *Bindata) First(ctx context.Context) (version uint, err error) { + if v, ok := b.migrations.First(ctx); !ok { return 0, &os.PathError{Op: "first", Path: b.path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bindata) Prev(version uint) (prevVersion uint, err error) { - if v, ok := b.migrations.Prev(version); !ok { +func (b *Bindata) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { + if v, ok := b.migrations.Prev(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: b.path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bindata) Next(version uint) (nextVersion uint, err error) { - if v, ok := b.migrations.Next(version); !ok { +func (b *Bindata) Next(ctx context.Context, version uint) (nextVersion uint, err error) { + if v, ok := b.migrations.Next(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: b.path, Err: os.ErrNotExist} } else { return v, nil } } -func (b *Bindata) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (b *Bindata) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := b.migrations.Up(version); ok { body, err := b.assetSource.AssetFunc(m.Raw) if err != nil { @@ -106,7 +107,7 @@ func (b *Bindata) ReadUp(version uint) (r io.ReadCloser, identifier string, err return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: b.path, Err: os.ErrNotExist} } -func (b *Bindata) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (b *Bindata) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := b.migrations.Down(version); ok { body, err := b.assetSource.AssetFunc(m.Raw) if err != nil { diff --git a/source/go_bindata/go-bindata_test.go b/source/go_bindata/go-bindata_test.go index ac3bd0f4d..82b9c7293 100644 --- a/source/go_bindata/go-bindata_test.go +++ b/source/go_bindata/go-bindata_test.go @@ -1,6 +1,7 @@ package bindata import ( + "context" "testing" "github.com/golang-migrate/migrate/v4/source/go_bindata/testdata" @@ -14,7 +15,7 @@ func Test(t *testing.T) { return testdata.Asset(name) }) - d, err := WithInstance(s) + d, err := WithInstance(context.Background(), s) if err != nil { t.Fatal(err) } @@ -28,7 +29,7 @@ func TestWithInstance(t *testing.T) { return testdata.Asset(name) }) - _, err := WithInstance(s) + _, err := WithInstance(context.Background(), s) if err != nil { t.Fatal(err) } @@ -36,7 +37,7 @@ func TestWithInstance(t *testing.T) { func TestOpen(t *testing.T) { b := &Bindata{} - _, err := b.Open("") + _, err := b.Open(context.Background(), "") if err == nil { t.Fatal("expected err, because it's not implemented yet") } diff --git a/source/godoc_vfs/vfs.go b/source/godoc_vfs/vfs.go index b9d31eca7..3d6825cc8 100644 --- a/source/godoc_vfs/vfs.go +++ b/source/godoc_vfs/vfs.go @@ -7,6 +7,7 @@ package godoc_vfs import ( + "context" "github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source/httpfs" @@ -30,7 +31,7 @@ type VFS struct { // // Calling this function panics, instead use the WithInstance function. // See the package level documentation for an example. -func (b *VFS) Open(url string) (source.Driver, error) { +func (b *VFS) Open(ctx context.Context, url string) (source.Driver, error) { panic("not implemented") } @@ -38,7 +39,7 @@ func (b *VFS) Open(url string) (source.Driver, error) { // If a tree named searchPath exists in the virtual filesystem, WithInstance // searches for migration files there. // It defaults to "/". -func WithInstance(fs vfs.FileSystem, searchPath string) (source.Driver, error) { +func WithInstance(ctx context.Context, fs vfs.FileSystem, searchPath string) (source.Driver, error) { if searchPath == "" { searchPath = "/" } diff --git a/source/godoc_vfs/vfs_example_test.go b/source/godoc_vfs/vfs_example_test.go index ae178294d..b129ff972 100644 --- a/source/godoc_vfs/vfs_example_test.go +++ b/source/godoc_vfs/vfs_example_test.go @@ -1,6 +1,7 @@ package godoc_vfs_test import ( + "context" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/source/godoc_vfs" "golang.org/x/tools/godoc/vfs/mapfs" @@ -18,15 +19,16 @@ func Example_mapfs() { "7_foobar.down.sql": "7 down", }) - d, err := godoc_vfs.WithInstance(fs, "") + ctx := context.Background() + d, err := godoc_vfs.WithInstance(ctx, fs, "") if err != nil { panic("bad migrations found!") } - m, err := migrate.NewWithSourceInstance("godoc-vfs", d, "database://foobar") + m, err := migrate.NewWithSourceInstance(ctx, "godoc-vfs", d, "database://foobar") if err != nil { panic("error creating the migrations") } - err = m.Up() + err = m.Up(ctx) if err != nil { panic("up failed") } diff --git a/source/godoc_vfs/vfs_test.go b/source/godoc_vfs/vfs_test.go index 30bced1ed..800bfeebd 100644 --- a/source/godoc_vfs/vfs_test.go +++ b/source/godoc_vfs/vfs_test.go @@ -1,6 +1,7 @@ package godoc_vfs_test import ( + "context" "testing" "github.com/golang-migrate/migrate/v4/source/godoc_vfs" @@ -20,7 +21,7 @@ func TestVFS(t *testing.T) { "7_foobar.down.sql": "7 down", }) - d, err := godoc_vfs.WithInstance(fs, "") + d, err := godoc_vfs.WithInstance(context.Background(), fs, "") if err != nil { t.Fatal(err) } @@ -34,7 +35,7 @@ func TestOpen(t *testing.T) { } }() b := &godoc_vfs.VFS{} - if _, err := b.Open(""); err != nil { + if _, err := b.Open(context.Background(), ""); err != nil { t.Error(err) } } diff --git a/source/google_cloud_storage/storage.go b/source/google_cloud_storage/storage.go index 9ec3e7a71..76155ad05 100644 --- a/source/google_cloud_storage/storage.go +++ b/source/google_cloud_storage/storage.go @@ -24,7 +24,7 @@ type gcs struct { migrations *source.Migrations } -func (g *gcs) Open(folder string) (source.Driver, error) { +func (g *gcs) Open(ctx context.Context, folder string) (source.Driver, error) { u, err := url.Parse(folder) if err != nil { return nil, err @@ -67,42 +67,42 @@ func (g *gcs) loadMigrations() error { return nil } -func (g *gcs) Close() error { +func (g *gcs) Close(ctx context.Context) error { return nil } -func (g *gcs) First() (uint, error) { - v, ok := g.migrations.First() +func (g *gcs) First(ctx context.Context) (uint, error) { + v, ok := g.migrations.First(ctx) if !ok { return 0, os.ErrNotExist } return v, nil } -func (g *gcs) Prev(version uint) (uint, error) { - v, ok := g.migrations.Prev(version) +func (g *gcs) Prev(ctx context.Context, version uint) (uint, error) { + v, ok := g.migrations.Prev(ctx, version) if !ok { return 0, os.ErrNotExist } return v, nil } -func (g *gcs) Next(version uint) (uint, error) { - v, ok := g.migrations.Next(version) +func (g *gcs) Next(ctx context.Context, version uint) (uint, error) { + v, ok := g.migrations.Next(ctx, version) if !ok { return 0, os.ErrNotExist } return v, nil } -func (g *gcs) ReadUp(version uint) (io.ReadCloser, string, error) { +func (g *gcs) ReadUp(ctx context.Context, version uint) (io.ReadCloser, string, error) { if m, ok := g.migrations.Up(version); ok { return g.open(m) } return nil, "", os.ErrNotExist } -func (g *gcs) ReadDown(version uint) (io.ReadCloser, string, error) { +func (g *gcs) ReadDown(ctx context.Context, version uint) (io.ReadCloser, string, error) { if m, ok := g.migrations.Down(version); ok { return g.open(m) } diff --git a/source/httpfs/driver.go b/source/httpfs/driver.go index e0cdbaa00..62af98c7d 100644 --- a/source/httpfs/driver.go +++ b/source/httpfs/driver.go @@ -1,6 +1,7 @@ package httpfs import ( + "context" "errors" "net/http" @@ -26,6 +27,6 @@ func New(fs http.FileSystem, path string) (source.Driver, error) { // Open completes the implementetion of source.Driver interface. Other methods // are implemented by the embedded PartialDriver struct. -func (d *driver) Open(url string) (source.Driver, error) { +func (d *driver) Open(ctx context.Context, url string) (source.Driver, error) { return nil, errors.New("Open() cannot be called on the httpfs passthrough driver") } diff --git a/source/httpfs/driver_test.go b/source/httpfs/driver_test.go index d0cf786f6..6cb07f6f6 100644 --- a/source/httpfs/driver_test.go +++ b/source/httpfs/driver_test.go @@ -1,6 +1,7 @@ package httpfs_test import ( + "context" "net/http" "testing" @@ -32,7 +33,7 @@ func TestOpen(t *testing.T) { t.Error("New() expected no error") return } - d, err = d.Open("") + d, err = d.Open(context.Background(), "") if d != nil { t.Error("Open() expected to return nil driver") } diff --git a/source/httpfs/partial_driver.go b/source/httpfs/partial_driver.go index 5ddb79883..40df65b43 100644 --- a/source/httpfs/partial_driver.go +++ b/source/httpfs/partial_driver.go @@ -1,6 +1,7 @@ package httpfs import ( + "context" "errors" "io" "net/http" @@ -66,13 +67,13 @@ func (p *PartialDriver) Init(fs http.FileSystem, path string) error { } // Close is part of source.Driver interface implementation. This is a no-op. -func (p *PartialDriver) Close() error { +func (p *PartialDriver) Close(ctx context.Context) error { return nil } // First is part of source.Driver interface implementation. -func (p *PartialDriver) First() (version uint, err error) { - if version, ok := p.migrations.First(); ok { +func (p *PartialDriver) First(ctx context.Context) (version uint, err error) { + if version, ok := p.migrations.First(ctx); ok { return version, nil } return 0, &os.PathError{ @@ -83,8 +84,8 @@ func (p *PartialDriver) First() (version uint, err error) { } // Prev is part of source.Driver interface implementation. -func (p *PartialDriver) Prev(version uint) (prevVersion uint, err error) { - if version, ok := p.migrations.Prev(version); ok { +func (p *PartialDriver) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { + if version, ok := p.migrations.Prev(ctx, version); ok { return version, nil } return 0, &os.PathError{ @@ -95,8 +96,8 @@ func (p *PartialDriver) Prev(version uint) (prevVersion uint, err error) { } // Next is part of source.Driver interface implementation. -func (p *PartialDriver) Next(version uint) (nextVersion uint, err error) { - if version, ok := p.migrations.Next(version); ok { +func (p *PartialDriver) Next(ctx context.Context, version uint) (nextVersion uint, err error) { + if version, ok := p.migrations.Next(ctx, version); ok { return version, nil } return 0, &os.PathError{ @@ -107,7 +108,7 @@ func (p *PartialDriver) Next(version uint) (nextVersion uint, err error) { } // ReadUp is part of source.Driver interface implementation. -func (p *PartialDriver) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (p *PartialDriver) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := p.migrations.Up(version); ok { body, err := p.open(path.Join(p.path, m.Raw)) if err != nil { @@ -123,7 +124,7 @@ func (p *PartialDriver) ReadUp(version uint) (r io.ReadCloser, identifier string } // ReadDown is part of source.Driver interface implementation. -func (p *PartialDriver) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (p *PartialDriver) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := p.migrations.Down(version); ok { body, err := p.open(path.Join(p.path, m.Raw)) if err != nil { diff --git a/source/httpfs/partial_driver_test.go b/source/httpfs/partial_driver_test.go index 94c6ed14a..5b0ad4586 100644 --- a/source/httpfs/partial_driver_test.go +++ b/source/httpfs/partial_driver_test.go @@ -1,6 +1,7 @@ package httpfs_test import ( + "context" "errors" "net/http" "strings" @@ -13,13 +14,15 @@ import ( type driver struct{ httpfs.PartialDriver } -func (d *driver) Open(url string) (source.Driver, error) { return nil, errors.New("X") } +func (d *driver) Open(ctx context.Context, url string) (source.Driver, error) { + return nil, errors.New("X") +} type driverExample struct { httpfs.PartialDriver } -func (d *driverExample) Open(url string) (source.Driver, error) { +func (d *driverExample) Open(ctx context.Context, url string) (source.Driver, error) { parts := strings.Split(url, ":") dir := parts[0] path := "" @@ -32,7 +35,7 @@ func (d *driverExample) Open(url string) (source.Driver, error) { } func TestDriverExample(t *testing.T) { - d, err := (*driverExample)(nil).Open("testdata:sql") + d, err := (*driverExample)(nil).Open(context.Background(), "testdata:sql") if err != nil { t.Errorf("Open() returned error: %s", err) } @@ -80,7 +83,7 @@ func TestPartialDriverInit(t *testing.T) { t.Errorf("Init() returned error %s", err) } st.Test(t, &d) - if err = d.Close(); err != nil { + if err = d.Close(context.Background()); err != nil { t.Errorf("Init().Close() returned error %s", err) } } else { @@ -101,7 +104,7 @@ func TestFirstWithNoMigrations(t *testing.T) { t.Errorf("No error on Init() expected, got: %v", err) } - if _, err := d.First(); err == nil { + if _, err := d.First(context.Background()); err == nil { t.Errorf("Expected error on First(), got: %v", err) } } diff --git a/source/iofs/example_test.go b/source/iofs/example_test.go index 474fc633c..a6ec5fb26 100644 --- a/source/iofs/example_test.go +++ b/source/iofs/example_test.go @@ -4,6 +4,7 @@ package iofs_test import ( + "context" "embed" "log" @@ -16,15 +17,16 @@ import ( var fs embed.FS func Example() { + ctx := context.Background() d, err := iofs.New(fs, "testdata/migrations") if err != nil { log.Fatal(err) } - m, err := migrate.NewWithSourceInstance("iofs", d, "postgres://postgres@localhost/postgres?sslmode=disable") + m, err := migrate.NewWithSourceInstance(ctx, "iofs", d, "postgres://postgres@localhost/postgres?sslmode=disable") if err != nil { log.Fatal(err) } - err = m.Up() + err = m.Up(ctx) if err != nil { // ... } diff --git a/source/iofs/iofs.go b/source/iofs/iofs.go index dc934a5fe..830cfee6e 100644 --- a/source/iofs/iofs.go +++ b/source/iofs/iofs.go @@ -4,6 +4,7 @@ package iofs import ( + "context" "errors" "fmt" "io" @@ -29,7 +30,7 @@ func New(fsys fs.FS, path string) (source.Driver, error) { // Open is part of source.Driver interface implementation. // Open cannot be called on the iofs passthrough driver. -func (d *driver) Open(url string) (source.Driver, error) { +func (d *driver) Open(ctx context.Context, url string) (source.Driver, error) { return nil, errors.New("Open() cannot be called on the iofs passthrough driver") } @@ -82,7 +83,7 @@ func (d *PartialDriver) Init(fsys fs.FS, path string) error { // Close is part of source.Driver interface implementation. // Closes the file system if possible. -func (d *PartialDriver) Close() error { +func (d *PartialDriver) Close(ctx context.Context) error { c, ok := d.fsys.(io.Closer) if !ok { return nil @@ -91,8 +92,8 @@ func (d *PartialDriver) Close() error { } // First is part of source.Driver interface implementation. -func (d *PartialDriver) First() (version uint, err error) { - if version, ok := d.migrations.First(); ok { +func (d *PartialDriver) First(ctx context.Context) (version uint, err error) { + if version, ok := d.migrations.First(ctx); ok { return version, nil } return 0, &fs.PathError{ @@ -103,8 +104,8 @@ func (d *PartialDriver) First() (version uint, err error) { } // Prev is part of source.Driver interface implementation. -func (d *PartialDriver) Prev(version uint) (prevVersion uint, err error) { - if version, ok := d.migrations.Prev(version); ok { +func (d *PartialDriver) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { + if version, ok := d.migrations.Prev(ctx, version); ok { return version, nil } return 0, &fs.PathError{ @@ -115,8 +116,8 @@ func (d *PartialDriver) Prev(version uint) (prevVersion uint, err error) { } // Next is part of source.Driver interface implementation. -func (d *PartialDriver) Next(version uint) (nextVersion uint, err error) { - if version, ok := d.migrations.Next(version); ok { +func (d *PartialDriver) Next(ctx context.Context, version uint) (nextVersion uint, err error) { + if version, ok := d.migrations.Next(ctx, version); ok { return version, nil } return 0, &fs.PathError{ @@ -127,7 +128,7 @@ func (d *PartialDriver) Next(version uint) (nextVersion uint, err error) { } // ReadUp is part of source.Driver interface implementation. -func (d *PartialDriver) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (d *PartialDriver) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := d.migrations.Up(version); ok { body, err := d.open(path.Join(d.path, m.Raw)) if err != nil { @@ -143,7 +144,7 @@ func (d *PartialDriver) ReadUp(version uint) (r io.ReadCloser, identifier string } // ReadDown is part of source.Driver interface implementation. -func (d *PartialDriver) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (d *PartialDriver) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := d.migrations.Down(version); ok { body, err := d.open(path.Join(d.path, m.Raw)) if err != nil { diff --git a/source/migration.go b/source/migration.go index 74f6523cb..2d1723ea8 100644 --- a/source/migration.go +++ b/source/migration.go @@ -1,6 +1,7 @@ package source import ( + "context" "sort" ) @@ -75,14 +76,14 @@ func (i *Migrations) buildIndex() { }) } -func (i *Migrations) First() (version uint, ok bool) { +func (i *Migrations) First(ctx context.Context) (version uint, ok bool) { if len(i.index) == 0 { return 0, false } return i.index[0], true } -func (i *Migrations) Prev(version uint) (prevVersion uint, ok bool) { +func (i *Migrations) Prev(ctx context.Context, version uint) (prevVersion uint, ok bool) { pos := i.findPos(version) if pos >= 1 && len(i.index) > pos-1 { return i.index[pos-1], true @@ -90,7 +91,7 @@ func (i *Migrations) Prev(version uint) (prevVersion uint, ok bool) { return 0, false } -func (i *Migrations) Next(version uint) (nextVersion uint, ok bool) { +func (i *Migrations) Next(ctx context.Context, version uint) (nextVersion uint, ok bool) { pos := i.findPos(version) if pos >= 0 && len(i.index) > pos+1 { return i.index[pos+1], true diff --git a/source/pkger/pkger.go b/source/pkger/pkger.go index f5f2132d6..7ee1e2c20 100644 --- a/source/pkger/pkger.go +++ b/source/pkger/pkger.go @@ -1,14 +1,15 @@ package pkger import ( + "context" "fmt" + "github.com/markbates/pkger/pkging" "net/http" stdurl "net/url" "github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/markbates/pkger" - "github.com/markbates/pkger/pkging" ) func init() { @@ -26,7 +27,7 @@ type Pkger struct { // scoped pkger.Open to access migrations. The relative root and any // migrations must be added to the global pkger.Pkger instance by calling // pkger.Apply. Refer to Pkger documentation for more information. -func (p *Pkger) Open(url string) (source.Driver, error) { +func (p *Pkger) Open(ctx context.Context, url string) (source.Driver, error) { u, err := stdurl.Parse(url) if err != nil { return nil, err @@ -52,7 +53,7 @@ func (p *Pkger) Open(url string) (source.Driver, error) { // pkging.Pkger. The relative location of migrations is indicated by path. The // path must exist on the pkging.Pkger instance for the driver to initialize // successfully. -func WithInstance(instance pkging.Pkger, path string) (source.Driver, error) { +func WithInstance(ctx context.Context, instance pkging.Pkger, path string) (source.Driver, error) { if instance == nil { return nil, fmt.Errorf("expected instance of pkging.Pkger") } diff --git a/source/pkger/pkger_test.go b/source/pkger/pkger_test.go index bb3d561b6..3c845e1bf 100644 --- a/source/pkger/pkger_test.go +++ b/source/pkger/pkger_test.go @@ -1,6 +1,7 @@ package pkger import ( + "context" "errors" "os" "testing" @@ -14,6 +15,7 @@ import ( func Test(t *testing.T) { t.Run("WithInstance", func(t *testing.T) { + ctx := context.Background() i := testInstance(t) createPkgerFile(t, i, "/1_foobar.up.sql") @@ -25,7 +27,7 @@ func Test(t *testing.T) { createPkgerFile(t, i, "/7_foobar.up.sql") createPkgerFile(t, i, "/7_foobar.down.sql") - d, err := WithInstance(i, "/") + d, err := WithInstance(ctx, i, "/") if err != nil { t.Fatal(err) } @@ -34,6 +36,7 @@ func Test(t *testing.T) { }) t.Run("Open", func(t *testing.T) { + ctx := context.Background() i := testInstance(t) createPkgerFile(t, i, "/1_foobar.up.sql") @@ -47,7 +50,7 @@ func Test(t *testing.T) { registerPackageLevelInstance(t, i) - d, err := (&Pkger{}).Open("pkger:///") + d, err := (&Pkger{}).Open(ctx, "pkger:///") if err != nil { t.Fatal(err) } @@ -58,6 +61,7 @@ func Test(t *testing.T) { } func TestWithInstance(t *testing.T) { + ctx := context.Background() t.Run("Subdir", func(t *testing.T) { i := testInstance(t) @@ -65,14 +69,14 @@ func TestWithInstance(t *testing.T) { // initialize. createPkgerSubdir(t, i, "/subdir") - _, err := WithInstance(i, "/subdir") + _, err := WithInstance(ctx, i, "/subdir") if err != nil { t.Fatal("") } }) t.Run("NilInstance", func(t *testing.T) { - _, err := WithInstance(nil, "") + _, err := WithInstance(ctx, nil, "") if err == nil { t.Fatal(err) } @@ -81,7 +85,7 @@ func TestWithInstance(t *testing.T) { t.Run("FailInit", func(t *testing.T) { i := testInstance(t) - _, err := WithInstance(i, "/fail") + _, err := WithInstance(ctx, i, "/fail") if err == nil { t.Fatal(err) } @@ -92,12 +96,12 @@ func TestWithInstance(t *testing.T) { createPkgerSubdir(t, i, "/") - d, err := WithInstance(i, "/") + d, err := WithInstance(ctx, i, "/") if err != nil { t.Fatal(err) } - if _, err := d.First(); !errors.Is(err, os.ErrNotExist) { + if _, err := d.First(ctx); !errors.Is(err, os.ErrNotExist) { t.Fatal(err) } @@ -105,23 +109,24 @@ func TestWithInstance(t *testing.T) { } func TestOpen(t *testing.T) { + ctx := context.Background() t.Run("InvalidURL", func(t *testing.T) { - _, err := (&Pkger{}).Open(":///") + _, err := (&Pkger{}).Open(ctx, ":///") if err == nil { t.Fatal(err) } }) t.Run("Root", func(t *testing.T) { - _, err := (&Pkger{}).Open("pkger:///") + _, err := (&Pkger{}).Open(ctx, "pkger:///") if err != nil { t.Fatal(err) } }) t.Run("FailInit", func(t *testing.T) { - _, err := (&Pkger{}).Open("pkger:///subdir") + _, err := (&Pkger{}).Open(ctx, "pkger:///subdir") if err == nil { t.Fatal(err) } @@ -136,7 +141,7 @@ func TestOpen(t *testing.T) { registerPackageLevelInstance(t, i) t.Run("Subdir", func(t *testing.T) { - _, err := (&Pkger{}).Open("pkger:///subdir") + _, err := (&Pkger{}).Open(ctx, "pkger:///subdir") if err != nil { t.Fatal(err) } @@ -144,11 +149,12 @@ func TestOpen(t *testing.T) { } func TestClose(t *testing.T) { - d, err := (&Pkger{}).Open("pkger:///") + ctx := context.Background() + d, err := (&Pkger{}).Open(ctx, "pkger:///") if err != nil { t.Fatal(err) } - if err := d.Close(); err != nil { + if err := d.Close(ctx); err != nil { t.Fatal(err) } } diff --git a/source/stub/stub.go b/source/stub/stub.go index ad2620ff7..8d6d061a0 100644 --- a/source/stub/stub.go +++ b/source/stub/stub.go @@ -2,6 +2,7 @@ package stub import ( "bytes" + "context" "fmt" "io" "os" @@ -25,7 +26,7 @@ type Stub struct { Config *Config } -func (s *Stub) Open(url string) (source.Driver, error) { +func (s *Stub) Open(ctx context.Context, url string) (source.Driver, error) { return &Stub{ Url: url, Migrations: source.NewMigrations(), @@ -33,7 +34,7 @@ func (s *Stub) Open(url string) (source.Driver, error) { }, nil } -func WithInstance(instance interface{}, config *Config) (source.Driver, error) { +func WithInstance(ctx context.Context, instance interface{}, config *Config) (source.Driver, error) { return &Stub{ Instance: instance, Migrations: source.NewMigrations(), @@ -41,42 +42,42 @@ func WithInstance(instance interface{}, config *Config) (source.Driver, error) { }, nil } -func (s *Stub) Close() error { +func (s *Stub) Close(ctx context.Context) error { return nil } -func (s *Stub) First() (version uint, err error) { - if v, ok := s.Migrations.First(); !ok { +func (s *Stub) First(ctx context.Context) (version uint, err error) { + if v, ok := s.Migrations.First(ctx); !ok { return 0, &os.PathError{Op: "first", Path: s.Url, Err: os.ErrNotExist} // TODO: s.Url can be empty when called with WithInstance } else { return v, nil } } -func (s *Stub) Prev(version uint) (prevVersion uint, err error) { - if v, ok := s.Migrations.Prev(version); !ok { +func (s *Stub) Prev(ctx context.Context, version uint) (prevVersion uint, err error) { + if v, ok := s.Migrations.Prev(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: s.Url, Err: os.ErrNotExist} } else { return v, nil } } -func (s *Stub) Next(version uint) (nextVersion uint, err error) { - if v, ok := s.Migrations.Next(version); !ok { +func (s *Stub) Next(ctx context.Context, version uint) (nextVersion uint, err error) { + if v, ok := s.Migrations.Next(ctx, version); !ok { return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: s.Url, Err: os.ErrNotExist} } else { return v, nil } } -func (s *Stub) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) { +func (s *Stub) ReadUp(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := s.Migrations.Up(version); ok { return io.NopCloser(bytes.NewBufferString(m.Identifier)), fmt.Sprintf("%v.up.stub", version), nil } return nil, "", &os.PathError{Op: fmt.Sprintf("read up version %v", version), Path: s.Url, Err: os.ErrNotExist} } -func (s *Stub) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) { +func (s *Stub) ReadDown(ctx context.Context, version uint) (r io.ReadCloser, identifier string, err error) { if m, ok := s.Migrations.Down(version); ok { return io.NopCloser(bytes.NewBufferString(m.Identifier)), fmt.Sprintf("%v.down.stub", version), nil } diff --git a/source/stub/stub_test.go b/source/stub/stub_test.go index 62f70ec15..b452b4c4a 100644 --- a/source/stub/stub_test.go +++ b/source/stub/stub_test.go @@ -1,6 +1,7 @@ package stub import ( + "context" "testing" "github.com/golang-migrate/migrate/v4/source" @@ -9,7 +10,7 @@ import ( func Test(t *testing.T) { s := &Stub{} - d, err := s.Open("") + d, err := s.Open(context.Background(), "") if err != nil { t.Fatal(err) } diff --git a/source/testing/testing.go b/source/testing/testing.go index 2b5505c4b..51f5391b5 100644 --- a/source/testing/testing.go +++ b/source/testing/testing.go @@ -4,6 +4,7 @@ package testing import ( + "context" "errors" "os" "testing" @@ -21,15 +22,16 @@ import ( // // See source/stub/stub_test.go or source/file/file_test.go for an example. func Test(t *testing.T, d source.Driver) { - TestFirst(t, d) - TestPrev(t, d) - TestNext(t, d) - TestReadUp(t, d) - TestReadDown(t, d) + ctx := context.Background() + TestFirst(t, ctx, d) + TestPrev(t, ctx, d) + TestNext(t, ctx, d) + TestReadUp(t, ctx, d) + TestReadDown(t, ctx, d) } -func TestFirst(t *testing.T, d source.Driver) { - version, err := d.First() +func TestFirst(t *testing.T, ctx context.Context, d source.Driver) { + version, err := d.First(ctx) if err != nil { t.Fatalf("First: expected err to be nil, got %v", err) } @@ -38,7 +40,7 @@ func TestFirst(t *testing.T, d source.Driver) { } } -func TestPrev(t *testing.T, d source.Driver) { +func TestPrev(t *testing.T, ctx context.Context, d source.Driver) { tt := []struct { version uint expectErr error @@ -57,7 +59,7 @@ func TestPrev(t *testing.T, d source.Driver) { } for i, v := range tt { - pv, err := d.Prev(v.version) + pv, err := d.Prev(ctx, v.version) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) && v.expectErr != err { t.Errorf("Prev: expected %v, got %v, in %v", v.expectErr, err, i) } @@ -67,7 +69,7 @@ func TestPrev(t *testing.T, d source.Driver) { } } -func TestNext(t *testing.T, d source.Driver) { +func TestNext(t *testing.T, ctx context.Context, d source.Driver) { tt := []struct { version uint expectErr error @@ -86,7 +88,7 @@ func TestNext(t *testing.T, d source.Driver) { } for i, v := range tt { - nv, err := d.Next(v.version) + nv, err := d.Next(ctx, v.version) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) && v.expectErr != err { t.Errorf("Next: expected %v, got %v, in %v", v.expectErr, err, i) } @@ -96,7 +98,7 @@ func TestNext(t *testing.T, d source.Driver) { } } -func TestReadUp(t *testing.T, d source.Driver) { +func TestReadUp(t *testing.T, ctx context.Context, d source.Driver) { tt := []struct { version uint expectErr error @@ -114,7 +116,7 @@ func TestReadUp(t *testing.T, d source.Driver) { } for i, v := range tt { - up, identifier, err := d.ReadUp(v.version) + up, identifier, err := d.ReadUp(ctx, v.version) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || (v.expectErr != os.ErrNotExist && err != v.expectErr) { t.Errorf("expected %v, got %v, in %v", v.expectErr, err, i) @@ -138,7 +140,7 @@ func TestReadUp(t *testing.T, d source.Driver) { } } -func TestReadDown(t *testing.T, d source.Driver) { +func TestReadDown(t *testing.T, ctx context.Context, d source.Driver) { tt := []struct { version uint expectErr error @@ -156,7 +158,7 @@ func TestReadDown(t *testing.T, d source.Driver) { } for i, v := range tt { - down, identifier, err := d.ReadDown(v.version) + down, identifier, err := d.ReadDown(ctx, v.version) if (v.expectErr == os.ErrNotExist && !errors.Is(err, os.ErrNotExist)) || (v.expectErr != os.ErrNotExist && err != v.expectErr) { t.Errorf("expected %v, got %v, in %v", v.expectErr, err, i)