Skip to content

Commit 57aead3

Browse files
authored
Merge pull request #659 from AndreasKl/add-with-connection-to-postgres
Add WithConnection to Postgres similar to MySQL.
2 parents e1d604b + 3dfae0d commit 57aead3

File tree

2 files changed

+70
-14
lines changed

2 files changed

+70
-14
lines changed

database/postgres/postgres.go

+32-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"context"
88
"database/sql"
99
"fmt"
10-
"go.uber.org/atomic"
1110
"io"
1211
"io/ioutil"
1312
nurl "net/url"
@@ -16,10 +15,12 @@ import (
1615
"strings"
1716
"time"
1817

18+
"go.uber.org/atomic"
19+
1920
"github.com/golang-migrate/migrate/v4"
2021
"github.com/golang-migrate/migrate/v4/database"
2122
"github.com/golang-migrate/migrate/v4/database/multistmt"
22-
multierror "github.com/hashicorp/go-multierror"
23+
"github.com/hashicorp/go-multierror"
2324
"github.com/lib/pq"
2425
)
2526

@@ -65,19 +66,19 @@ type Postgres struct {
6566
config *Config
6667
}
6768

68-
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
69+
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
6970
if config == nil {
7071
return nil, ErrNilConfig
7172
}
7273

73-
if err := instance.Ping(); err != nil {
74+
if err := conn.PingContext(ctx); err != nil {
7475
return nil, err
7576
}
7677

7778
if config.DatabaseName == "" {
7879
query := `SELECT CURRENT_DATABASE()`
7980
var databaseName string
80-
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
81+
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
8182
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
8283
}
8384

@@ -91,7 +92,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
9192
if config.SchemaName == "" {
9293
query := `SELECT CURRENT_SCHEMA()`
9394
var schemaName string
94-
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
95+
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
9596
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
9697
}
9798

@@ -119,15 +120,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
119120
}
120121
}
121122

122-
conn, err := instance.Conn(context.Background())
123-
124-
if err != nil {
125-
return nil, err
126-
}
127-
128123
px := &Postgres{
129124
conn: conn,
130-
db: instance,
131125
config: config,
132126
}
133127

@@ -138,6 +132,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
138132
return px, nil
139133
}
140134

135+
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
136+
ctx := context.Background()
137+
138+
if err := instance.Ping(); err != nil {
139+
return nil, err
140+
}
141+
142+
conn, err := instance.Conn(ctx)
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
px, err := WithConnection(ctx, conn, config)
148+
if err != nil {
149+
return nil, err
150+
}
151+
px.db = instance
152+
return px, nil
153+
}
154+
141155
func (p *Postgres) Open(url string) (database.Driver, error) {
142156
purl, err := nurl.Parse(url)
143157
if err != nil {
@@ -207,7 +221,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
207221

208222
func (p *Postgres) Close() error {
209223
connErr := p.conn.Close()
210-
dbErr := p.db.Close()
224+
var dbErr error
225+
if p.db != nil {
226+
dbErr = p.db.Close()
227+
}
228+
211229
if connErr != nil || dbErr != nil {
212230
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
213231
}

database/postgres/postgres_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,44 @@ func TestWithInstance_Concurrent(t *testing.T) {
684684
}
685685
})
686686
}
687+
688+
func TestWithConnection(t *testing.T) {
689+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
690+
ip, port, err := c.FirstPort()
691+
if err != nil {
692+
t.Fatal(err)
693+
}
694+
695+
db, err := sql.Open("postgres", pgConnectionString(ip, port))
696+
if err != nil {
697+
t.Fatal(err)
698+
}
699+
defer func() {
700+
if err := db.Close(); err != nil {
701+
t.Error(err)
702+
}
703+
}()
704+
705+
ctx := context.Background()
706+
conn, err := db.Conn(ctx)
707+
if err != nil {
708+
t.Fatal(err)
709+
}
710+
711+
p, err := WithConnection(ctx, conn, &Config{})
712+
if err != nil {
713+
t.Fatal(err)
714+
}
715+
716+
defer func() {
717+
if err := p.Close(); err != nil {
718+
t.Error(err)
719+
}
720+
}()
721+
dt.Test(t, p, []byte("SELECT 1"))
722+
})
723+
}
724+
687725
func Test_computeLineFromPos(t *testing.T) {
688726
testcases := []struct {
689727
pos int

0 commit comments

Comments
 (0)