Skip to content

Commit edc32df

Browse files
committed
chore: add context.Context everywhere
1 parent d28e549 commit edc32df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1380
-1182
lines changed

database/cassandra/cassandra.go

+16-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cassandra
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"io"
@@ -52,7 +53,7 @@ type Cassandra struct {
5253
config *Config
5354
}
5455

55-
func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
56+
func WithInstance(ctx context.Context, session *gocql.Session, config *Config) (database.Driver, error) {
5657
if config == nil {
5758
return nil, ErrNilConfig
5859
} else if len(config.KeyspaceName) == 0 {
@@ -76,14 +77,14 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
7677
config: config,
7778
}
7879

79-
if err := c.ensureVersionTable(); err != nil {
80+
if err := c.ensureVersionTable(ctx); err != nil {
8081
return nil, err
8182
}
8283

8384
return c, nil
8485
}
8586

86-
func (c *Cassandra) Open(url string) (database.Driver, error) {
87+
func (c *Cassandra) Open(ctx context.Context, url string) (database.Driver, error) {
8788
u, err := nurl.Parse(url)
8889
if err != nil {
8990
return nil, err
@@ -185,34 +186,34 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
185186
}
186187
}
187188

188-
return WithInstance(session, &Config{
189+
return WithInstance(ctx, session, &Config{
189190
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
190191
MigrationsTable: u.Query().Get("x-migrations-table"),
191192
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
192193
MultiStatementMaxSize: multiStatementMaxSize,
193194
})
194195
}
195196

196-
func (c *Cassandra) Close() error {
197+
func (c *Cassandra) Close(ctx context.Context) error {
197198
c.session.Close()
198199
return nil
199200
}
200201

201-
func (c *Cassandra) Lock() error {
202+
func (c *Cassandra) Lock(ctx context.Context) error {
202203
if !c.isLocked.CAS(false, true) {
203204
return database.ErrLocked
204205
}
205206
return nil
206207
}
207208

208-
func (c *Cassandra) Unlock() error {
209+
func (c *Cassandra) Unlock(ctx context.Context) error {
209210
if !c.isLocked.CAS(true, false) {
210211
return database.ErrNotLocked
211212
}
212213
return nil
213214
}
214215

215-
func (c *Cassandra) Run(migration io.Reader) error {
216+
func (c *Cassandra) Run(ctx context.Context, migration io.Reader) error {
216217
if c.config.MultiStatementEnabled {
217218
var err error
218219
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
@@ -243,7 +244,7 @@ func (c *Cassandra) Run(migration io.Reader) error {
243244
return nil
244245
}
245246

246-
func (c *Cassandra) SetVersion(version int, dirty bool) error {
247+
func (c *Cassandra) SetVersion(ctx context.Context, version int, dirty bool) error {
247248
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
248249
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
249250
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
@@ -273,7 +274,7 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
273274
}
274275

275276
// Return current keyspace version
276-
func (c *Cassandra) Version() (version int, dirty bool, err error) {
277+
func (c *Cassandra) Version(ctx context.Context) (version int, dirty bool, err error) {
277278
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
278279
err = c.session.Query(query).Scan(&version, &dirty)
279280
switch {
@@ -291,7 +292,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) {
291292
}
292293
}
293294

294-
func (c *Cassandra) Drop() error {
295+
func (c *Cassandra) Drop(ctx context.Context) error {
295296
// select all tables in current schema
296297
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
297298
iter := c.session.Query(query).Iter()
@@ -309,13 +310,13 @@ func (c *Cassandra) Drop() error {
309310
// ensureVersionTable checks if versions table exists and, if not, creates it.
310311
// Note that this function locks the database, which deviates from the usual
311312
// convention of "caller locks" in the Cassandra type.
312-
func (c *Cassandra) ensureVersionTable() (err error) {
313-
if err = c.Lock(); err != nil {
313+
func (c *Cassandra) ensureVersionTable(ctx context.Context) (err error) {
314+
if err = c.Lock(ctx); err != nil {
314315
return err
315316
}
316317

317318
defer func() {
318-
if e := c.Unlock(); e != nil {
319+
if e := c.Unlock(ctx); e != nil {
319320
if err == nil {
320321
err = e
321322
} else {
@@ -328,7 +329,7 @@ func (c *Cassandra) ensureVersionTable() (err error) {
328329
if err != nil {
329330
return err
330331
}
331-
if _, _, err = c.Version(); err != nil {
332+
if _, _, err = c.Version(ctx); err != nil {
332333
return err
333334
}
334335
return nil

database/cassandra/cassandra_test.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,19 @@ func Test(t *testing.T) {
7676

7777
func test(t *testing.T) {
7878
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
79+
ctx := context.Background()
7980
ip, port, err := c.Port(9042)
8081
if err != nil {
8182
t.Fatal("Unable to get mapped port:", err)
8283
}
8384
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
8485
p := &Cassandra{}
85-
d, err := p.Open(addr)
86+
d, err := p.Open(ctx, addr)
8687
if err != nil {
8788
t.Fatal(err)
8889
}
8990
defer func() {
90-
if err := d.Close(); err != nil {
91+
if err := d.Close(ctx); err != nil {
9192
t.Error(err)
9293
}
9394
}()
@@ -97,23 +98,24 @@ func test(t *testing.T) {
9798

9899
func testMigrate(t *testing.T) {
99100
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
101+
ctx := context.Background()
100102
ip, port, err := c.Port(9042)
101103
if err != nil {
102104
t.Fatal("Unable to get mapped port:", err)
103105
}
104106
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
105107
p := &Cassandra{}
106-
d, err := p.Open(addr)
108+
d, err := p.Open(ctx, addr)
107109
if err != nil {
108110
t.Fatal(err)
109111
}
110112
defer func() {
111-
if err := d.Close(); err != nil {
113+
if err := d.Close(ctx); err != nil {
112114
t.Error(err)
113115
}
114116
}()
115117

116-
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
118+
m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "testks", d)
117119
if err != nil {
118120
t.Fatal(err)
119121
}

database/clickhouse/clickhouse.go

+17-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package clickhouse
22

33
import (
4+
"context"
45
"database/sql"
56
"fmt"
67
"io"
@@ -40,7 +41,7 @@ func init() {
4041
database.Register("clickhouse", &ClickHouse{})
4142
}
4243

43-
func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
44+
func WithInstance(ctx context.Context, conn *sql.DB, config *Config) (database.Driver, error) {
4445
if config == nil {
4546
return nil, ErrNilConfig
4647
}
@@ -54,7 +55,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
5455
config: config,
5556
}
5657

57-
if err := ch.init(); err != nil {
58+
if err := ch.init(ctx); err != nil {
5859
return nil, err
5960
}
6061

@@ -67,7 +68,7 @@ type ClickHouse struct {
6768
isLocked atomic.Bool
6869
}
6970

70-
func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
71+
func (ch *ClickHouse) Open(ctx context.Context, dsn string) (database.Driver, error) {
7172
purl, err := url.Parse(dsn)
7273
if err != nil {
7374
return nil, err
@@ -104,14 +105,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
104105
},
105106
}
106107

107-
if err := ch.init(); err != nil {
108+
if err := ch.init(ctx); err != nil {
108109
return nil, err
109110
}
110111

111112
return ch, nil
112113
}
113114

114-
func (ch *ClickHouse) init() error {
115+
func (ch *ClickHouse) init(ctx context.Context) error {
115116
if len(ch.config.DatabaseName) == 0 {
116117
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
117118
return err
@@ -130,10 +131,10 @@ func (ch *ClickHouse) init() error {
130131
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
131132
}
132133

133-
return ch.ensureVersionTable()
134+
return ch.ensureVersionTable(ctx)
134135
}
135136

136-
func (ch *ClickHouse) Run(r io.Reader) error {
137+
func (ch *ClickHouse) Run(ctx context.Context, r io.Reader) error {
137138
if ch.config.MultiStatementEnabled {
138139
var err error
139140
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
@@ -163,7 +164,7 @@ func (ch *ClickHouse) Run(r io.Reader) error {
163164

164165
return nil
165166
}
166-
func (ch *ClickHouse) Version() (int, bool, error) {
167+
func (ch *ClickHouse) Version(ctx context.Context) (int, bool, error) {
167168
var (
168169
version int
169170
dirty uint8
@@ -178,7 +179,7 @@ func (ch *ClickHouse) Version() (int, bool, error) {
178179
return version, dirty == 1, nil
179180
}
180181

181-
func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
182+
func (ch *ClickHouse) SetVersion(ctx context.Context, version int, dirty bool) error {
182183
var (
183184
bool = func(v bool) uint8 {
184185
if v {
@@ -203,13 +204,13 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
203204
// ensureVersionTable checks if versions table exists and, if not, creates it.
204205
// Note that this function locks the database, which deviates from the usual
205206
// convention of "caller locks" in the ClickHouse type.
206-
func (ch *ClickHouse) ensureVersionTable() (err error) {
207-
if err = ch.Lock(); err != nil {
207+
func (ch *ClickHouse) ensureVersionTable(ctx context.Context) (err error) {
208+
if err = ch.Lock(ctx); err != nil {
208209
return err
209210
}
210211

211212
defer func() {
212-
if e := ch.Unlock(); e != nil {
213+
if e := ch.Unlock(ctx); e != nil {
213214
if err == nil {
214215
err = e
215216
} else {
@@ -258,7 +259,7 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
258259
return nil
259260
}
260261

261-
func (ch *ClickHouse) Drop() (err error) {
262+
func (ch *ClickHouse) Drop(ctx context.Context) (err error) {
262263
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
263264
tables, err := ch.conn.Query(query)
264265

@@ -290,21 +291,21 @@ func (ch *ClickHouse) Drop() (err error) {
290291
return nil
291292
}
292293

293-
func (ch *ClickHouse) Lock() error {
294+
func (ch *ClickHouse) Lock(ctx context.Context) error {
294295
if !ch.isLocked.CAS(false, true) {
295296
return database.ErrLocked
296297
}
297298

298299
return nil
299300
}
300-
func (ch *ClickHouse) Unlock() error {
301+
func (ch *ClickHouse) Unlock(ctx context.Context) error {
301302
if !ch.isLocked.CAS(true, false) {
302303
return database.ErrNotLocked
303304
}
304305

305306
return nil
306307
}
307-
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
308+
func (ch *ClickHouse) Close(ctx context.Context) error { return ch.conn.Close() }
308309

309310
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
310311
func quoteIdentifier(name string) string {

0 commit comments

Comments
 (0)