Skip to content

Commit

Permalink
chore: add context.Context everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
joschi committed Jan 27, 2025
1 parent d477553 commit e9f5c08
Show file tree
Hide file tree
Showing 81 changed files with 1,602 additions and 1,404 deletions.
31 changes: 16 additions & 15 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cassandra

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -52,7 +53,7 @@ type Cassandra struct {
config *Config
}

func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if len(config.KeyspaceName) == 0 {
Expand All @@ -76,14 +77,14 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
config: config,
}

if err := c.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(ctx); err != nil {
return nil, err
}

return c, nil
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
func (c *Cassandra) Open(ctx context.Context, url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
Expand Down Expand Up @@ -185,34 +186,34 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
}
}

return WithInstance(session, &Config{
return WithInstance(ctx, session, &Config{
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
MigrationsTable: u.Query().Get("x-migrations-table"),
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
})
}

func (c *Cassandra) Close() error {
func (c *Cassandra) Close(ctx context.Context) error {
c.session.Close()
return nil
}

func (c *Cassandra) Lock() error {
func (c *Cassandra) Lock(ctx context.Context) error {
if !c.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}

func (c *Cassandra) Unlock() error {
func (c *Cassandra) Unlock(ctx context.Context) error {
if !c.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}

func (c *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(ctx context.Context, migration io.Reader) error {
if c.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
Expand Down Expand Up @@ -243,7 +244,7 @@ func (c *Cassandra) Run(migration io.Reader) error {
return nil
}

func (c *Cassandra) SetVersion(version int, dirty bool) error {
func (c *Cassandra) SetVersion(ctx context.Context, version int, dirty bool) error {
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
Expand Down Expand Up @@ -273,7 +274,7 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
}

// Return current keyspace version
func (c *Cassandra) Version() (version int, dirty bool, err error) {
func (c *Cassandra) Version(ctx context.Context) (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
Expand All @@ -291,7 +292,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) {
}
}

func (c *Cassandra) Drop() error {
func (c *Cassandra) Drop(ctx context.Context) error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
iter := c.session.Query(query).Iter()
Expand All @@ -309,13 +310,13 @@ func (c *Cassandra) Drop() error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Cassandra type.
func (c *Cassandra) ensureVersionTable() (err error) {
if err = c.Lock(); err != nil {
func (c *Cassandra) ensureVersionTable(ctx context.Context) (err error) {
if err = c.Lock(ctx); err != nil {
return err
}

defer func() {
if e := c.Unlock(); e != nil {
if e := c.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand All @@ -328,7 +329,7 @@ func (c *Cassandra) ensureVersionTable() (err error) {
if err != nil {
return err
}
if _, _, err = c.Version(); err != nil {
if _, _, err = c.Version(ctx); err != nil {
return err
}
return nil
Expand Down
12 changes: 7 additions & 5 deletions database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ func Test(t *testing.T) {

func test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()
Expand All @@ -97,23 +98,24 @@ func test(t *testing.T) {

func testMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()

m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "testks", d)
if err != nil {
t.Fatal(err)
}
Expand Down
47 changes: 24 additions & 23 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"io"
Expand Down Expand Up @@ -40,7 +41,7 @@ func init() {
database.Register("clickhouse", &ClickHouse{})
}

func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, conn *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
Expand All @@ -54,7 +55,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

Expand All @@ -67,7 +68,7 @@ type ClickHouse struct {
isLocked atomic.Bool
}

func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
func (ch *ClickHouse) Open(ctx context.Context, dsn string) (database.Driver, error) {
purl, err := url.Parse(dsn)
if err != nil {
return nil, err
Expand Down Expand Up @@ -104,14 +105,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
},
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

return ch, nil
}

func (ch *ClickHouse) init() error {
func (ch *ClickHouse) init(ctx context.Context) error {
if len(ch.config.DatabaseName) == 0 {
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
return err
Expand All @@ -130,18 +131,18 @@ func (ch *ClickHouse) init() error {
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
}

return ch.ensureVersionTable()
return ch.ensureVersionTable(ctx)
}

func (ch *ClickHouse) Run(r io.Reader) error {
func (ch *ClickHouse) Run(ctx context.Context, r io.Reader) error {
if ch.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
tq := strings.TrimSpace(string(m))
if tq == "" {
return true
}
if _, e := ch.conn.Exec(string(m)); e != nil {
if _, e := ch.conn.ExecContext(ctx, string(m)); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
Expand All @@ -157,13 +158,13 @@ func (ch *ClickHouse) Run(r io.Reader) error {
return err
}

if _, err := ch.conn.Exec(string(migration)); err != nil {
if _, err := ch.conn.ExecContext(ctx, string(migration)); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
}

return nil
}
func (ch *ClickHouse) Version() (int, bool, error) {
func (ch *ClickHouse) Version(ctx context.Context) (int, bool, error) {
var (
version int
dirty uint8
Expand All @@ -178,22 +179,22 @@ func (ch *ClickHouse) Version() (int, bool, error) {
return version, dirty == 1, nil
}

func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
func (ch *ClickHouse) SetVersion(ctx context.Context, version int, dirty bool) error {
var (
bool = func(v bool) uint8 {
if v {
return 1
}
return 0
}
tx, err = ch.conn.Begin()
tx, err = ch.conn.BeginTx(ctx, nil)
)
if err != nil {
return err
}

query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
if _, err := tx.ExecContext(ctx, query, version, bool(dirty), time.Now().UnixNano()); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -203,13 +204,13 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the ClickHouse type.
func (ch *ClickHouse) ensureVersionTable() (err error) {
if err = ch.Lock(); err != nil {
func (ch *ClickHouse) ensureVersionTable(ctx context.Context) (err error) {
if err = ch.Lock(ctx); err != nil {
return err
}

defer func() {
if e := ch.Unlock(); e != nil {
if e := ch.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand Down Expand Up @@ -252,15 +253,15 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
query = fmt.Sprintf(`%s ORDER BY sequence`, query)
}

if _, err := ch.conn.Exec(query); err != nil {
if _, err := ch.conn.ExecContext(ctx, query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}

func (ch *ClickHouse) Drop() (err error) {
func (ch *ClickHouse) Drop(ctx context.Context) (err error) {
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
tables, err := ch.conn.Query(query)
tables, err := ch.conn.QueryContext(ctx, query)

if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -279,7 +280,7 @@ func (ch *ClickHouse) Drop() (err error) {

query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table)

if _, err := ch.conn.Exec(query); err != nil {
if _, err := ch.conn.ExecContext(ctx, query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand All @@ -290,21 +291,21 @@ func (ch *ClickHouse) Drop() (err error) {
return nil
}

func (ch *ClickHouse) Lock() error {
func (ch *ClickHouse) Lock(ctx context.Context) error {
if !ch.isLocked.CAS(false, true) {
return database.ErrLocked
}

return nil
}
func (ch *ClickHouse) Unlock() error {
func (ch *ClickHouse) Unlock(ctx context.Context) error {
if !ch.isLocked.CAS(true, false) {
return database.ErrNotLocked
}

return nil
}
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
func (ch *ClickHouse) Close(ctx context.Context) error { return ch.conn.Close() }

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
Expand Down
Loading

0 comments on commit e9f5c08

Please sign in to comment.