Skip to content

Commit 915e8a8

Browse files
authored
Merge pull request #150 from EmmEff/oci-download
Support for pulling SIF images directly from OCI registry
2 parents fd6a3ee + ef899af commit 915e8a8

12 files changed

+1286
-174
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ executors:
99
- image: node:18-slim
1010
golangci-lint:
1111
docker:
12-
- image: golangci/golangci-lint:v1.48
12+
- image: golangci/golangci-lint:v1.50
1313
golang-previous:
1414
docker:
1515
- image: golang:1.18

.golangci.yml

-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ linters:
55
- bodyclose
66
- containedctx
77
- contextcheck
8-
- deadcode
98
- decorder
109
- depguard
1110
- dogsled
@@ -28,11 +27,8 @@ linters:
2827
- nakedret
2928
- prealloc
3029
- revive
31-
- rowserrcheck
3230
- staticcheck
33-
- structcheck
3431
- stylecheck
3532
- tenv
3633
- typecheck
3734
- unused
38-
- varcheck

client/client.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019, Sylabs Inc. All rights reserved.
1+
// Copyright (c) 2019-2022, Sylabs Inc. All rights reserved.
22
// This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file
33
// distributed with the sources of this project regarding your rights to use or distribute this
44
// software.
@@ -104,9 +104,13 @@ func (c *Client) newRequestWithURL(ctx context.Context, method, url string, body
104104
if err != nil {
105105
return nil, err
106106
}
107+
107108
if v := c.AuthToken; v != "" {
108-
r.Header.Set("Authorization", fmt.Sprintf("BEARER %s", v))
109+
if err := (bearerTokenCredentials{authToken: v}).ModifyRequest(r); err != nil {
110+
return nil, err
111+
}
109112
}
113+
110114
if v := c.UserAgent; v != "" {
111115
r.Header.Set("User-Agent", v)
112116
}

client/client_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019, Sylabs Inc. All rights reserved.
1+
// Copyright (c) 2019-2022, Sylabs Inc. All rights reserved.
22
// This software is licensed under a 3-clause BSD license. Please consult the LICENSE.md file
33
// distributed with the sources of this project regarding your rights to use or distribute this
44
// software.
@@ -161,7 +161,7 @@ func TestNewRequest(t *testing.T) {
161161
if got, want := len(authBearer), 1; got != want {
162162
t.Fatalf("got %v auth bearer(s), want %v", got, want)
163163
}
164-
if got, want := authBearer[0], tt.wantAuthBearer; got != want {
164+
if got, want := authBearer[0], tt.wantAuthBearer; !strings.EqualFold(got, want) {
165165
t.Errorf("got auth bearer %v, want %v", got, want)
166166
}
167167
}

client/downloader.go

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2021-2022, Sylabs Inc. All rights reserved.
2+
// This software is licensed under a 3-clause BSD license. Please consult the
3+
// LICENSE.md file distributed with the sources of this project regarding your
4+
// rights to use or distribute this software.
5+
6+
package client
7+
8+
import (
9+
"context"
10+
"errors"
11+
"fmt"
12+
"io"
13+
"net/http"
14+
"strconv"
15+
"strings"
16+
17+
"golang.org/x/sync/errgroup"
18+
)
19+
20+
// filePartDescriptor defines one part of multipart download.
21+
type filePartDescriptor struct {
22+
start int64
23+
end int64
24+
cur int64
25+
26+
w io.WriterAt
27+
}
28+
29+
// Write writes buffer 'p' at offset 'start' using 'WriteAt()' to atomically seek and write.
30+
// Returns bytes written
31+
func (ps *filePartDescriptor) Write(p []byte) (n int, err error) {
32+
n, err = ps.w.WriteAt(p, ps.start+ps.cur)
33+
ps.cur += int64(n)
34+
35+
return
36+
}
37+
38+
// minInt64 returns minimum value of two arguments
39+
func minInt64(a, b int64) int64 {
40+
if a < b {
41+
return a
42+
}
43+
return b
44+
}
45+
46+
// Download performs download of contents at url by writing 'size' bytes to 'dst' using credentials 'c'.
47+
func (c *Client) multipartDownload(ctx context.Context, u string, creds credentials, w io.WriterAt, size int64, spec *Downloader, pb ProgressBar) error {
48+
if size <= 0 {
49+
return fmt.Errorf("invalid image size (%v)", size)
50+
}
51+
52+
// Initialize the progress bar using passed size
53+
pb.Init(size)
54+
55+
// Clean up (remove) progress bar after download
56+
defer pb.Wait()
57+
58+
// Calculate # of parts
59+
parts := uint(1 + (size-1)/spec.PartSize)
60+
61+
c.Logger.Logf("size: %d, parts: %d, streams: %d, partsize: %d", size, parts, spec.Concurrency, spec.PartSize)
62+
63+
g, ctx := errgroup.WithContext(ctx)
64+
65+
// Allocate channel for file part requests
66+
ch := make(chan filePartDescriptor, parts)
67+
68+
// Create download part workers
69+
for n := uint(0); n < spec.Concurrency; n++ {
70+
g.Go(c.ociDownloadWorker(ctx, u, creds, ch, pb))
71+
}
72+
73+
// Add part download requests
74+
for n := uint(0); n < parts; n++ {
75+
partSize := minInt64(spec.PartSize, size-int64(n)*spec.PartSize)
76+
77+
ch <- filePartDescriptor{start: int64(n) * spec.PartSize, end: int64(n)*spec.PartSize + partSize - 1, w: w}
78+
}
79+
80+
// Close worker queue after submitting all requests
81+
close(ch)
82+
83+
// Wait for workers to complete
84+
return g.Wait()
85+
}
86+
87+
func (c *Client) ociDownloadWorker(ctx context.Context, u string, creds credentials, ch chan filePartDescriptor, pb ProgressBar) func() error {
88+
return func() error {
89+
// Iterate on channel 'ch' to handle download part requests
90+
for ps := range ch {
91+
written, err := c.ociDownloadBlobPart(ctx, creds, u, &ps)
92+
if err != nil {
93+
// Cleanly abort progress bar on error
94+
pb.Abort(true)
95+
96+
return err
97+
}
98+
99+
// Increase progress bar by number of bytes downloaded/written
100+
pb.IncrBy(int(written))
101+
}
102+
return nil
103+
}
104+
}
105+
106+
func (c *Client) ociDownloadBlobPart(ctx context.Context, creds credentials, u string, ps *filePartDescriptor) (int64, error) {
107+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
108+
if err != nil {
109+
return 0, err
110+
}
111+
if err := creds.ModifyRequest(req); err != nil {
112+
return 0, err
113+
}
114+
115+
req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", ps.start, ps.end))
116+
117+
res, err := c.HTTPClient.Do(req)
118+
if err != nil {
119+
return 0, err
120+
}
121+
defer res.Body.Close()
122+
123+
return io.Copy(ps, res.Body)
124+
}
125+
126+
// parseContentRange parses "Content-Range" header (eg. "Content-Range: bytes 0-1000/2000") and returns size
127+
func parseContentRange(val string) (int64, error) {
128+
e := strings.Split(val, " ")
129+
130+
if !strings.EqualFold(e[0], "bytes") {
131+
return 0, errors.New("unexpected/malformed value")
132+
}
133+
134+
rangeElems := strings.Split(e[1], "/")
135+
136+
if len(rangeElems) != 2 {
137+
return 0, errors.New("unexpected/malformed value")
138+
}
139+
140+
return strconv.ParseInt(rangeElems[1], 10, 0)
141+
}

client/downloader_test.go

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright (c) 2018-2022, Sylabs Inc. All rights reserved.
2+
// This software is licensed under a 3-clause BSD license. Please consult the
3+
// LICENSE.md file distributed with the sources of this project regarding your
4+
// rights to use or distribute this software.
5+
6+
package client
7+
8+
import (
9+
"bytes"
10+
"context"
11+
"fmt"
12+
"io"
13+
"log"
14+
"net/http"
15+
"net/http/httptest"
16+
"strconv"
17+
"strings"
18+
"sync"
19+
"testing"
20+
)
21+
22+
type inMemoryBuffer struct {
23+
m sync.Mutex
24+
buf []byte
25+
}
26+
27+
func (f *inMemoryBuffer) WriteAt(p []byte, ofs int64) (n int, err error) {
28+
f.m.Lock()
29+
defer f.m.Unlock()
30+
31+
n = copy(f.buf[ofs:], p)
32+
return
33+
}
34+
35+
func (f *inMemoryBuffer) Bytes() []byte {
36+
f.m.Lock()
37+
defer f.m.Unlock()
38+
39+
return f.buf
40+
}
41+
42+
type stdLogger struct{}
43+
44+
func (l *stdLogger) Log(v ...interface{}) {
45+
log.Print(v...)
46+
}
47+
48+
func (l *stdLogger) Logf(f string, v ...interface{}) {
49+
log.Printf(f, v...)
50+
}
51+
52+
func parseRangeHeader(t *testing.T, val string) (int64, int64) {
53+
if val == "" {
54+
return 0, 0
55+
}
56+
57+
var start, end int64
58+
59+
e := strings.SplitN(val, "=", 2)
60+
61+
byteRange := strings.Split(e[1], "-")
62+
63+
start, _ = strconv.ParseInt(byteRange[0], 10, 0)
64+
end, _ = strconv.ParseInt(byteRange[1], 10, 0)
65+
66+
return start, end
67+
}
68+
69+
const (
70+
basicAuthUsername = "user"
71+
basicAuthPassword = "password"
72+
)
73+
74+
var (
75+
testLogger = &stdLogger{}
76+
creds = &basicCredentials{username: basicAuthUsername, password: basicAuthPassword}
77+
)
78+
79+
func TestMultistreamDownloader(t *testing.T) {
80+
const src = "123456789012345678901234567890"
81+
size := int64(len(src))
82+
83+
defaultSpec := &Downloader{Concurrency: 10, PartSize: 3}
84+
85+
tests := []struct {
86+
name string
87+
size int64
88+
spec *Downloader
89+
expectErr bool
90+
}{
91+
{"Basic", size, defaultSpec, false},
92+
{"WithoutSize", 0, defaultSpec, true},
93+
{"SingleStream", size, &Downloader{Concurrency: 1, PartSize: 1}, false},
94+
{"SingleStreamWithoutSize", 0, &Downloader{Concurrency: 1, PartSize: 1}, true},
95+
{"ManyStreams", size, &Downloader{Concurrency: uint(size), PartSize: 1}, false},
96+
{"ManyStreamsWithoutSize", 0, &Downloader{Concurrency: uint(size), PartSize: 1}, true},
97+
}
98+
99+
for _, tt := range tests {
100+
tt := tt
101+
102+
t.Run(tt.name, func(t *testing.T) {
103+
t.Parallel()
104+
105+
// Create test http server for serving "file"
106+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
start, end := parseRangeHeader(t, r.Header.Get("Range"))
108+
109+
if username, password, ok := r.BasicAuth(); ok {
110+
if got, want := username, basicAuthUsername; got != want {
111+
t.Fatalf("unexpected basic auth username: got %v, want %v", got, want)
112+
}
113+
if got, want := password, basicAuthPassword; got != want {
114+
t.Fatalf("unexpected basic auth password: got %v, want %v", got, want)
115+
}
116+
}
117+
118+
w.Header().Set("Content-Range", fmt.Sprintf("bytes %v-%v/%v", start, end+1, size))
119+
w.Header().Set("Content-Length", fmt.Sprintf("%v", end-start+1))
120+
121+
w.WriteHeader(http.StatusPartialContent)
122+
123+
if _, err := io.Copy(w, bytes.NewReader([]byte(src[start:end+1]))); err != nil {
124+
t.Fatalf("unexpected error writing http response: %v", err)
125+
}
126+
}))
127+
defer srv.Close()
128+
129+
c, err := NewClient(&Config{Logger: testLogger})
130+
if err != nil {
131+
t.Fatalf("error initializing client: %v", err)
132+
}
133+
134+
// Preallocate sink for downloaded file
135+
dst := &inMemoryBuffer{buf: make([]byte, size)}
136+
137+
// Start download
138+
err = c.multipartDownload(context.Background(), srv.URL, creds, dst, tt.size, tt.spec, &NoopProgressBar{})
139+
if tt.expectErr && err == nil {
140+
t.Fatal("unexpected success")
141+
}
142+
if !tt.expectErr && err != nil {
143+
t.Fatalf("unexpected error: %v", err)
144+
}
145+
if err != nil {
146+
return
147+
}
148+
149+
// Compare results with expectations
150+
if got, want := string(dst.Bytes()), src; got != want {
151+
t.Fatalf("unexpected data: got %v, want %v", got, want)
152+
}
153+
})
154+
}
155+
}

0 commit comments

Comments
 (0)