7
7
"context"
8
8
"database/sql"
9
9
"fmt"
10
- "go.uber.org/atomic"
11
10
"io"
12
11
"io/ioutil"
13
12
nurl "net/url"
@@ -16,10 +15,12 @@ import (
16
15
"strings"
17
16
"time"
18
17
18
+ "go.uber.org/atomic"
19
+
19
20
"github.com/golang-migrate/migrate/v4"
20
21
"github.com/golang-migrate/migrate/v4/database"
21
22
"github.com/golang-migrate/migrate/v4/database/multistmt"
22
- multierror "github.com/hashicorp/go-multierror"
23
+ "github.com/hashicorp/go-multierror"
23
24
"github.com/lib/pq"
24
25
)
25
26
@@ -65,19 +66,19 @@ type Postgres struct {
65
66
config * Config
66
67
}
67
68
68
- func WithInstance ( instance * sql.DB , config * Config ) (database. Driver , error ) {
69
+ func WithConnection ( ctx context. Context , conn * sql.Conn , config * Config ) (* Postgres , error ) {
69
70
if config == nil {
70
71
return nil , ErrNilConfig
71
72
}
72
73
73
- if err := instance . Ping ( ); err != nil {
74
+ if err := conn . PingContext ( ctx ); err != nil {
74
75
return nil , err
75
76
}
76
77
77
78
if config .DatabaseName == "" {
78
79
query := `SELECT CURRENT_DATABASE()`
79
80
var databaseName string
80
- if err := instance . QueryRow ( query ).Scan (& databaseName ); err != nil {
81
+ if err := conn . QueryRowContext ( ctx , query ).Scan (& databaseName ); err != nil {
81
82
return nil , & database.Error {OrigErr : err , Query : []byte (query )}
82
83
}
83
84
@@ -91,7 +92,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
91
92
if config .SchemaName == "" {
92
93
query := `SELECT CURRENT_SCHEMA()`
93
94
var schemaName string
94
- if err := instance . QueryRow ( query ).Scan (& schemaName ); err != nil {
95
+ if err := conn . QueryRowContext ( ctx , query ).Scan (& schemaName ); err != nil {
95
96
return nil , & database.Error {OrigErr : err , Query : []byte (query )}
96
97
}
97
98
@@ -119,15 +120,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
119
120
}
120
121
}
121
122
122
- conn , err := instance .Conn (context .Background ())
123
-
124
- if err != nil {
125
- return nil , err
126
- }
127
-
128
123
px := & Postgres {
129
124
conn : conn ,
130
- db : instance ,
131
125
config : config ,
132
126
}
133
127
@@ -138,6 +132,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
138
132
return px , nil
139
133
}
140
134
135
+ func WithInstance (instance * sql.DB , config * Config ) (database.Driver , error ) {
136
+ ctx := context .Background ()
137
+
138
+ if err := instance .Ping (); err != nil {
139
+ return nil , err
140
+ }
141
+
142
+ conn , err := instance .Conn (ctx )
143
+ if err != nil {
144
+ return nil , err
145
+ }
146
+
147
+ px , err := WithConnection (ctx , conn , config )
148
+ if err != nil {
149
+ return nil , err
150
+ }
151
+ px .db = instance
152
+ return px , nil
153
+ }
154
+
141
155
func (p * Postgres ) Open (url string ) (database.Driver , error ) {
142
156
purl , err := nurl .Parse (url )
143
157
if err != nil {
@@ -207,7 +221,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
207
221
208
222
func (p * Postgres ) Close () error {
209
223
connErr := p .conn .Close ()
210
- dbErr := p .db .Close ()
224
+ var dbErr error
225
+ if p .db != nil {
226
+ dbErr = p .db .Close ()
227
+ }
228
+
211
229
if connErr != nil || dbErr != nil {
212
230
return fmt .Errorf ("conn: %v, db: %v" , connErr , dbErr )
213
231
}
0 commit comments