Skip to content

Commit 8f9b789

Browse files
committed
feat: upgraded Neo4j driver to v5
1 parent 606fc17 commit 8f9b789

File tree

7 files changed

+125
-167
lines changed

7 files changed

+125
-167
lines changed

database/neo4j/README.md

+26-13
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
# neo4j
2-
The Neo4j driver (bolt) does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param.
3-
This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
4-
The queries **should** run in a single transaction, so partial migrations should not be a concern, but this is untested.
2+
The Neo4j driver (bolt) does not natively support executing multiple statements
3+
in a single query. To allow for multiple statements in a single migration, you
4+
can use the `x-multi-statement` param. This mode splits the migration text into
5+
separately-executed statements by a semicolon `;`. Thus `x-multi-statement`
6+
cannot be used when a statement in the migration contains a string with a
7+
semicolon. The queries **should** run in a single transaction, so partial
8+
migrations should not be a concern, but this is untested.
59

10+
Here are possible connection URLs:
611

7-
`neo4j://user:password@host:port/`
12+
- `neo4j://user:password@host:port/`
13+
- `neo4j+s://user:password@host:port/`
14+
- `neo4j+ssc://user:password@host:port/`
15+
- `bolt://user:password@host:port/`
16+
- `bolt+s://user:password@host:port/`
17+
- `bolt+ssc://user:password@host:port/`
18+
19+
| URL Query | WithInstance Config | Description |
20+
|---------------------|-------------------------------|------------------------------------------------------------------------------------------------------|
21+
| `x-multi-statement` | `MultiStatement` | Enable multiple statements to be ran in a single migration (See note above) |
22+
| `user` | Contained within `AuthConfig` | The user to sign in as |
23+
| `password` | Contained within `AuthConfig` | The user's password |
24+
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
25+
| `port` | | The port to bind to. (default is 7687) |
26+
| | `MigrationsLabel` | Name of the migrations node label |
827

9-
| URL Query | WithInstance Config | Description |
10-
|------------|---------------------|-------------|
11-
| `x-multi-statement` | `MultiStatement` | Enable multiple statements to be ran in a single migration (See note above) |
12-
| `user` | Contained within `AuthConfig` | The user to sign in as |
13-
| `password` | Contained within `AuthConfig` | The user's password |
14-
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
15-
| `port` | | The port to bind to. (default is 7687) |
16-
| | `MigrationsLabel` | Name of the migrations node label |
1728

1829
## Supported versions
1930

20-
Only Neo4j v3.5+ is [supported](https://github.com/neo4j/neo4j-go-driver/issues/64#issuecomment-625133600)
31+
Neo4j v4.4 LTS and v5+ is supported.
32+
33+
Make sure to check [End Of Life dates](https://neo4j.com/developer/kb/neo4j-supported-versions/) of Neo4j versions.
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
DROP CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
1+
DROP CONSTRAINT FOR (m:Movie) REQUIRE m.Name IS UNIQUE
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CREATE CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
1+
CREATE CONSTRAINT FOR (m:Movie) REQUIRE m.Name IS UNIQUE

database/neo4j/neo4j.go

+89-87
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@ package neo4j
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
6-
"golang.org/x/mod/semver"
77
"io"
88
neturl "net/url"
99
"strconv"
1010
"sync/atomic"
1111

12+
"golang.org/x/mod/semver"
13+
1214
"github.com/golang-migrate/migrate/v4/database"
1315
"github.com/golang-migrate/migrate/v4/database/multistmt"
1416
"github.com/hashicorp/go-multierror"
15-
"github.com/neo4j/neo4j-go-driver/v4/neo4j"
17+
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
1618
)
1719

1820
func init() {
@@ -38,14 +40,14 @@ type Config struct {
3840
}
3941

4042
type Neo4j struct {
41-
driver neo4j.Driver
43+
driver neo4j.DriverWithContext
4244
lock uint32
4345

4446
// Open and WithInstance need to guarantee that config is never nil
4547
config *Config
4648
}
4749

48-
func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
50+
func WithInstance(driver neo4j.DriverWithContext, config *Config) (database.Driver, error) {
4951
if config == nil {
5052
return nil, ErrNilConfig
5153
}
@@ -70,31 +72,16 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
7072
password, _ := uri.User.Password()
7173
authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
7274
uri.User = nil
73-
uri.Scheme = "bolt"
7475
msQuery := uri.Query().Get("x-multi-statement")
7576

76-
// Whether to turn on/off TLS encryption.
77-
tlsEncrypted := uri.Query().Get("x-tls-encrypted")
7877
multi := false
79-
encrypted := false
8078
if msQuery != "" {
8179
multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
8280
if err != nil {
8381
return nil, err
8482
}
8583
}
8684

87-
if tlsEncrypted != "" {
88-
encrypted, err = strconv.ParseBool(tlsEncrypted)
89-
if err != nil {
90-
return nil, err
91-
}
92-
}
93-
94-
if encrypted {
95-
uri.Scheme += "+s"
96-
}
97-
9885
multiStatementMaxSize := DefaultMultiStatementMaxSize
9986
if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
10087
multiStatementMaxSize, err = strconv.Atoi(s)
@@ -105,11 +92,15 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
10592

10693
uri.RawQuery = ""
10794

108-
driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {})
95+
driver, err := neo4j.NewDriverWithContext(uri.String(), authToken)
10996
if err != nil {
11097
return nil, err
11198
}
11299

100+
if err = driver.VerifyConnectivity(context.Background()); err != nil {
101+
return nil, err
102+
}
103+
113104
return WithInstance(driver, &Config{
114105
MigrationsLabel: DefaultMigrationsLabel,
115106
MultiStatement: multi,
@@ -118,7 +109,7 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
118109
}
119110

120111
func (n *Neo4j) Close() error {
121-
return n.driver.Close()
112+
return n.driver.Close(context.Background())
122113
}
123114

124115
// local locking in order to pass tests, Neo doesn't support database locking
@@ -138,60 +129,71 @@ func (n *Neo4j) Unlock() error {
138129
}
139130

140131
func (n *Neo4j) Run(migration io.Reader) (err error) {
141-
session := n.driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
132+
ctx := context.Background()
133+
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
142134
defer func() {
143-
if cerr := session.Close(); cerr != nil {
135+
if cerr := session.Close(ctx); cerr != nil {
144136
err = multierror.Append(err, cerr)
145137
}
146138
}()
147139

148140
if n.config.MultiStatement {
149-
_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
150-
var stmtRunErr error
151-
if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
152-
trimStmt := bytes.TrimSpace(stmt)
153-
if len(trimStmt) == 0 {
154-
return true
155-
}
156-
trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
157-
if len(trimStmt) == 0 {
158-
return true
159-
}
160-
161-
result, err := transaction.Run(string(trimStmt), nil)
162-
if _, err := neo4j.Collect(result, err); err != nil {
163-
stmtRunErr = err
164-
return false
165-
}
141+
tx, err := session.BeginTransaction(ctx)
142+
if err != nil {
143+
return err
144+
}
145+
defer func() {
146+
if cerr := tx.Close(ctx); cerr != nil {
147+
err = multierror.Append(err, cerr)
148+
}
149+
}()
150+
151+
var stmtRunErr error
152+
if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
153+
trimStmt := bytes.TrimSpace(stmt)
154+
if len(trimStmt) == 0 {
166155
return true
167-
}); err != nil {
168-
return nil, err
169156
}
170-
return nil, stmtRunErr
171-
})
172-
return err
157+
trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
158+
if len(trimStmt) == 0 {
159+
return true
160+
}
161+
162+
result, err := tx.Run(ctx, string(trimStmt), nil)
163+
if _, err := neo4j.CollectWithContext(ctx, result, err); err != nil {
164+
stmtRunErr = err
165+
return false
166+
}
167+
return true
168+
}); err != nil {
169+
return err
170+
}
171+
return stmtRunErr
173172
}
174173

175174
body, err := io.ReadAll(migration)
176175
if err != nil {
177176
return err
178177
}
179178

180-
_, err = neo4j.Collect(session.Run(string(body[:]), nil))
179+
res, err := session.Run(ctx, string(body[:]), nil)
180+
_, err = neo4j.CollectWithContext(ctx, res, err)
181181
return err
182182
}
183183

184184
func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
185-
session := n.driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
185+
ctx := context.Background()
186+
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
186187
defer func() {
187-
if cerr := session.Close(); cerr != nil {
188+
if cerr := session.Close(ctx); cerr != nil {
188189
err = multierror.Append(err, cerr)
189190
}
190191
}()
191192

192193
query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
193194
n.config.MigrationsLabel)
194-
_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
195+
res, err := session.Run(ctx, query, map[string]interface{}{"version": version, "dirty": dirty})
196+
_, err = neo4j.CollectWithContext(ctx, res, err)
195197
if err != nil {
196198
return err
197199
}
@@ -204,75 +206,73 @@ type MigrationRecord struct {
204206
}
205207

206208
func (n *Neo4j) Version() (version int, dirty bool, err error) {
207-
session := n.driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
209+
ctx := context.Background()
210+
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead})
208211
defer func() {
209-
if cerr := session.Close(); cerr != nil {
212+
if cerr := session.Close(ctx); cerr != nil {
210213
err = multierror.Append(err, cerr)
211214
}
212215
}()
213216

214217
query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
215218
ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
216219
n.config.MigrationsLabel)
217-
result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
218-
result, err := transaction.Run(query, nil)
219-
if err != nil {
220-
return nil, err
221-
}
222-
if result.Next() {
223-
record := result.Record()
224-
mr := MigrationRecord{}
225-
versionResult, ok := record.Get("version")
226-
if !ok {
227-
mr.Version = database.NilVersion
228-
} else {
229-
mr.Version = int(versionResult.(int64))
230-
}
231220

232-
dirtyResult, ok := record.Get("dirty")
233-
if ok {
234-
mr.Dirty = dirtyResult.(bool)
235-
}
221+
tx, err := session.BeginTransaction(ctx)
236222

237-
return mr, nil
238-
}
239-
return nil, result.Err()
240-
})
223+
result, err := tx.Run(ctx, query, nil)
241224
if err != nil {
242225
return database.NilVersion, false, err
243226
}
244-
if result == nil {
245-
return database.NilVersion, false, err
227+
if result.Next(ctx) {
228+
record := result.Record()
229+
mr := MigrationRecord{}
230+
versionResult, ok := record.Get("version")
231+
if !ok {
232+
mr.Version = database.NilVersion
233+
} else {
234+
mr.Version = int(versionResult.(int64))
235+
}
236+
237+
dirtyResult, ok := record.Get("dirty")
238+
if ok {
239+
mr.Dirty = dirtyResult.(bool)
240+
}
241+
242+
return mr.Version, mr.Dirty, nil
246243
}
247-
mr := result.(MigrationRecord)
248-
return mr.Version, mr.Dirty, err
244+
245+
return database.NilVersion, false, err
249246
}
250247

251248
func (n *Neo4j) Drop() (err error) {
252-
session := n.driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
249+
ctx := context.Background()
250+
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
253251
defer func() {
254-
if cerr := session.Close(); cerr != nil {
252+
if cerr := session.Close(ctx); cerr != nil {
255253
err = multierror.Append(err, cerr)
256254
}
257255
}()
258256

259-
if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
257+
res, err := session.Run(ctx, "MATCH (n) DETACH DELETE n", nil)
258+
if _, err := neo4j.CollectWithContext(ctx, res, err); err != nil {
260259
return err
261260
}
262261
return nil
263262
}
264263

265264
func (n *Neo4j) ensureVersionConstraint() (err error) {
266-
session := n.driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
265+
ctx := context.Background()
266+
session := n.driver.NewSession(ctx, neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite})
267267
defer func() {
268-
if cerr := session.Close(); cerr != nil {
268+
if cerr := session.Close(ctx); cerr != nil {
269269
err = multierror.Append(err, cerr)
270270
}
271271
}()
272272

273273
var neo4jVersion string
274-
275-
res, err := neo4j.Collect(session.Run("call dbms.components() yield versions unwind versions as version return version", nil))
274+
result, err := session.Run(ctx, "call dbms.components() yield versions unwind versions as version return version", nil)
275+
res, err := neo4j.CollectWithContext(ctx, result, err)
276276
if err != nil {
277277
return err
278278
}
@@ -287,7 +287,8 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
287287
using db.labels() to support Neo4j 3 and 4.
288288
Neo4J 3 doesn't support db.constraints() YIELD name
289289
*/
290-
res, err = neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
290+
result, err = session.Run(ctx, fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil)
291+
res, err = neo4j.CollectWithContext(ctx, result, err)
291292
if err != nil {
292293
return err
293294
}
@@ -299,13 +300,14 @@ func (n *Neo4j) ensureVersionConstraint() (err error) {
299300
switch neo4jVersion {
300301
case "v5":
301302
query = fmt.Sprintf("CREATE CONSTRAINT FOR (a:%s) REQUIRE a.version IS UNIQUE", n.config.MigrationsLabel)
302-
case "v3", "v4":
303+
case "v4":
303304
query = fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
304305
default:
305306
return fmt.Errorf("unsupported neo4j version %v", neo4jVersion)
306307
}
307308

308-
if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
309+
result, err = session.Run(ctx, query, nil)
310+
if _, err := neo4j.CollectWithContext(ctx, result, err); err != nil {
309311
return err
310312
}
311313
return nil

0 commit comments

Comments
 (0)