Skip to content

Commit 59bd95e

Browse files
author
leozeli
committed
feat(source): add embed source support
1 parent 2788339 commit 59bd95e

File tree

3 files changed

+381
-0
lines changed

3 files changed

+381
-0
lines changed

source/embed/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# embed
2+
3+
```golang
4+
//go:embed *.sql
5+
var MigrationFiles embed.FS
6+
7+
```
8+
9+
```golang
10+
embed, err := migrations.NewEmbed(migrations.MigrationFiles, ".")
11+
if err != nil {
12+
klog.Error(fmt.Sprintf("newConnectionEngine migrations.NewEmbed error:%v", err))
13+
return
14+
}
15+
m, err := migrate.NewWithInstance(
16+
"embed",
17+
embed,
18+
"mysql",
19+
dbdriver,
20+
)
21+
```

source/embed/embed.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package embed
2+
3+
import (
4+
"embed"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"io/fs"
9+
"path"
10+
"strconv"
11+
12+
"github.com/golang-migrate/migrate/v4/source"
13+
)
14+
15+
type Embed struct {
16+
FS embed.FS
17+
migrations *source.Migrations
18+
path string
19+
}
20+
21+
// NewEmbed returns a new Driver using the embed.FS and a relative path.
22+
func NewEmbed(fsys embed.FS, path string) (source.Driver, error) {
23+
var e Embed
24+
if err := e.Init(fsys, path); err != nil {
25+
return nil, fmt.Errorf("failed to init embed driver with path %s: %w", path, err)
26+
}
27+
return &e, nil
28+
}
29+
30+
// Open is part of source.Driver interface implementation.
31+
// Open cannot be called on the embed driver directly as it's designed to use embed.FS.
32+
func (e *Embed) Open(url string) (source.Driver, error) {
33+
return nil, errors.New("Open() cannot be called on the embed driver")
34+
}
35+
36+
// Init prepares Embed instance to read migrations from embed.FS and a relative path.
37+
func (e *Embed) Init(fsys embed.FS, path string) error {
38+
entries, err := fs.ReadDir(fsys, path)
39+
if err != nil {
40+
return err
41+
}
42+
43+
ms := source.NewMigrations()
44+
for _, e := range entries {
45+
if e.IsDir() {
46+
continue
47+
}
48+
m, err := source.DefaultParse(e.Name())
49+
if err != nil {
50+
continue
51+
}
52+
file, err := e.Info()
53+
if err != nil {
54+
return err
55+
}
56+
if !ms.Append(m) {
57+
return source.ErrDuplicateMigration{
58+
Migration: *m,
59+
FileInfo: file,
60+
}
61+
}
62+
}
63+
64+
e.FS = fsys
65+
e.path = path
66+
e.migrations = ms
67+
return nil
68+
}
69+
70+
// Close is part of source.Driver interface implementation.
71+
func (e *Embed) Close() error {
72+
// Since embed.FS doesn't support Close(), this method is a no-op
73+
return nil
74+
}
75+
76+
// First is part of source.Driver interface implementation.
77+
func (e *Embed) First() (version uint, err error) {
78+
if version, ok := e.migrations.First(); ok {
79+
return version, nil
80+
}
81+
return 0, &fs.PathError{
82+
Op: "first",
83+
Path: e.path,
84+
Err: fs.ErrNotExist,
85+
}
86+
}
87+
88+
// Prev is part of source.Driver interface implementation.
89+
func (e *Embed) Prev(version uint) (prevVersion uint, err error) {
90+
if version, ok := e.migrations.Prev(version); ok {
91+
return version, nil
92+
}
93+
return 0, &fs.PathError{
94+
Op: "prev for version " + strconv.FormatUint(uint64(version), 10),
95+
Path: e.path,
96+
Err: fs.ErrNotExist,
97+
}
98+
}
99+
100+
// Next is part of source.Driver interface implementation.
101+
func (e *Embed) Next(version uint) (nextVersion uint, err error) {
102+
if version, ok := e.migrations.Next(version); ok {
103+
return version, nil
104+
}
105+
return 0, &fs.PathError{
106+
Op: "next for version " + strconv.FormatUint(uint64(version), 10),
107+
Path: e.path,
108+
Err: fs.ErrNotExist,
109+
}
110+
}
111+
112+
// ReadUp is part of source.Driver interface implementation.
113+
func (e *Embed) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) {
114+
if m, ok := e.migrations.Up(version); ok {
115+
body, err := e.FS.ReadFile(path.Join(e.path, m.Raw))
116+
if err != nil {
117+
return nil, "", err
118+
}
119+
return io.NopCloser(&fileReader{data: body}), m.Identifier, nil
120+
}
121+
return nil, "", &fs.PathError{
122+
Op: "read up for version " + strconv.FormatUint(uint64(version), 10),
123+
Path: e.path,
124+
Err: fs.ErrNotExist,
125+
}
126+
}
127+
128+
// ReadDown is part of source.Driver interface implementation.
129+
func (e *Embed) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) {
130+
if m, ok := e.migrations.Down(version); ok {
131+
body, err := e.FS.ReadFile(path.Join(e.path, m.Raw))
132+
if err != nil {
133+
return nil, "", err
134+
}
135+
return io.NopCloser(&fileReader{data: body}), m.Identifier, nil
136+
}
137+
return nil, "", &fs.PathError{
138+
Op: "read down for version " + strconv.FormatUint(uint64(version), 10),
139+
Path: e.path,
140+
Err: fs.ErrNotExist,
141+
}
142+
}
143+
144+
// fileReader []byte to io.ReadCloser
145+
type fileReader struct {
146+
data []byte
147+
pos int
148+
}
149+
150+
func (fr *fileReader) Read(p []byte) (n int, err error) {
151+
if fr.pos >= len(fr.data) {
152+
return 0, io.EOF
153+
}
154+
n = copy(p, fr.data[fr.pos:])
155+
fr.pos += n
156+
return n, nil
157+
}
158+
159+
func (fr *fileReader) Close() error {
160+
// do nothing, as embed.FS does not require closing
161+
return nil
162+
}

source/embed/embed_test.go

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package embed
2+
3+
import (
4+
"embed"
5+
"errors"
6+
"io"
7+
"io/fs"
8+
"testing"
9+
10+
"github.com/golang-migrate/migrate/v4/source"
11+
st "github.com/golang-migrate/migrate/v4/source/testing"
12+
)
13+
14+
//go:embed testmigrations/*.sql
15+
var testFS embed.FS
16+
17+
const testPath = "testmigrations"
18+
19+
func Test(t *testing.T) {
20+
driver, err := NewEmbed(testFS, testPath)
21+
if err != nil {
22+
t.Fatal(err)
23+
}
24+
25+
st.Test(t, driver)
26+
}
27+
28+
func TestNewEmbed_Success(t *testing.T) {
29+
driver, err := NewEmbed(testFS, testPath)
30+
if err != nil {
31+
t.Fatalf("expected no error, got %v", err)
32+
}
33+
if driver == nil {
34+
t.Fatal("expected driver, got nil")
35+
}
36+
}
37+
38+
func TestNewEmbed_InvalidPath(t *testing.T) {
39+
_, err := NewEmbed(testFS, "doesnotexist")
40+
if err == nil {
41+
t.Fatal("expected error for invalid path, got nil")
42+
}
43+
}
44+
45+
func TestEmbed_Open(t *testing.T) {
46+
driver, _ := NewEmbed(testFS, "testmigrations")
47+
_, err := driver.(*Embed).Open("someurl")
48+
if err == nil || err.Error() != "Open() cannot be called on the embed driver" {
49+
t.Fatalf("expected Open() error, got %v", err)
50+
}
51+
}
52+
53+
func TestEmbed_First(t *testing.T) {
54+
driver, _ := NewEmbed(testFS, "testmigrations")
55+
version, err := driver.First()
56+
if err != nil {
57+
t.Fatalf("expected no error, got %v", err)
58+
}
59+
if version == 0 {
60+
t.Fatal("expected non-zero version")
61+
}
62+
}
63+
64+
func TestEmbed_First_Empty(t *testing.T) {
65+
emptyFS := embed.FS{}
66+
e := &Embed{}
67+
e.FS = emptyFS
68+
e.path = "empty"
69+
e.migrations = source.NewMigrations()
70+
_, err := e.First()
71+
if err == nil {
72+
t.Fatal("expected error for empty migrations")
73+
}
74+
}
75+
76+
func TestEmbed_PrevNext(t *testing.T) {
77+
driver, _ := NewEmbed(testFS, "testmigrations")
78+
first, _ := driver.First()
79+
_, err := driver.Prev(first)
80+
if err == nil {
81+
t.Fatal("expected error for prev of first migration")
82+
}
83+
next, err := driver.Next(first)
84+
if err != nil {
85+
t.Fatalf("expected no error for next, got %v", err)
86+
}
87+
if next == 0 {
88+
t.Fatal("expected next version to be non-zero")
89+
}
90+
}
91+
92+
func TestEmbed_ReadUpDown(t *testing.T) {
93+
driver, _ := NewEmbed(testFS, "testmigrations")
94+
first, _ := driver.First()
95+
r, id, err := driver.ReadUp(first)
96+
if err != nil {
97+
t.Fatalf("expected no error, got %v", err)
98+
}
99+
if r == nil || id == "" {
100+
t.Fatal("expected valid reader and identifier")
101+
}
102+
b, err := io.ReadAll(r)
103+
if err != nil {
104+
t.Fatalf("failed to read: %v", err)
105+
}
106+
if len(b) == 0 {
107+
t.Fatal("expected file content")
108+
}
109+
r.Close()
110+
111+
// Down migration may not exist for first, so test with next if available
112+
next, _ := driver.Next(first)
113+
rd, idd, err := driver.ReadDown(next)
114+
if err == nil {
115+
if rd == nil || idd == "" {
116+
t.Fatal("expected valid reader and identifier for down")
117+
}
118+
rd.Close()
119+
}
120+
}
121+
122+
func TestEmbed_ReadUp_NotExist(t *testing.T) {
123+
driver, _ := NewEmbed(testFS, "testmigrations")
124+
_, _, err := driver.ReadUp(999999)
125+
if err == nil {
126+
t.Fatal("expected error for non-existent migration")
127+
}
128+
var pathErr *fs.PathError
129+
if !errors.As(err, &pathErr) {
130+
t.Fatalf("expected fs.PathError, got %T", err)
131+
}
132+
}
133+
134+
func TestEmbed_Close(t *testing.T) {
135+
driver, _ := NewEmbed(testFS, "testmigrations")
136+
if err := driver.Close(); err != nil {
137+
t.Fatalf("expected no error, got %v", err)
138+
}
139+
}
140+
141+
func TestFileReader_ReadClose(t *testing.T) {
142+
data := []byte("hello world")
143+
fr := &fileReader{data: data}
144+
buf := make([]byte, 5)
145+
n, err := fr.Read(buf)
146+
if n != 5 || err != nil {
147+
t.Fatalf("expected to read 5 bytes, got %d, err %v", n, err)
148+
}
149+
n, err = fr.Read(buf)
150+
if n != 5 || err != nil {
151+
t.Fatalf("expected to read next 5 bytes, got %d, err %v", n, err)
152+
}
153+
n, err = fr.Read(buf)
154+
if n != 1 || err != nil {
155+
t.Fatalf("expected to read last byte, got %d, err %v", n, err)
156+
}
157+
n, err = fr.Read(buf)
158+
if n != 0 || err != io.EOF {
159+
t.Fatalf("expected EOF, got %d, err %v", n, err)
160+
}
161+
if err := fr.Close(); err != nil {
162+
t.Fatalf("expected no error on close, got %v", err)
163+
}
164+
}
165+
166+
// createBenchmarkEmbed creates an Embed driver with test migrations
167+
// This is a helper function for benchmarks
168+
func createBenchmarkEmbed(b *testing.B) *Embed {
169+
driver, err := NewEmbed(testFS, testPath)
170+
if err != nil {
171+
b.Fatal(err)
172+
}
173+
return driver.(*Embed)
174+
}
175+
176+
func BenchmarkFirst(b *testing.B) {
177+
e := createBenchmarkEmbed(b)
178+
b.ResetTimer()
179+
for n := 0; n < b.N; n++ {
180+
_, err := e.First()
181+
if err != nil {
182+
b.Error(err)
183+
}
184+
}
185+
b.StopTimer()
186+
}
187+
188+
func BenchmarkNext(b *testing.B) {
189+
e := createBenchmarkEmbed(b)
190+
b.ResetTimer()
191+
v, err := e.First()
192+
for n := 0; n < b.N; n++ {
193+
for !errors.Is(err, fs.ErrNotExist) {
194+
v, err = e.Next(v)
195+
}
196+
}
197+
b.StopTimer()
198+
}

0 commit comments

Comments
 (0)