Skip to content

Commit e5b4be7

Browse files
erikdubbelboerdhui
authored andcommitted
Let database.Open() use schemeFromURL as well (#271)
* Let database.Open() use schemeFromURL as well Otherwise it will fail on MySQL DSNs. Moved schemeFromURL into the database package. Also removed databaseSchemeFromURL and sourceSchemeFromURL as they were just calling schemeFromURL. Fixes #265 (comment) * Moved url functions into internal/url Also merged the test cases. * Add some database tests to improve coverage * Fix suggestions
1 parent d5960ad commit e5b4be7

File tree

7 files changed

+191
-152
lines changed

7 files changed

+191
-152
lines changed

database/driver.go

+6-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ package database
77
import (
88
"fmt"
99
"io"
10-
nurl "net/url"
1110
"sync"
11+
12+
iurl "github.com/golang-migrate/migrate/v4/internal/url"
1213
)
1314

1415
var (
@@ -81,21 +82,16 @@ type Driver interface {
8182

8283
// Open returns a new driver instance.
8384
func Open(url string) (Driver, error) {
84-
u, err := nurl.Parse(url)
85+
scheme, err := iurl.SchemeFromURL(url)
8586
if err != nil {
86-
return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+
87-
"See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err)
88-
}
89-
90-
if u.Scheme == "" {
91-
return nil, fmt.Errorf("database driver: invalid URL scheme")
87+
return nil, err
9288
}
9389

9490
driversMu.RLock()
95-
d, ok := drivers[u.Scheme]
91+
d, ok := drivers[scheme]
9692
driversMu.RUnlock()
9793
if !ok {
98-
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme)
94+
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
9995
}
10096

10197
return d.Open(url)

database/driver_test.go

+107
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,115 @@
11
package database
22

3+
import (
4+
"io"
5+
"testing"
6+
)
7+
38
func ExampleDriver() {
49
// see database/stub for an example
510

611
// database/stub/stub.go has the driver implementation
712
// database/stub/stub_test.go runs database/testing/test.go:Test
813
}
14+
15+
// Using database/stub here is not possible as it
16+
// results in an import cycle.
17+
type mockDriver struct {
18+
url string
19+
}
20+
21+
func (m *mockDriver) Open(url string) (Driver, error) {
22+
return &mockDriver{
23+
url: url,
24+
}, nil
25+
}
26+
27+
func (m *mockDriver) Close() error {
28+
return nil
29+
}
30+
31+
func (m *mockDriver) Lock() error {
32+
return nil
33+
}
34+
35+
func (m *mockDriver) Unlock() error {
36+
return nil
37+
}
38+
39+
func (m *mockDriver) Run(migration io.Reader) error {
40+
return nil
41+
}
42+
43+
func (m *mockDriver) SetVersion(version int, dirty bool) error {
44+
return nil
45+
}
46+
47+
func (m *mockDriver) Version() (version int, dirty bool, err error) {
48+
return 0, false, nil
49+
}
50+
51+
func (m *mockDriver) Drop() error {
52+
return nil
53+
}
54+
55+
func TestRegisterTwice(t *testing.T) {
56+
Register("mock", &mockDriver{})
57+
58+
var err interface{}
59+
func() {
60+
defer func() {
61+
err = recover()
62+
}()
63+
Register("mock", &mockDriver{})
64+
}()
65+
66+
if err == nil {
67+
t.Fatal("expected a panic when calling Register twice")
68+
}
69+
}
70+
71+
func TestOpen(t *testing.T) {
72+
// Make sure the driver is registered.
73+
// But if the previous test already registered it just ignore the panic.
74+
// If we don't do this it will be impossible to run this test standalone.
75+
func() {
76+
defer func() {
77+
_ = recover()
78+
}()
79+
Register("mock", &mockDriver{})
80+
}()
81+
82+
cases := []struct {
83+
url string
84+
err bool
85+
}{
86+
{
87+
"mock://user:pass@tcp(host:1337)/db",
88+
false,
89+
},
90+
{
91+
"unknown://bla",
92+
true,
93+
},
94+
}
95+
96+
for _, c := range cases {
97+
t.Run(c.url, func(t *testing.T) {
98+
d, err := Open(c.url)
99+
100+
if err == nil {
101+
if c.err {
102+
t.Fatal("expected an error for an unknown driver")
103+
} else {
104+
if md, ok := d.(*mockDriver); !ok {
105+
t.Fatalf("expected *mockDriver got %T", d)
106+
} else if md.url != c.url {
107+
t.Fatalf("expected %q got %q", c.url, md.url)
108+
}
109+
}
110+
} else if !c.err {
111+
t.Fatalf("did not expect %q", err)
112+
}
113+
})
114+
}
115+
}

internal/url/url.go

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package url
2+
3+
import (
4+
"errors"
5+
"strings"
6+
)
7+
8+
var errNoScheme = errors.New("no scheme")
9+
var errEmptyURL = errors.New("URL cannot be empty")
10+
11+
// schemeFromURL returns the scheme from a URL string
12+
func SchemeFromURL(url string) (string, error) {
13+
if url == "" {
14+
return "", errEmptyURL
15+
}
16+
17+
i := strings.Index(url, ":")
18+
19+
// No : or : is the first character.
20+
if i < 1 {
21+
return "", errNoScheme
22+
}
23+
24+
return url[0:i], nil
25+
}

internal/url/url_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package url
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestSchemeFromUrl(t *testing.T) {
8+
cases := []struct {
9+
name string
10+
urlStr string
11+
expected string
12+
expectErr error
13+
}{
14+
{
15+
name: "Simple",
16+
urlStr: "protocol://path",
17+
expected: "protocol",
18+
},
19+
{
20+
// See issue #264
21+
name: "MySQLWithPort",
22+
urlStr: "mysql://user:pass@tcp(host:1337)/db",
23+
expected: "mysql",
24+
},
25+
{
26+
name: "Empty",
27+
urlStr: "",
28+
expectErr: errEmptyURL,
29+
},
30+
{
31+
name: "NoScheme",
32+
urlStr: "hello",
33+
expectErr: errNoScheme,
34+
},
35+
}
36+
37+
for _, tc := range cases {
38+
t.Run(tc.name, func(t *testing.T) {
39+
s, err := SchemeFromURL(tc.urlStr)
40+
if err != tc.expectErr {
41+
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
42+
}
43+
if s != tc.expected {
44+
t.Fatalf("expected %q, but received %q", tc.expected, s)
45+
}
46+
})
47+
}
48+
}

migrate.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/golang-migrate/migrate/v4/database"
16+
iurl "github.com/golang-migrate/migrate/v4/internal/url"
1617
"github.com/golang-migrate/migrate/v4/source"
1718
)
1819

@@ -85,13 +86,13 @@ type Migrate struct {
8586
func New(sourceURL, databaseURL string) (*Migrate, error) {
8687
m := newCommon()
8788

88-
sourceName, err := sourceSchemeFromURL(sourceURL)
89+
sourceName, err := iurl.SchemeFromURL(sourceURL)
8990
if err != nil {
9091
return nil, err
9192
}
9293
m.sourceName = sourceName
9394

94-
databaseName, err := databaseSchemeFromURL(databaseURL)
95+
databaseName, err := iurl.SchemeFromURL(databaseURL)
9596
if err != nil {
9697
return nil, err
9798
}
@@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
119120
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
120121
m := newCommon()
121122

122-
sourceName, err := schemeFromURL(sourceURL)
123+
sourceName, err := iurl.SchemeFromURL(sourceURL)
123124
if err != nil {
124125
return nil, err
125126
}
@@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
145146
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
146147
m := newCommon()
147148

148-
databaseName, err := schemeFromURL(databaseURL)
149+
databaseName, err := iurl.SchemeFromURL(databaseURL)
149150
if err != nil {
150151
return nil, err
151152
}

util.go

-36
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package migrate
22

33
import (
4-
"errors"
54
"fmt"
65
nurl "net/url"
76
"strings"
@@ -49,41 +48,6 @@ func suint(n int) uint {
4948
return uint(n)
5049
}
5150

52-
var errNoScheme = errors.New("no scheme")
53-
var errEmptyURL = errors.New("URL cannot be empty")
54-
55-
func sourceSchemeFromURL(url string) (string, error) {
56-
u, err := schemeFromURL(url)
57-
if err != nil {
58-
return "", fmt.Errorf("source: %v", err)
59-
}
60-
return u, nil
61-
}
62-
63-
func databaseSchemeFromURL(url string) (string, error) {
64-
u, err := schemeFromURL(url)
65-
if err != nil {
66-
return "", fmt.Errorf("database: %v", err)
67-
}
68-
return u, nil
69-
}
70-
71-
// schemeFromURL returns the scheme from a URL string
72-
func schemeFromURL(url string) (string, error) {
73-
if url == "" {
74-
return "", errEmptyURL
75-
}
76-
77-
i := strings.Index(url, ":")
78-
79-
// No : or : is the first character.
80-
if i < 1 {
81-
return "", errNoScheme
82-
}
83-
84-
return url[0:i], nil
85-
}
86-
8751
// FilterCustomQuery filters all query values starting with `x-`
8852
func FilterCustomQuery(u *nurl.URL) *nurl.URL {
8953
ux := *u

0 commit comments

Comments
 (0)