@@ -2,17 +2,19 @@ package neo4j
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
5
6
"fmt"
6
- "golang.org/x/mod/semver"
7
7
"io"
8
8
neturl "net/url"
9
9
"strconv"
10
10
"sync/atomic"
11
11
12
+ "golang.org/x/mod/semver"
13
+
12
14
"github.com/golang-migrate/migrate/v4/database"
13
15
"github.com/golang-migrate/migrate/v4/database/multistmt"
14
16
"github.com/hashicorp/go-multierror"
15
- "github.com/neo4j/neo4j-go-driver/v4 /neo4j"
17
+ "github.com/neo4j/neo4j-go-driver/v5 /neo4j"
16
18
)
17
19
18
20
func init () {
@@ -38,14 +40,14 @@ type Config struct {
38
40
}
39
41
40
42
type Neo4j struct {
41
- driver neo4j.Driver
43
+ driver neo4j.DriverWithContext
42
44
lock uint32
43
45
44
46
// Open and WithInstance need to guarantee that config is never nil
45
47
config * Config
46
48
}
47
49
48
- func WithInstance (driver neo4j.Driver , config * Config ) (database.Driver , error ) {
50
+ func WithInstance (driver neo4j.DriverWithContext , config * Config ) (database.Driver , error ) {
49
51
if config == nil {
50
52
return nil , ErrNilConfig
51
53
}
@@ -70,31 +72,16 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
70
72
password , _ := uri .User .Password ()
71
73
authToken := neo4j .BasicAuth (uri .User .Username (), password , "" )
72
74
uri .User = nil
73
- uri .Scheme = "bolt"
74
75
msQuery := uri .Query ().Get ("x-multi-statement" )
75
76
76
- // Whether to turn on/off TLS encryption.
77
- tlsEncrypted := uri .Query ().Get ("x-tls-encrypted" )
78
77
multi := false
79
- encrypted := false
80
78
if msQuery != "" {
81
79
multi , err = strconv .ParseBool (uri .Query ().Get ("x-multi-statement" ))
82
80
if err != nil {
83
81
return nil , err
84
82
}
85
83
}
86
84
87
- if tlsEncrypted != "" {
88
- encrypted , err = strconv .ParseBool (tlsEncrypted )
89
- if err != nil {
90
- return nil , err
91
- }
92
- }
93
-
94
- if encrypted {
95
- uri .Scheme += "+s"
96
- }
97
-
98
85
multiStatementMaxSize := DefaultMultiStatementMaxSize
99
86
if s := uri .Query ().Get ("x-multi-statement-max-size" ); s != "" {
100
87
multiStatementMaxSize , err = strconv .Atoi (s )
@@ -105,11 +92,15 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
105
92
106
93
uri .RawQuery = ""
107
94
108
- driver , err := neo4j .NewDriver (uri .String (), authToken , func ( config * neo4j. Config ) {} )
95
+ driver , err := neo4j .NewDriverWithContext (uri .String (), authToken )
109
96
if err != nil {
110
97
return nil , err
111
98
}
112
99
100
+ if err = driver .VerifyConnectivity (context .Background ()); err != nil {
101
+ return nil , err
102
+ }
103
+
113
104
return WithInstance (driver , & Config {
114
105
MigrationsLabel : DefaultMigrationsLabel ,
115
106
MultiStatement : multi ,
@@ -118,7 +109,7 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
118
109
}
119
110
120
111
func (n * Neo4j ) Close () error {
121
- return n .driver .Close ()
112
+ return n .driver .Close (context . Background () )
122
113
}
123
114
124
115
// local locking in order to pass tests, Neo doesn't support database locking
@@ -138,60 +129,71 @@ func (n *Neo4j) Unlock() error {
138
129
}
139
130
140
131
func (n * Neo4j ) Run (migration io.Reader ) (err error ) {
141
- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
132
+ ctx := context .Background ()
133
+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
142
134
defer func () {
143
- if cerr := session .Close (); cerr != nil {
135
+ if cerr := session .Close (ctx ); cerr != nil {
144
136
err = multierror .Append (err , cerr )
145
137
}
146
138
}()
147
139
148
140
if n .config .MultiStatement {
149
- _ , err = session .WriteTransaction (func (transaction neo4j.Transaction ) (interface {}, error ) {
150
- var stmtRunErr error
151
- if err := multistmt .Parse (migration , StatementSeparator , n .config .MultiStatementMaxSize , func (stmt []byte ) bool {
152
- trimStmt := bytes .TrimSpace (stmt )
153
- if len (trimStmt ) == 0 {
154
- return true
155
- }
156
- trimStmt = bytes .TrimSuffix (trimStmt , StatementSeparator )
157
- if len (trimStmt ) == 0 {
158
- return true
159
- }
160
-
161
- result , err := transaction .Run (string (trimStmt ), nil )
162
- if _ , err := neo4j .Collect (result , err ); err != nil {
163
- stmtRunErr = err
164
- return false
165
- }
141
+ tx , err := session .BeginTransaction (ctx )
142
+ if err != nil {
143
+ return err
144
+ }
145
+ defer func () {
146
+ if cerr := tx .Close (ctx ); cerr != nil {
147
+ err = multierror .Append (err , cerr )
148
+ }
149
+ }()
150
+
151
+ var stmtRunErr error
152
+ if err := multistmt .Parse (migration , StatementSeparator , n .config .MultiStatementMaxSize , func (stmt []byte ) bool {
153
+ trimStmt := bytes .TrimSpace (stmt )
154
+ if len (trimStmt ) == 0 {
166
155
return true
167
- }); err != nil {
168
- return nil , err
169
156
}
170
- return nil , stmtRunErr
171
- })
172
- return err
157
+ trimStmt = bytes .TrimSuffix (trimStmt , StatementSeparator )
158
+ if len (trimStmt ) == 0 {
159
+ return true
160
+ }
161
+
162
+ result , err := tx .Run (ctx , string (trimStmt ), nil )
163
+ if _ , err := neo4j .CollectWithContext (ctx , result , err ); err != nil {
164
+ stmtRunErr = err
165
+ return false
166
+ }
167
+ return true
168
+ }); err != nil {
169
+ return err
170
+ }
171
+ return stmtRunErr
173
172
}
174
173
175
174
body , err := io .ReadAll (migration )
176
175
if err != nil {
177
176
return err
178
177
}
179
178
180
- _ , err = neo4j .Collect (session .Run (string (body [:]), nil ))
179
+ res , err := session .Run (ctx , string (body [:]), nil )
180
+ _ , err = neo4j .CollectWithContext (ctx , res , err )
181
181
return err
182
182
}
183
183
184
184
func (n * Neo4j ) SetVersion (version int , dirty bool ) (err error ) {
185
- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
185
+ ctx := context .Background ()
186
+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
186
187
defer func () {
187
- if cerr := session .Close (); cerr != nil {
188
+ if cerr := session .Close (ctx ); cerr != nil {
188
189
err = multierror .Append (err , cerr )
189
190
}
190
191
}()
191
192
192
193
query := fmt .Sprintf ("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()" ,
193
194
n .config .MigrationsLabel )
194
- _ , err = neo4j .Collect (session .Run (query , map [string ]interface {}{"version" : version , "dirty" : dirty }))
195
+ res , err := session .Run (ctx , query , map [string ]interface {}{"version" : version , "dirty" : dirty })
196
+ _ , err = neo4j .CollectWithContext (ctx , res , err )
195
197
if err != nil {
196
198
return err
197
199
}
@@ -204,75 +206,73 @@ type MigrationRecord struct {
204
206
}
205
207
206
208
func (n * Neo4j ) Version () (version int , dirty bool , err error ) {
207
- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeRead })
209
+ ctx := context .Background ()
210
+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeRead })
208
211
defer func () {
209
- if cerr := session .Close (); cerr != nil {
212
+ if cerr := session .Close (ctx ); cerr != nil {
210
213
err = multierror .Append (err , cerr )
211
214
}
212
215
}()
213
216
214
217
query := fmt .Sprintf (`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
215
218
ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1` ,
216
219
n .config .MigrationsLabel )
217
- result , err := session .ReadTransaction (func (transaction neo4j.Transaction ) (interface {}, error ) {
218
- result , err := transaction .Run (query , nil )
219
- if err != nil {
220
- return nil , err
221
- }
222
- if result .Next () {
223
- record := result .Record ()
224
- mr := MigrationRecord {}
225
- versionResult , ok := record .Get ("version" )
226
- if ! ok {
227
- mr .Version = database .NilVersion
228
- } else {
229
- mr .Version = int (versionResult .(int64 ))
230
- }
231
220
232
- dirtyResult , ok := record .Get ("dirty" )
233
- if ok {
234
- mr .Dirty = dirtyResult .(bool )
235
- }
221
+ tx , err := session .BeginTransaction (ctx )
236
222
237
- return mr , nil
238
- }
239
- return nil , result .Err ()
240
- })
223
+ result , err := tx .Run (ctx , query , nil )
241
224
if err != nil {
242
225
return database .NilVersion , false , err
243
226
}
244
- if result == nil {
245
- return database .NilVersion , false , err
227
+ if result .Next (ctx ) {
228
+ record := result .Record ()
229
+ mr := MigrationRecord {}
230
+ versionResult , ok := record .Get ("version" )
231
+ if ! ok {
232
+ mr .Version = database .NilVersion
233
+ } else {
234
+ mr .Version = int (versionResult .(int64 ))
235
+ }
236
+
237
+ dirtyResult , ok := record .Get ("dirty" )
238
+ if ok {
239
+ mr .Dirty = dirtyResult .(bool )
240
+ }
241
+
242
+ return mr .Version , mr .Dirty , nil
246
243
}
247
- mr := result .( MigrationRecord )
248
- return mr . Version , mr . Dirty , err
244
+
245
+ return database . NilVersion , false , err
249
246
}
250
247
251
248
func (n * Neo4j ) Drop () (err error ) {
252
- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
249
+ ctx := context .Background ()
250
+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
253
251
defer func () {
254
- if cerr := session .Close (); cerr != nil {
252
+ if cerr := session .Close (ctx ); cerr != nil {
255
253
err = multierror .Append (err , cerr )
256
254
}
257
255
}()
258
256
259
- if _ , err := neo4j .Collect (session .Run ("MATCH (n) DETACH DELETE n" , nil )); err != nil {
257
+ res , err := session .Run (ctx , "MATCH (n) DETACH DELETE n" , nil )
258
+ if _ , err := neo4j .CollectWithContext (ctx , res , err ); err != nil {
260
259
return err
261
260
}
262
261
return nil
263
262
}
264
263
265
264
func (n * Neo4j ) ensureVersionConstraint () (err error ) {
266
- session := n .driver .NewSession (neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
265
+ ctx := context .Background ()
266
+ session := n .driver .NewSession (ctx , neo4j.SessionConfig {AccessMode : neo4j .AccessModeWrite })
267
267
defer func () {
268
- if cerr := session .Close (); cerr != nil {
268
+ if cerr := session .Close (ctx ); cerr != nil {
269
269
err = multierror .Append (err , cerr )
270
270
}
271
271
}()
272
272
273
273
var neo4jVersion string
274
-
275
- res , err := neo4j .Collect ( session . Run ( "call dbms.components() yield versions unwind versions as version return version" , nil ) )
274
+ result , err := session . Run ( ctx , "call dbms.components() yield versions unwind versions as version return version" , nil )
275
+ res , err := neo4j .CollectWithContext ( ctx , result , err )
276
276
if err != nil {
277
277
return err
278
278
}
@@ -287,7 +287,8 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
287
287
using db.labels() to support Neo4j 3 and 4.
288
288
Neo4J 3 doesn't support db.constraints() YIELD name
289
289
*/
290
- res , err = neo4j .Collect (session .Run (fmt .Sprintf ("CALL db.labels() YIELD label WHERE label=\" %s\" RETURN label" , n .config .MigrationsLabel ), nil ))
290
+ result , err = session .Run (ctx , fmt .Sprintf ("CALL db.labels() YIELD label WHERE label=\" %s\" RETURN label" , n .config .MigrationsLabel ), nil )
291
+ res , err = neo4j .CollectWithContext (ctx , result , err )
291
292
if err != nil {
292
293
return err
293
294
}
@@ -299,13 +300,14 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
299
300
switch neo4jVersion {
300
301
case "v5" :
301
302
query = fmt .Sprintf ("CREATE CONSTRAINT FOR (a:%s) REQUIRE a.version IS UNIQUE" , n .config .MigrationsLabel )
302
- case "v3" , " v4" :
303
+ case "v4" :
303
304
query = fmt .Sprintf ("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE" , n .config .MigrationsLabel )
304
305
default :
305
306
return fmt .Errorf ("unsupported neo4j version %v" , neo4jVersion )
306
307
}
307
308
308
- if _ , err := neo4j .Collect (session .Run (query , nil )); err != nil {
309
+ result , err = session .Run (ctx , query , nil )
310
+ if _ , err := neo4j .CollectWithContext (ctx , result , err ); err != nil {
309
311
return err
310
312
}
311
313
return nil
0 commit comments