Skip to content

Commit dd3bcc7

Browse files
committed
split downloader and getter
1 parent 93559ed commit dd3bcc7

File tree

7 files changed

+618
-485
lines changed

7 files changed

+618
-485
lines changed

feature/s3/transfermanager/api_op_DownloadObject.go

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ func (o *DownloadObjectOutput) mapFromGetObjectOutput(out *s3.GetObjectOutput, c
529529
// download. These options are copies of the original Options instance, the client of which DownloadObject is called from.
530530
// Modifying the options will not impact the original Client and Options instance.
531531
func (c *Client) DownloadObject(ctx context.Context, input *DownloadObjectInput, opts ...func(*Options)) (*DownloadObjectOutput, error) {
532-
i := downloader{in: input, options: c.options.Copy(), w: input.WriterAt}
532+
i := downloader{in: input, options: c.options.Copy()}
533533
for _, opt := range opts {
534534
opt(&i.options)
535535
}
@@ -541,7 +541,6 @@ type downloader struct {
541541
options Options
542542
in *DownloadObjectInput
543543
out *DownloadObjectOutput
544-
w io.WriterAt
545544

546545
wg sync.WaitGroup
547546
m sync.Mutex
@@ -571,7 +570,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
571570
return d.singleDownload(ctx, clientOptions...)
572571
}
573572

574-
var output *GetObjectOutput
573+
var output *DownloadObjectOutput
575574
if d.options.MultipartDownloadType == types.MultipartDownloadTypePart {
576575
if d.in.Range != "" {
577576
return d.singleDownload(ctx, clientOptions...)
@@ -583,7 +582,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
583582

584583
if output.PartsCount > 1 {
585584
partSize := output.ContentLength
586-
ch := make(chan dlchunk, d.options.Concurrency)
585+
ch := make(chan dlChunk, d.options.Concurrency)
587586
for i := 0; i < d.options.Concurrency; i++ {
588587
d.wg.Add(1)
589588
go d.downloadPart(ctx, ch, clientOptions...)
@@ -594,25 +593,23 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
594593
break
595594
}
596595

597-
ch <- dlchunk{w: d.w, start: d.pos - d.offset, part: i}
596+
ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: i}
598597
d.pos += partSize
599598
}
600599

601600
close(ch)
602601
d.wg.Wait()
603602
}
604603
} else {
605-
var total int64
606604
if d.in.Range == "" {
607605
output = d.getChunk(ctx, 0, d.byteRange(), clientOptions...)
608-
total = d.getTotalBytes()
609606
} else {
610607
d.pos, d.totalBytes = d.getDownloadRange()
611608
d.offset = d.pos
612-
total = d.totalBytes
613609
}
610+
total := d.totalBytes
614611

615-
ch := make(chan dlchunk, d.options.Concurrency)
612+
ch := make(chan dlChunk, d.options.Concurrency)
616613
for i := 0; i < d.options.Concurrency; i++ {
617614
d.wg.Add(1)
618615
go d.downloadPart(ctx, ch, clientOptions...)
@@ -625,7 +622,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
625622
}
626623

627624
// Queue the next range of bytes to read.
628-
ch <- dlchunk{w: d.w, start: d.pos - d.offset, withRange: d.byteRange()}
625+
ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, withRange: d.byteRange()}
629626
d.pos += d.options.PartSizeBytes
630627
}
631628

@@ -659,17 +656,17 @@ func (d *downloader) init(ctx context.Context) error {
659656
}
660657

661658
func (d *downloader) singleDownload(ctx context.Context, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
662-
chunk := dlchunk{w: d.w}
663-
d.in.PartNumber = 0
659+
chunk := dlChunk{w: d.in.WriterAt}
660+
// d.in.PartNumber = 0
664661
output, err := d.downloadChunk(ctx, chunk, clientOptions...)
665662
if err != nil {
666-
return output, err
663+
return nil, err
667664
}
668665

669-
return output, err
666+
return output, nil
670667
}
671668

672-
func (d *downloader) downloadPart(ctx context.Context, ch chan dlchunk, clientOptions ...func(*s3.Options)) {
669+
func (d *downloader) downloadPart(ctx context.Context, ch chan dlChunk, clientOptions ...func(*s3.Options)) {
673670
defer d.wg.Done()
674671
for {
675672
chunk, ok := <-ch
@@ -690,8 +687,8 @@ func (d *downloader) downloadPart(ctx context.Context, ch chan dlchunk, clientOp
690687

691688
// getChunk grabs a chunk of data from the body.
692689
// Not thread safe. Should only used when grabbing data on a single thread.
693-
func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clientOptions ...func(*s3.Options)) *GetObjectOutput {
694-
chunk := dlchunk{w: d.w, start: d.pos - d.offset, part: part, withRange: rng}
690+
func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clientOptions ...func(*s3.Options)) *DownloadObjectOutput {
691+
chunk := dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: part, withRange: rng}
695692

696693
output, err := d.downloadChunk(ctx, chunk, clientOptions...)
697694
if err != nil {
@@ -705,7 +702,7 @@ func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clien
705702
}
706703

707704
// downloadChunk downloads the chunk from s3
708-
func (d *downloader) downloadChunk(ctx context.Context, chunk dlchunk, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
705+
func (d *downloader) downloadChunk(ctx context.Context, chunk dlChunk, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
709706
params := d.in.mapGetObjectInput(!d.options.DisableChecksumValidation)
710707
if chunk.part != 0 {
711708
params.PartNumber = aws.Int32(chunk.part)
@@ -750,7 +747,7 @@ func (d *downloader) downloadChunk(ctx context.Context, chunk dlchunk, clientOpt
750747
return output, err
751748
}
752749

753-
func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectInput, chunk *dlchunk, clientOptions ...func(*s3.Options)) (*s3.GetObjectOutput, int64, error) {
750+
func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectInput, chunk *dlChunk, clientOptions ...func(*s3.Options)) (*s3.GetObjectOutput, int64, error) {
754751
out, err := d.options.S3.GetObject(ctx, params, clientOptions...)
755752
if err != nil {
756753
return nil, 0, err
@@ -865,7 +862,7 @@ func (d *downloader) setErr(e error) {
865862
d.err = e
866863
}
867864

868-
type dlchunk struct {
865+
type dlChunk struct {
869866
w io.WriterAt
870867

871868
start int64
@@ -875,7 +872,7 @@ type dlchunk struct {
875872
withRange string
876873
}
877874

878-
func (c *dlchunk) Write(p []byte) (int, error) {
875+
func (c *dlChunk) Write(p []byte) (int, error) {
879876
n, err := c.w.WriteAt(p, c.start+c.cur)
880877
c.cur += int64(n)
881878

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//go:build integration
2+
// +build integration
3+
4+
package transfermanager
5+
6+
import (
7+
"bytes"
8+
"strings"
9+
"testing"
10+
)
11+
12+
func TestInteg_DownloadObject(t *testing.T) {
13+
cases := map[string]getObjectTestData{
14+
"seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")},
15+
"empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")},
16+
"multipart download body": {Body: bytes.NewReader(largeObjectBuf), ExpectBody: largeObjectBuf},
17+
}
18+
19+
for name, c := range cases {
20+
t.Run(name, func(t *testing.T) {
21+
testDownloadObject(t, setupMetadata.Buckets.Source.Name, c)
22+
})
23+
}
24+
}

0 commit comments

Comments
 (0)