diff --git a/feature/s3/transfermanager/api.go b/feature/s3/transfermanager/api.go index dbe430dadb4..526cbf96ff0 100644 --- a/feature/s3/transfermanager/api.go +++ b/feature/s3/transfermanager/api.go @@ -13,4 +13,6 @@ type S3APIClient interface { CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) + GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) + HeadObject(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) (*s3.HeadObjectOutput, error) } diff --git a/feature/s3/transfermanager/api_client.go b/feature/s3/transfermanager/api_client.go index 84ce16db425..bbee94d986e 100644 --- a/feature/s3/transfermanager/api_client.go +++ b/feature/s3/transfermanager/api_client.go @@ -20,6 +20,10 @@ const defaultMultipartUploadThreshold = 1024 * 1024 * 16 // using PutObject(). const defaultTransferConcurrency = 5 +const defaultPartBodyMaxRetries = 3 + +const defaultGetBufferSize = 1024 * 1024 * 50 + // Client provides the API client to make operations call for Amazon Simple // Storage Service's Transfer Manager // It is safe to call Client methods concurrently across goroutines. @@ -39,6 +43,9 @@ func New(s3Client S3APIClient, opts Options, optFns ...func(*Options)) *Client { resolvePartSizeBytes(&opts) resolveChecksumAlgorithm(&opts) resolveMultipartUploadThreshold(&opts) + resolveGetObjectType(&opts) + resolvePartBodyMaxRetries(&opts) + resolveGetBufferSize(&opts) return &Client{ options: opts, diff --git a/feature/s3/transfermanager/api_op_DownloadObject.go b/feature/s3/transfermanager/api_op_DownloadObject.go new file mode 100644 index 00000000000..76ae1b12692 --- /dev/null +++ b/feature/s3/transfermanager/api_op_DownloadObject.go @@ -0,0 +1,874 @@ +package transfermanager + +import ( + "context" + "fmt" + "io" + "math" + "strconv" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +// DownloadObjectInput represents a request to the DownloadObject() call. It contains common fields +// of s3 GetObject input and destination WriterAt of object +type DownloadObjectInput struct { + // Bucket where the object is downloaded from + Bucket string + + // Key of the object to get. + Key string + + // Destination WriterAt which object parts are written to + WriterAt io.WriterAt + + // To retrieve the checksum, this mode must be enabled. + // + // General purpose buckets - In addition, if you enable checksum mode and the + // object is uploaded with a [checksum]and encrypted with an Key Management Service (KMS) + // key, you must have permission to use the kms:Decrypt action to retrieve the + // checksum. + // + // [checksum]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_Checksum.html + ChecksumMode types.ChecksumMode + + // The account ID of the expected bucket owner. If the account ID that you provide + // does not match the actual owner of the bucket, the request fails with the HTTP + // status code 403 Forbidden (access denied). + ExpectedBucketOwner string + + // Return the object only if its entity tag (ETag) is the same as the one + // specified in this header; otherwise, return a 412 Precondition Failed error. + // + // If both of the If-Match and If-Unmodified-Since headers are present in the + // request as follows: If-Match condition evaluates to true , and; + // If-Unmodified-Since condition evaluates to false ; then, S3 returns 200 OK and + // the data requested. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfMatch string + + // Return the object only if it has been modified since the specified time; + // otherwise, return a 304 Not Modified error. + // + // If both of the If-None-Match and If-Modified-Since headers are present in the + // request as follows: If-None-Match condition evaluates to false , and; + // If-Modified-Since condition evaluates to true ; then, S3 returns 304 Not + // Modified status code. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfModifiedSince time.Time + + // Return the object only if its entity tag (ETag) is different from the one + // specified in this header; otherwise, return a 304 Not Modified error. + // + // If both of the If-None-Match and If-Modified-Since headers are present in the + // request as follows: If-None-Match condition evaluates to false , and; + // If-Modified-Since condition evaluates to true ; then, S3 returns 304 Not + // Modified HTTP status code. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfNoneMatch string + + // Return the object only if it has not been modified since the specified time; + // otherwise, return a 412 Precondition Failed error. + // + // If both of the If-Match and If-Unmodified-Since headers are present in the + // request as follows: If-Match condition evaluates to true , and; + // If-Unmodified-Since condition evaluates to false ; then, S3 returns 200 OK and + // the data requested. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfUnmodifiedSince time.Time + + // Part number of the object being read. This is a positive integer between 1 and + // 10,000. Effectively performs a 'ranged' GET request for the part specified. + // Useful for downloading just a part of an object. + PartNumber int32 + + // Downloads the specified byte range of an object. For more information about the + // HTTP Range header, see [https://www.rfc-editor.org/rfc/rfc9110.html#name-range]. + // + // Amazon S3 doesn't support retrieving multiple ranges of data per GET request. + // + // [https://www.rfc-editor.org/rfc/rfc9110.html#name-range]: https://www.rfc-editor.org/rfc/rfc9110.html#name-range + Range string + + // Confirms that the requester knows that they will be charged for the request. + // Bucket owners need not specify this parameter in their requests. If either the + // source or destination S3 bucket has Requester Pays enabled, the requester will + // pay for corresponding charges to copy the object. For information about + // downloading objects from Requester Pays buckets, see [Downloading Objects in Requester Pays Buckets]in the Amazon S3 User + // Guide. + // + // This functionality is not supported for directory buckets. + // + // [Downloading Objects in Requester Pays Buckets]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html + RequestPayer types.RequestPayer + + // Sets the Cache-Control header of the response. + ResponseCacheControl string + + // Sets the Content-Disposition header of the response. + ResponseContentDisposition string + + // Sets the Content-Encoding header of the response. + ResponseContentEncoding string + + // Sets the Content-Language header of the response. + ResponseContentLanguage string + + // Sets the Content-Type header of the response. + ResponseContentType string + + // Sets the Expires header of the response. + ResponseExpires time.Time + + // Specifies the algorithm to use when decrypting the object (for example, AES256 ). + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerAlgorithm string + + // Specifies the customer-provided encryption key that you originally provided for + // Amazon S3 to encrypt the data before storing it. This value is used to decrypt + // the object when recovering it and must match the one used when storing the data. + // The key must be appropriate for use with the algorithm specified in the + // x-amz-server-side-encryption-customer-algorithm header. + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerKey string + + // Specifies the 128-bit MD5 digest of the customer-provided encryption key + // according to RFC 1321. Amazon S3 uses this header for a message integrity check + // to ensure that the encryption key was transmitted without error. + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerKeyMD5 string + + // Version ID used to reference a specific version of the object. + // + // By default, the GetObject operation returns the current version of an object. + // To return a different version, use the versionId subresource. + // + // - If you include a versionId in your request header, you must have the + // s3:GetObjectVersion permission to access a specific version of an object. The + // s3:GetObject permission is not required in this scenario. + // + // - If you request the current version of an object without a specific versionId + // in the request header, only the s3:GetObject permission is required. The + // s3:GetObjectVersion permission is not required in this scenario. + // + // - Directory buckets - S3 Versioning isn't enabled and supported for directory + // buckets. For this API operation, only the null value of the version ID is + // supported by directory buckets. You can only specify null to the versionId + // query parameter in the request. + // + // For more information about versioning, see [PutBucketVersioning]. + // + // [PutBucketVersioning]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketVersioning.html + VersionID string +} + +func (i DownloadObjectInput) mapGetObjectInput(enableChecksumValidation bool) *s3.GetObjectInput { + input := &s3.GetObjectInput{ + Bucket: aws.String(i.Bucket), + Key: aws.String(i.Key), + } + + if i.ChecksumMode != "" { + input.ChecksumMode = s3types.ChecksumMode(i.ChecksumMode) + } else if enableChecksumValidation { + input.ChecksumMode = s3types.ChecksumModeEnabled + } + + if i.RequestPayer != "" { + input.RequestPayer = s3types.RequestPayer(i.RequestPayer) + } + + input.ExpectedBucketOwner = nzstring(i.ExpectedBucketOwner) + input.IfMatch = nzstring(i.IfMatch) + input.IfNoneMatch = nzstring(i.IfNoneMatch) + input.IfModifiedSince = nztime(i.IfModifiedSince) + input.IfUnmodifiedSince = nztime(i.IfUnmodifiedSince) + input.ResponseCacheControl = nzstring(i.ResponseCacheControl) + input.ResponseContentDisposition = nzstring(i.ResponseContentDisposition) + input.ResponseContentEncoding = nzstring(i.ResponseContentEncoding) + input.ResponseContentLanguage = nzstring(i.ResponseContentLanguage) + input.ResponseContentType = nzstring(i.ResponseContentType) + input.ResponseExpires = nztime(i.ResponseExpires) + input.SSECustomerAlgorithm = nzstring(i.SSECustomerAlgorithm) + input.SSECustomerKey = nzstring(i.SSECustomerKey) + input.SSECustomerKeyMD5 = nzstring(i.SSECustomerKeyMD5) + input.VersionId = nzstring(i.VersionID) + + return input +} + +// DownloadObjectOutput represents a response from DownloadObject() call. It contains common fields +// of s3 GetObject output except Body which is replaced by WriterAt of input +type DownloadObjectOutput struct { + // Indicates that a range of bytes was specified in the request. + AcceptRanges string + + // Indicates whether the object uses an S3 Bucket Key for server-side encryption + // with Key Management Service (KMS) keys (SSE-KMS). + BucketKeyEnabled bool + + // Specifies caching behavior along the request/reply chain. + CacheControl string + + // Specifies if the response checksum validation is enabled + ChecksumMode types.ChecksumMode + + // The base64-encoded, 32-bit CRC-32 checksum of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumCRC32 string + + // The base64-encoded, 32-bit CRC-32C checksum of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumCRC32C string + + // The base64-encoded, 160-bit SHA-1 digest of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumSHA1 string + + // The base64-encoded, 256-bit SHA-256 digest of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumSHA256 string + + // Specifies presentational information for the object. + ContentDisposition string + + // Indicates what content encodings have been applied to the object and thus what + // decoding mechanisms must be applied to obtain the media-type referenced by the + // Content-Type header field. + ContentEncoding string + + // The language the content is in. + ContentLanguage string + + // Size of the body in bytes. + ContentLength int64 + + // The portion of the object returned in the response. + ContentRange string + + // A standard MIME type describing the format of the object data. + ContentType string + + // Indicates whether the object retrieved was (true) or was not (false) a Delete + // Marker. If false, this response header does not appear in the response. + // + // - If the current version of the object is a delete marker, Amazon S3 behaves + // as if the object was deleted and includes x-amz-delete-marker: true in the + // response. + // + // - If the specified version in the request is a delete marker, the response + // returns a 405 Method Not Allowed error and the Last-Modified: timestamp + // response header. + DeleteMarker bool + + // An entity tag (ETag) is an opaque identifier assigned by a web server to a + // specific version of a resource found at a URL. + ETag string + + // If the object expiration is configured (see [PutBucketLifecycleConfiguration]PutBucketLifecycleConfiguration ), + // the response includes this header. It includes the expiry-date and rule-id + // key-value pairs providing object expiration information. The value of the + // rule-id is URL-encoded. + // + // This functionality is not supported for directory buckets. + // + // [PutBucketLifecycleConfiguration]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketLifecycleConfiguration.html + Expiration string + + // The date and time at which the object is no longer cacheable. + // + // Deprecated: This field is handled inconsistently across AWS SDKs. Prefer using + // the ExpiresString field which contains the unparsed value from the service + // response. + Expires time.Time + + // The unparsed value of the Expires field from the service response. Prefer use + // of this value over the normal Expires response field where possible. + ExpiresString string + + // Date and time when the object was last modified. + // + // General purpose buckets - When you specify a versionId of the object in your + // request, if the specified version in the request is a delete marker, the + // response returns a 405 Method Not Allowed error and the Last-Modified: timestamp + // response header. + LastModified time.Time + + // A map of metadata to store with the object in S3. + // + // Map keys will be normalized to lower-case. + Metadata map[string]string + + // This is set to the number of metadata entries not returned in the headers that + // are prefixed with x-amz-meta- . This can happen if you create metadata using an + // API like SOAP that supports more flexible metadata than the REST API. For + // example, using SOAP, you can create metadata whose values are not legal HTTP + // headers. + // + // This functionality is not supported for directory buckets. + MissingMeta int32 + + // Indicates whether this object has an active legal hold. This field is only + // returned if you have permission to view an object's legal hold status. + // + // This functionality is not supported for directory buckets. + ObjectLockLegalHoldStatus types.ObjectLockLegalHoldStatus + + // The Object Lock mode that's currently in place for this object. + // + // This functionality is not supported for directory buckets. + ObjectLockMode types.ObjectLockMode + + // The date and time when this object's Object Lock will expire. + // + // This functionality is not supported for directory buckets. + ObjectLockRetainUntilDate time.Time + + // The count of parts this object has. This value is only returned if you specify + // partNumber in your request and the object was uploaded as a multipart upload. + PartsCount int32 + + // Amazon S3 can return this if your request involves a bucket that is either a + // source or destination in a replication rule. + // + // This functionality is not supported for directory buckets. + ReplicationStatus types.ReplicationStatus + + // If present, indicates that the requester was successfully charged for the + // request. + // + // This functionality is not supported for directory buckets. + RequestCharged types.RequestCharged + + // Provides information about object restoration action and expiration time of the + // restored object copy. + // + // This functionality is not supported for directory buckets. Only the S3 Express + // One Zone storage class is supported by directory buckets to store objects. + Restore string + + // If server-side encryption with a customer-provided encryption key was + // requested, the response will include this header to confirm the encryption + // algorithm that's used. + // + // This functionality is not supported for directory buckets. + SSECustomerAlgorithm string + + // If server-side encryption with a customer-provided encryption key was + // requested, the response will include this header to provide the round-trip + // message integrity verification of the customer-provided encryption key. + // + // This functionality is not supported for directory buckets. + SSECustomerKeyMD5 string + + // If present, indicates the ID of the KMS key that was used for object encryption. + SSEKMSKeyID string + + // The server-side encryption algorithm used when you store this object in Amazon + // S3. + ServerSideEncryption types.ServerSideEncryption + + // Provides storage class information of the object. Amazon S3 returns this header + // for all objects except for S3 Standard storage class objects. + // + // Directory buckets - Only the S3 Express One Zone storage class is supported by + // directory buckets to store objects. + StorageClass types.StorageClass + + // The number of tags, if any, on the object, when you have the relevant + // permission to read object tags. + // + // You can use [GetObjectTagging] to retrieve the tag set associated with an object. + // + // This functionality is not supported for directory buckets. + // + // [GetObjectTagging]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObjectTagging.html + TagCount int32 + + // Version ID of the object. + // + // This functionality is not supported for directory buckets. + VersionID string + + // If the bucket is configured as a website, redirects requests for this object to + // another object in the same bucket or to an external URL. Amazon S3 stores the + // value of this header in the object metadata. + // + // This functionality is not supported for directory buckets. + WebsiteRedirectLocation string + + // Metadata pertaining to the operation's result. + ResultMetadata smithymiddleware.Metadata +} + +func (o *DownloadObjectOutput) mapFromGetObjectOutput(out *s3.GetObjectOutput, checksumMode s3types.ChecksumMode) { + o.AcceptRanges = aws.ToString(out.AcceptRanges) + o.CacheControl = aws.ToString(out.CacheControl) + o.ChecksumMode = types.ChecksumMode(checksumMode) + o.ChecksumCRC32 = aws.ToString(out.ChecksumCRC32) + o.ChecksumCRC32C = aws.ToString(out.ChecksumCRC32C) + o.ChecksumSHA1 = aws.ToString(out.ChecksumSHA1) + o.ChecksumSHA256 = aws.ToString(out.ChecksumSHA256) + o.ContentDisposition = aws.ToString(out.ContentDisposition) + o.ContentEncoding = aws.ToString(out.ContentEncoding) + o.ContentLanguage = aws.ToString(out.ContentLanguage) + o.ContentRange = aws.ToString(out.ContentRange) + o.ContentType = aws.ToString(out.ContentType) + o.ETag = aws.ToString(out.ETag) + o.Expiration = aws.ToString(out.Expiration) + o.ExpiresString = aws.ToString(out.ExpiresString) + o.Restore = aws.ToString(out.Restore) + o.SSECustomerAlgorithm = aws.ToString(out.SSECustomerAlgorithm) + o.SSECustomerKeyMD5 = aws.ToString(out.SSECustomerKeyMD5) + o.SSEKMSKeyID = aws.ToString(out.SSEKMSKeyId) + o.VersionID = aws.ToString(out.VersionId) + o.WebsiteRedirectLocation = aws.ToString(out.WebsiteRedirectLocation) + o.BucketKeyEnabled = aws.ToBool(out.BucketKeyEnabled) + o.DeleteMarker = aws.ToBool(out.DeleteMarker) + o.MissingMeta = aws.ToInt32(out.MissingMeta) + o.PartsCount = aws.ToInt32(out.PartsCount) + o.TagCount = aws.ToInt32(out.TagCount) + o.ContentLength = aws.ToInt64(out.ContentLength) + o.Expires = aws.ToTime(out.Expires) + o.LastModified = aws.ToTime(out.LastModified) + o.ObjectLockRetainUntilDate = aws.ToTime(out.ObjectLockRetainUntilDate) + o.Metadata = out.Metadata + o.ObjectLockLegalHoldStatus = types.ObjectLockLegalHoldStatus(out.ObjectLockLegalHoldStatus) + o.ObjectLockMode = types.ObjectLockMode(out.ObjectLockMode) + o.ReplicationStatus = types.ReplicationStatus(out.ReplicationStatus) + o.RequestCharged = types.RequestCharged(out.RequestCharged) + o.ServerSideEncryption = types.ServerSideEncryption(out.ServerSideEncryption) + o.StorageClass = types.StorageClass(out.StorageClass) + o.ResultMetadata = out.ResultMetadata.Clone() +} + +// DownloadObject downloads an object from S3, intelligently splitting large +// files into smaller parts/ranges according to config and getting them in parallel across +// multiple goroutines. You can configure the download type, chunk size and concurrency +// through the Options parameters. +// +// Additional functional options can be provided to configure the individual +// download. These options are copies of the original Options instance, the client of which DownloadObject is called from. +// Modifying the options will not impact the original Client and Options instance. +func (c *Client) DownloadObject(ctx context.Context, input *DownloadObjectInput, opts ...func(*Options)) (*DownloadObjectOutput, error) { + i := downloader{in: input, options: c.options.Copy()} + for _, opt := range opts { + opt(&i.options) + } + + return i.download(ctx) +} + +type downloader struct { + options Options + in *DownloadObjectInput + out *DownloadObjectOutput + + wg sync.WaitGroup + m sync.Mutex + + offset int64 + pos int64 + totalBytes int64 + written int64 + + err error +} + +func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error) { + if err := d.init(ctx); err != nil { + return nil, fmt.Errorf("unable to initialize download: %w", err) + } + + clientOptions := []func(*s3.Options){ + func(o *s3.Options) { + o.APIOptions = append(o.APIOptions, + middleware.AddSDKAgentKey(middleware.FeatureMetadata, userAgentKey), + addFeatureUserAgent, + ) + }} + + if d.in.PartNumber > 0 { + return d.singleDownload(ctx, clientOptions...) + } + + var output *DownloadObjectOutput + if d.options.GetObjectType == types.GetObjectParts { + if d.in.Range != "" { + return d.singleDownload(ctx, clientOptions...) + } + output = d.getChunk(ctx, 1, "", clientOptions...) + if d.getErr() != nil { + return output, d.err + } + + if output.PartsCount > 1 { + partSize := output.ContentLength + ch := make(chan dlChunk, d.options.Concurrency) + for i := 0; i < d.options.Concurrency; i++ { + d.wg.Add(1) + go d.downloadPart(ctx, ch, clientOptions...) + } + + for i := int32(2); i <= output.PartsCount; i++ { + if d.getErr() != nil { + break + } + + ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: i} + d.pos += partSize + } + + close(ch) + d.wg.Wait() + } + } else { + if d.in.Range == "" { + output = d.getChunk(ctx, 0, d.byteRange(), clientOptions...) + } else { + d.pos, d.totalBytes = d.getDownloadRange() + d.offset = d.pos + } + total := d.totalBytes + + ch := make(chan dlChunk, d.options.Concurrency) + for i := 0; i < d.options.Concurrency; i++ { + d.wg.Add(1) + go d.downloadPart(ctx, ch, clientOptions...) + } + + // Assign work + for d.getErr() == nil { + if d.pos >= total { + break // We're finished queuing chunks + } + + // Queue the next range of bytes to read. + ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, withRange: d.byteRange()} + d.pos += d.options.PartSizeBytes + } + + // Wait for completion + close(ch) + d.wg.Wait() + } + + if d.err != nil { + return nil, d.err + } + + d.out.ContentLength = d.written + d.out.ContentRange = fmt.Sprintf("bytes=%d-%d", d.offset, d.totalBytes-1) + return d.out, nil +} + +func (d *downloader) init(ctx context.Context) error { + if d.options.PartSizeBytes < minPartSizeBytes { + return fmt.Errorf("part size must be at least %d bytes", minPartSizeBytes) + } + + if d.options.PartBodyMaxRetries < 0 { + return fmt.Errorf("part body retry must be non-negative") + } + + d.totalBytes = -1 + + return nil +} + +func (d *downloader) singleDownload(ctx context.Context, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) { + chunk := dlChunk{w: d.in.WriterAt} + // d.in.PartNumber = 0 + output, err := d.downloadChunk(ctx, chunk, clientOptions...) + if err != nil { + return nil, err + } + + return output, nil +} + +func (d *downloader) downloadPart(ctx context.Context, ch chan dlChunk, clientOptions ...func(*s3.Options)) { + defer d.wg.Done() + for { + chunk, ok := <-ch + if !ok { + break + } + if d.getErr() != nil { + continue + } + out, err := d.downloadChunk(ctx, chunk, clientOptions...) + if err != nil { + d.setErr(err) + } else { + d.setOutput(out) + } + } +} + +// getChunk grabs a chunk of data from the body. +// Not thread safe. Should only used when grabbing data on a single thread. +func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clientOptions ...func(*s3.Options)) *DownloadObjectOutput { + chunk := dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: part, withRange: rng} + + output, err := d.downloadChunk(ctx, chunk, clientOptions...) + if err != nil { + d.setErr(err) + return output + } + + d.setOutput(output) + d.pos += output.ContentLength + return output +} + +// downloadChunk downloads the chunk from s3 +func (d *downloader) downloadChunk(ctx context.Context, chunk dlChunk, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) { + params := d.in.mapGetObjectInput(!d.options.DisableChecksumValidation) + if chunk.part != 0 { + params.PartNumber = aws.Int32(chunk.part) + } + if chunk.withRange != "" { + params.Range = aws.String(chunk.withRange) + } + + var out *s3.GetObjectOutput + var n int64 + var err error + for retry := 0; retry < d.options.PartBodyMaxRetries; retry++ { + out, n, err = d.tryDownloadChunk(ctx, params, &chunk, clientOptions...) + if err == nil { + break + } + // Check if the returned error is an errReadingBody. + // If err is errReadingBody this indicates that an error + // occurred while copying the http response body. + // If this occurs we unwrap the err to set the underlying error + // and attempt any remaining retries. + if bodyErr, ok := err.(*errReadingBody); ok { + err = bodyErr + } else { + return nil, err + } + + chunk.cur = 0 + } + + d.incrWritten(n) + + var output *DownloadObjectOutput + if out != nil { + output = &DownloadObjectOutput{} + output.mapFromGetObjectOutput(out, params.ChecksumMode) + } + return output, err +} + +func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectInput, chunk *dlChunk, clientOptions ...func(*s3.Options)) (*s3.GetObjectOutput, int64, error) { + out, err := d.options.S3.GetObject(ctx, params, clientOptions...) + if err != nil { + return nil, 0, err + } + + d.setTotalBytes(out) // Set total if not yet set. + + var n int64 + defer out.Body.Close() + n, err = io.Copy(chunk, out.Body) + if err != nil { + return nil, 0, &errReadingBody{err: err} + } + + return out, n, nil +} + +func (d *downloader) incrWritten(n int64) { + d.m.Lock() + defer d.m.Unlock() + + d.written += n +} + +// getTotalBytes is a thread-safe getter for retrieving the total byte status. +func (d *downloader) getTotalBytes() int64 { + d.m.Lock() + defer d.m.Unlock() + + return d.totalBytes +} + +// setTotalBytes is a thread-safe setter for setting the total byte status. +// Will extract the object's total bytes from the Content-Range if the file +// will be chunked, or Content-Length. Content-Length is used when the response +// does not include a Content-Range. Meaning the object was not chunked. This +// occurs when the full file fits within the PartSize directive. +func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) { + d.m.Lock() + defer d.m.Unlock() + + if d.totalBytes >= 0 { + return + } + + if resp.ContentRange == nil { + // ContentRange is nil when the full file contents is provided, and + // is not chunked. Use ContentLength instead. + d.totalBytes = aws.ToInt64(resp.ContentLength) + } else { + parts := strings.Split(*resp.ContentRange, "/") + totalStr := parts[len(parts)-1] + total, err := strconv.ParseInt(totalStr, 10, 64) + if err != nil { + d.err = err + return + } + + d.totalBytes = total + } +} + +func (d *downloader) setOutput(resp *DownloadObjectOutput) { + d.m.Lock() + defer d.m.Unlock() + + if d.out != nil { + return + } + d.out = resp +} + +// TODO this might be shared beteen get and download +func (d *downloader) getDownloadRange() (int64, int64) { + parts := strings.Split(strings.Split(d.in.Range, "=")[1], "-") + + start, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + d.err = err + return 0, 0 + } + + end, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + d.err = err + return 0, 0 + } + + return start, end + 1 +} + +// byteRange returns a HTTP Byte-Range header value that should be used by the +// client to request a chunk range. +func (d *downloader) byteRange() string { + if d.totalBytes >= 0 { + return fmt.Sprintf("bytes=%d-%d", d.pos, int64(math.Min(float64(d.totalBytes-1), float64(d.pos+d.options.PartSizeBytes-1)))) + } + return fmt.Sprintf("bytes=%d-%d", d.pos, d.pos+d.options.PartSizeBytes-1) +} + +func (d *downloader) getErr() error { + d.m.Lock() + defer d.m.Unlock() + + return d.err +} + +func (d *downloader) setErr(e error) { + d.m.Lock() + defer d.m.Unlock() + + d.err = e +} + +type dlChunk struct { + w io.WriterAt + + start int64 + cur int64 + + part int32 + withRange string +} + +func (c *dlChunk) Write(p []byte) (int, error) { + n, err := c.w.WriteAt(p, c.start+c.cur) + c.cur += int64(n) + + return n, err +} diff --git a/feature/s3/transfermanager/api_op_DownloadObject_integ_test.go b/feature/s3/transfermanager/api_op_DownloadObject_integ_test.go new file mode 100644 index 00000000000..250552c3415 --- /dev/null +++ b/feature/s3/transfermanager/api_op_DownloadObject_integ_test.go @@ -0,0 +1,52 @@ +//go:build integration +// +build integration + +package transfermanager + +import ( + "bytes" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "strings" + "testing" +) + +func TestInteg_DownloadObject(t *testing.T) { + cases := map[string]getObjectTestData{ + "part get seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")}, + "part get empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")}, + "part get multipart body": {Body: bytes.NewReader(largeObjectBuf), ExpectBody: largeObjectBuf}, + "range get seekable body": { + Body: strings.NewReader("hello world"), + ExpectBody: []byte("hello world"), + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + "range get empty string body": { + Body: strings.NewReader(""), + ExpectError: "InvalidRange", + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + "range get multipart body": { + Body: bytes.NewReader(largeObjectBuf), + ExpectBody: largeObjectBuf, + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + testDownloadObject(t, setupMetadata.Buckets.Source.Name, c) + }) + } +} diff --git a/feature/s3/transfermanager/api_op_GetObject.go b/feature/s3/transfermanager/api_op_GetObject.go new file mode 100644 index 00000000000..19d8b449b75 --- /dev/null +++ b/feature/s3/transfermanager/api_op_GetObject.go @@ -0,0 +1,916 @@ +package transfermanager + +import ( + "bytes" + "context" + "fmt" + "io" + "strconv" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +type errReadingBody struct { + err error +} + +func (e *errReadingBody) Error() string { + return fmt.Sprintf("failed to read part body: %v", e.err) +} + +type errInvalidRange struct { + max int64 +} + +func (e *errInvalidRange) Error() string { + return fmt.Sprintf("invalid input range, must be between 0 and %d", e.max) +} + +// GetObjectInput represents a request to the GetObject() or DownloadObject() call. It contains common fields +// of s3 GetObject input +type GetObjectInput struct { + // Bucket where the object is downloaded from + Bucket string + + // Key of the object to get. + Key string + + Reader *ConcurrentReader + + // To retrieve the checksum, this mode must be enabled. + // + // General purpose buckets - In addition, if you enable checksum mode and the + // object is uploaded with a [checksum]and encrypted with an Key Management Service (KMS) + // key, you must have permission to use the kms:Decrypt action to retrieve the + // checksum. + // + // [checksum]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_Checksum.html + ChecksumMode types.ChecksumMode + + // The account ID of the expected bucket owner. If the account ID that you provide + // does not match the actual owner of the bucket, the request fails with the HTTP + // status code 403 Forbidden (access denied). + ExpectedBucketOwner string + + // Return the object only if its entity tag (ETag) is the same as the one + // specified in this header; otherwise, return a 412 Precondition Failed error. + // + // If both of the If-Match and If-Unmodified-Since headers are present in the + // request as follows: If-Match condition evaluates to true , and; + // If-Unmodified-Since condition evaluates to false ; then, S3 returns 200 OK and + // the data requested. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfMatch string + + // Return the object only if it has been modified since the specified time; + // otherwise, return a 304 Not Modified error. + // + // If both of the If-None-Match and If-Modified-Since headers are present in the + // request as follows: If-None-Match condition evaluates to false , and; + // If-Modified-Since condition evaluates to true ; then, S3 returns 304 Not + // Modified status code. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfModifiedSince time.Time + + // Return the object only if its entity tag (ETag) is different from the one + // specified in this header; otherwise, return a 304 Not Modified error. + // + // If both of the If-None-Match and If-Modified-Since headers are present in the + // request as follows: If-None-Match condition evaluates to false , and; + // If-Modified-Since condition evaluates to true ; then, S3 returns 304 Not + // Modified HTTP status code. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfNoneMatch string + + // Return the object only if it has not been modified since the specified time; + // otherwise, return a 412 Precondition Failed error. + // + // If both of the If-Match and If-Unmodified-Since headers are present in the + // request as follows: If-Match condition evaluates to true , and; + // If-Unmodified-Since condition evaluates to false ; then, S3 returns 200 OK and + // the data requested. + // + // For more information about conditional requests, see [RFC 7232]. + // + // [RFC 7232]: https://tools.ietf.org/html/rfc7232 + IfUnmodifiedSince time.Time + + // Part number of the object being read. This is a positive integer between 1 and + // 10,000. Effectively performs a 'ranged' GET request for the part specified. + // Useful for downloading just a part of an object. + PartNumber int32 + + // Downloads the specified byte range of an object. For more information about the + // HTTP Range header, see [https://www.rfc-editor.org/rfc/rfc9110.html#name-range]. + // + // Amazon S3 doesn't support retrieving multiple ranges of data per GET request. + // + // [https://www.rfc-editor.org/rfc/rfc9110.html#name-range]: https://www.rfc-editor.org/rfc/rfc9110.html#name-range + Range string + + // Confirms that the requester knows that they will be charged for the request. + // Bucket owners need not specify this parameter in their requests. If either the + // source or destination S3 bucket has Requester Pays enabled, the requester will + // pay for corresponding charges to copy the object. For information about + // downloading objects from Requester Pays buckets, see [Downloading Objects in Requester Pays Buckets]in the Amazon S3 User + // Guide. + // + // This functionality is not supported for directory buckets. + // + // [Downloading Objects in Requester Pays Buckets]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html + RequestPayer types.RequestPayer + + // Sets the Cache-Control header of the response. + ResponseCacheControl string + + // Sets the Content-Disposition header of the response. + ResponseContentDisposition string + + // Sets the Content-Encoding header of the response. + ResponseContentEncoding string + + // Sets the Content-Language header of the response. + ResponseContentLanguage string + + // Sets the Content-Type header of the response. + ResponseContentType string + + // Sets the Expires header of the response. + ResponseExpires time.Time + + // Specifies the algorithm to use when decrypting the object (for example, AES256 ). + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerAlgorithm string + + // Specifies the customer-provided encryption key that you originally provided for + // Amazon S3 to encrypt the data before storing it. This value is used to decrypt + // the object when recovering it and must match the one used when storing the data. + // The key must be appropriate for use with the algorithm specified in the + // x-amz-server-side-encryption-customer-algorithm header. + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerKey string + + // Specifies the 128-bit MD5 digest of the customer-provided encryption key + // according to RFC 1321. Amazon S3 uses this header for a message integrity check + // to ensure that the encryption key was transmitted without error. + // + // If you encrypt an object by using server-side encryption with customer-provided + // encryption keys (SSE-C) when you store the object in Amazon S3, then when you + // GET the object, you must use the following headers: + // + // - x-amz-server-side-encryption-customer-algorithm + // + // - x-amz-server-side-encryption-customer-key + // + // - x-amz-server-side-encryption-customer-key-MD5 + // + // For more information about SSE-C, see [Server-Side Encryption (Using Customer-Provided Encryption Keys)] in the Amazon S3 User Guide. + // + // This functionality is not supported for directory buckets. + // + // [Server-Side Encryption (Using Customer-Provided Encryption Keys)]: https://docs.aws.amazon.com/AmazonS3/latest/dev/ServerSideEncryptionCustomerKeys.html + SSECustomerKeyMD5 string + + // Version ID used to reference a specific version of the object. + // + // By default, the GetObject operation returns the current version of an object. + // To return a different version, use the versionId subresource. + // + // - If you include a versionId in your request header, you must have the + // s3:GetObjectVersion permission to access a specific version of an object. The + // s3:GetObject permission is not required in this scenario. + // + // - If you request the current version of an object without a specific versionId + // in the request header, only the s3:GetObject permission is required. The + // s3:GetObjectVersion permission is not required in this scenario. + // + // - Directory buckets - S3 Versioning isn't enabled and supported for directory + // buckets. For this API operation, only the null value of the version ID is + // supported by directory buckets. You can only specify null to the versionId + // query parameter in the request. + // + // For more information about versioning, see [PutBucketVersioning]. + // + // [PutBucketVersioning]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketVersioning.html + VersionID string +} + +func (i GetObjectInput) mapGetObjectInput(enableChecksumValidation bool) *s3.GetObjectInput { + input := &s3.GetObjectInput{ + Bucket: aws.String(i.Bucket), + Key: aws.String(i.Key), + } + + if i.ChecksumMode != "" { + input.ChecksumMode = s3types.ChecksumMode(i.ChecksumMode) + } else if enableChecksumValidation { + input.ChecksumMode = s3types.ChecksumModeEnabled + } + + if i.RequestPayer != "" { + input.RequestPayer = s3types.RequestPayer(i.RequestPayer) + } + + input.ExpectedBucketOwner = nzstring(i.ExpectedBucketOwner) + input.IfMatch = nzstring(i.IfMatch) + input.IfNoneMatch = nzstring(i.IfNoneMatch) + input.IfModifiedSince = nztime(i.IfModifiedSince) + input.IfUnmodifiedSince = nztime(i.IfUnmodifiedSince) + input.ResponseCacheControl = nzstring(i.ResponseCacheControl) + input.ResponseContentDisposition = nzstring(i.ResponseContentDisposition) + input.ResponseContentEncoding = nzstring(i.ResponseContentEncoding) + input.ResponseContentLanguage = nzstring(i.ResponseContentLanguage) + input.ResponseContentType = nzstring(i.ResponseContentType) + input.ResponseExpires = nztime(i.ResponseExpires) + input.SSECustomerAlgorithm = nzstring(i.SSECustomerAlgorithm) + input.SSECustomerKey = nzstring(i.SSECustomerKey) + input.SSECustomerKeyMD5 = nzstring(i.SSECustomerKeyMD5) + input.VersionId = nzstring(i.VersionID) + + return input +} + +// GetObjectOutput represents a response from GetObject() or DownloadObject() call. It contains common fields +// of s3 GetObject output +type GetObjectOutput struct { + // Indicates that a range of bytes was specified in the request. + AcceptRanges string + + // Object data. + Body io.ReadCloser + + // Indicates whether the object uses an S3 Bucket Key for server-side encryption + // with Key Management Service (KMS) keys (SSE-KMS). + BucketKeyEnabled bool + + // Specifies caching behavior along the request/reply chain. + CacheControl string + + // Specifies if the response checksum validation is enabled + ChecksumMode types.ChecksumMode + + // The base64-encoded, 32-bit CRC-32 checksum of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumCRC32 string + + // The base64-encoded, 32-bit CRC-32C checksum of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumCRC32C string + + // The base64-encoded, 160-bit SHA-1 digest of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumSHA1 string + + // The base64-encoded, 256-bit SHA-256 digest of the object. This will only be + // present if it was uploaded with the object. For more information, see [Checking object integrity]in the + // Amazon S3 User Guide. + // + // [Checking object integrity]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + ChecksumSHA256 string + + // Specifies presentational information for the object. + ContentDisposition string + + // Indicates what content encodings have been applied to the object and thus what + // decoding mechanisms must be applied to obtain the media-type referenced by the + // Content-Type header field. + ContentEncoding string + + // The language the content is in. + ContentLanguage string + + // Size of the body in bytes. + ContentLength int64 + + // The portion of the object returned in the response. + ContentRange string + + // A standard MIME type describing the format of the object data. + ContentType string + + // Indicates whether the object retrieved was (true) or was not (false) a Delete + // Marker. If false, this response header does not appear in the response. + // + // - If the current version of the object is a delete marker, Amazon S3 behaves + // as if the object was deleted and includes x-amz-delete-marker: true in the + // response. + // + // - If the specified version in the request is a delete marker, the response + // returns a 405 Method Not Allowed error and the Last-Modified: timestamp + // response header. + DeleteMarker bool + + // An entity tag (ETag) is an opaque identifier assigned by a web server to a + // specific version of a resource found at a URL. + ETag string + + // If the object expiration is configured (see [PutBucketLifecycleConfiguration]PutBucketLifecycleConfiguration ), + // the response includes this header. It includes the expiry-date and rule-id + // key-value pairs providing object expiration information. The value of the + // rule-id is URL-encoded. + // + // This functionality is not supported for directory buckets. + // + // [PutBucketLifecycleConfiguration]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketLifecycleConfiguration.html + Expiration string + + // The date and time at which the object is no longer cacheable. + // + // Deprecated: This field is handled inconsistently across AWS SDKs. Prefer using + // the ExpiresString field which contains the unparsed value from the service + // response. + Expires time.Time + + // The unparsed value of the Expires field from the service response. Prefer use + // of this value over the normal Expires response field where possible. + ExpiresString string + + // Date and time when the object was last modified. + // + // General purpose buckets - When you specify a versionId of the object in your + // request, if the specified version in the request is a delete marker, the + // response returns a 405 Method Not Allowed error and the Last-Modified: timestamp + // response header. + LastModified time.Time + + // A map of metadata to store with the object in S3. + // + // Map keys will be normalized to lower-case. + Metadata map[string]string + + // This is set to the number of metadata entries not returned in the headers that + // are prefixed with x-amz-meta- . This can happen if you create metadata using an + // API like SOAP that supports more flexible metadata than the REST API. For + // example, using SOAP, you can create metadata whose values are not legal HTTP + // headers. + // + // This functionality is not supported for directory buckets. + MissingMeta int32 + + // Indicates whether this object has an active legal hold. This field is only + // returned if you have permission to view an object's legal hold status. + // + // This functionality is not supported for directory buckets. + ObjectLockLegalHoldStatus types.ObjectLockLegalHoldStatus + + // The Object Lock mode that's currently in place for this object. + // + // This functionality is not supported for directory buckets. + ObjectLockMode types.ObjectLockMode + + // The date and time when this object's Object Lock will expire. + // + // This functionality is not supported for directory buckets. + ObjectLockRetainUntilDate time.Time + + // The count of parts this object has. This value is only returned if you specify + // partNumber in your request and the object was uploaded as a multipart upload. + PartsCount int32 + + // Amazon S3 can return this if your request involves a bucket that is either a + // source or destination in a replication rule. + // + // This functionality is not supported for directory buckets. + ReplicationStatus types.ReplicationStatus + + // If present, indicates that the requester was successfully charged for the + // request. + // + // This functionality is not supported for directory buckets. + RequestCharged types.RequestCharged + + // Provides information about object restoration action and expiration time of the + // restored object copy. + // + // This functionality is not supported for directory buckets. Only the S3 Express + // One Zone storage class is supported by directory buckets to store objects. + Restore string + + // If server-side encryption with a customer-provided encryption key was + // requested, the response will include this header to confirm the encryption + // algorithm that's used. + // + // This functionality is not supported for directory buckets. + SSECustomerAlgorithm string + + // If server-side encryption with a customer-provided encryption key was + // requested, the response will include this header to provide the round-trip + // message integrity verification of the customer-provided encryption key. + // + // This functionality is not supported for directory buckets. + SSECustomerKeyMD5 string + + // If present, indicates the ID of the KMS key that was used for object encryption. + SSEKMSKeyID string + + // The server-side encryption algorithm used when you store this object in Amazon + // S3. + ServerSideEncryption types.ServerSideEncryption + + // Provides storage class information of the object. Amazon S3 returns this header + // for all objects except for S3 Standard storage class objects. + // + // Directory buckets - Only the S3 Express One Zone storage class is supported by + // directory buckets to store objects. + StorageClass types.StorageClass + + // The number of tags, if any, on the object, when you have the relevant + // permission to read object tags. + // + // You can use [GetObjectTagging] to retrieve the tag set associated with an object. + // + // This functionality is not supported for directory buckets. + // + // [GetObjectTagging]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObjectTagging.html + TagCount int32 + + // Version ID of the object. + // + // This functionality is not supported for directory buckets. + VersionID string + + // If the bucket is configured as a website, redirects requests for this object to + // another object in the same bucket or to an external URL. Amazon S3 stores the + // value of this header in the object metadata. + // + // This functionality is not supported for directory buckets. + WebsiteRedirectLocation string + + // Metadata pertaining to the operation's result. + ResultMetadata smithymiddleware.Metadata +} + +func (o *GetObjectOutput) mapFromGetObjectOutput(out *s3.GetObjectOutput, checksumMode s3types.ChecksumMode) { + o.AcceptRanges = aws.ToString(out.AcceptRanges) + o.CacheControl = aws.ToString(out.CacheControl) + o.ChecksumMode = types.ChecksumMode(checksumMode) + o.ChecksumCRC32 = aws.ToString(out.ChecksumCRC32) + o.ChecksumCRC32C = aws.ToString(out.ChecksumCRC32C) + o.ChecksumSHA1 = aws.ToString(out.ChecksumSHA1) + o.ChecksumSHA256 = aws.ToString(out.ChecksumSHA256) + o.ContentDisposition = aws.ToString(out.ContentDisposition) + o.ContentEncoding = aws.ToString(out.ContentEncoding) + o.ContentLanguage = aws.ToString(out.ContentLanguage) + o.ContentRange = aws.ToString(out.ContentRange) + o.ContentType = aws.ToString(out.ContentType) + o.ETag = aws.ToString(out.ETag) + o.Expiration = aws.ToString(out.Expiration) + o.ExpiresString = aws.ToString(out.ExpiresString) + o.Restore = aws.ToString(out.Restore) + o.SSECustomerAlgorithm = aws.ToString(out.SSECustomerAlgorithm) + o.SSECustomerKeyMD5 = aws.ToString(out.SSECustomerKeyMD5) + o.SSEKMSKeyID = aws.ToString(out.SSEKMSKeyId) + o.VersionID = aws.ToString(out.VersionId) + o.WebsiteRedirectLocation = aws.ToString(out.WebsiteRedirectLocation) + o.BucketKeyEnabled = aws.ToBool(out.BucketKeyEnabled) + o.DeleteMarker = aws.ToBool(out.DeleteMarker) + o.MissingMeta = aws.ToInt32(out.MissingMeta) + o.PartsCount = aws.ToInt32(out.PartsCount) + o.TagCount = aws.ToInt32(out.TagCount) + o.ContentLength = aws.ToInt64(out.ContentLength) + o.Body = out.Body + o.Expires = aws.ToTime(out.Expires) + o.LastModified = aws.ToTime(out.LastModified) + o.ObjectLockRetainUntilDate = aws.ToTime(out.ObjectLockRetainUntilDate) + o.Metadata = out.Metadata + o.ObjectLockLegalHoldStatus = types.ObjectLockLegalHoldStatus(out.ObjectLockLegalHoldStatus) + o.ObjectLockMode = types.ObjectLockMode(out.ObjectLockMode) + o.ReplicationStatus = types.ReplicationStatus(out.ReplicationStatus) + o.RequestCharged = types.RequestCharged(out.RequestCharged) + o.ServerSideEncryption = types.ServerSideEncryption(out.ServerSideEncryption) + o.StorageClass = types.StorageClass(out.StorageClass) + o.ResultMetadata = out.ResultMetadata.Clone() +} + +// GetObject downloads an object from S3, intelligently splitting large +// files into smaller parts/ranges according to config and getting them in parallel across +// multiple goroutines. You can configure the download type, chunk size and concurrency +// through the Options parameters. +// +// Additional functional options can be provided to configure the individual +// download. These options are copies of the original Options instance, the client of which GetObject is called from. +// Modifying the options will not impact the original Client and Options instance. +// +// Before calling GetObject to download object, you must create a ConcurrentReader and use that reader to +// copy response content to your final destination file or buffer. This new reader type implements io.Reader to +// concurrently download parts of large object while limiting the max local cache size during download to prevent +// too much memory space consumption when getting large objects up to multi-gigabytes. You could configure that buffer +// size by changing Options.GetBufferSize. +// +// Example of creating ConcurrentReader to call GetObject: +// +// file, err := os.Create("your filename") +// if err != nil { +// log.Fatal("error when creating local file: ", err) +// } +// r := transfermanager.NewConcurrentReader() +// var wg sync.WaitGroup +// wg.Add(1) +// +// // You must read from the r in a separate goroutine to drive getter to get all parts. +// +// go func() { +// defer wg.Done() +// _, err := io.Copy(file, r) +// if err != nil { +// log.Fatal("error when writing to local file: ", err) +// } +// }() +// +// out, err := svc.GetObject(context.Background(), &transfermanager.GetObjectInput{ +// Bucket: "your-bucket", +// Key: "your-key", +// Reader: r, +// }) +// +// // must wait for r.Read() to finish +// wg.Wait() +// if err != nil { +// log.Fatal("error when downloading file: ", err) +// } +func (c *Client) GetObject(ctx context.Context, input *GetObjectInput, opts ...func(*Options)) (*GetObjectOutput, error) { + i := getter{in: input, options: c.options.Copy(), r: input.Reader} + for _, opt := range opts { + opt(&i.options) + } + + return i.get(ctx) +} + +type getter struct { + options Options + in *GetObjectInput + out *GetObjectOutput + w *types.WriteAtBuffer + r *ConcurrentReader + + wg sync.WaitGroup + m sync.Mutex + + offset int64 + pos int64 + totalBytes int64 + written int64 + + err error +} + +func (g *getter) get(ctx context.Context) (out *GetObjectOutput, err error) { + if err := g.init(ctx); err != nil { + return nil, fmt.Errorf("unable to initialize download: %w", err) + } + + clientOptions := []func(*s3.Options){ + func(o *s3.Options) { + o.APIOptions = append(o.APIOptions, + middleware.AddSDKAgentKey(middleware.FeatureMetadata, userAgentKey), + addFeatureUserAgent, + ) + }} + + defer close(g.r.ch) + if g.in.PartNumber > 0 { + return g.singleDownload(ctx, clientOptions...) + } + + if g.options.GetObjectType == types.GetObjectParts { + if g.in.Range != "" { + return g.singleDownload(ctx, clientOptions...) + } + // must know the part size before creating stream reader + out, err := g.options.S3.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(g.in.Bucket), + Key: aws.String(g.in.Key), + PartNumber: aws.Int32(1), + }, clientOptions...) + if err != nil { + g.r.setErr(err) + return nil, err + } + + partsCount := max(aws.ToInt32(out.PartsCount), 1) + partSize := max(aws.ToInt64(out.ContentLength), 1) + sectionParts := int32(max(1, g.options.GetBufferSize/partSize)) + capacity := sectionParts + g.r.setPartSize(partSize) + g.r.setCapacity(min(capacity, partsCount)) + g.r.setPartsCount(partsCount) + + ch := make(chan getChunk, g.options.Concurrency) + for i := 0; i < g.options.Concurrency; i++ { + g.wg.Add(1) + go g.downloadPart(ctx, ch, clientOptions...) + } + + var i int32 + for i < partsCount { + if g.getErr() != nil { + break + } + + if g.r.getRead() == capacity { + capacity = min(capacity+sectionParts, partsCount) + g.r.setCapacity(capacity) + } + + if i == capacity { + continue + } + + ch <- getChunk{start: g.pos - g.offset, part: i + 1, index: i} + i++ + g.pos += partSize + } + + close(ch) + g.wg.Wait() + } else { + out, err := g.options.S3.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(g.in.Bucket), + Key: aws.String(g.in.Key), + }, clientOptions...) + if err != nil { + g.r.setErr(err) + return nil, err + } + if aws.ToInt64(out.ContentLength) == 0 { + return g.singleDownload(ctx, clientOptions...) + } + g.totalBytes = aws.ToInt64(out.ContentLength) + if g.in.Range != "" { + start, totalBytes := g.getDownloadRange() + if start < 0 || start >= g.totalBytes || totalBytes > g.totalBytes || start >= totalBytes { + err := &errInvalidRange{ + max: g.totalBytes - 1, + } + g.r.setErr(err) + return nil, err + } + g.pos = start + g.totalBytes = totalBytes + g.offset = start + } + total := g.totalBytes - g.offset + + partsCount := int32((total-1)/g.options.PartSizeBytes + 1) + sectionParts := int32(max(1, g.options.GetBufferSize/g.options.PartSizeBytes)) + capacity := sectionParts + g.r.setPartSize(g.options.PartSizeBytes) + g.r.setCapacity(min(capacity, partsCount)) + g.r.setPartsCount(partsCount) + + ch := make(chan getChunk, g.options.Concurrency) + for i := 0; i < g.options.Concurrency; i++ { + g.wg.Add(1) + go g.downloadPart(ctx, ch, clientOptions...) + } + + var i int32 + for i < partsCount { + if g.getErr() != nil { + break + } + + if g.r.getRead() == capacity { + capacity = min(capacity+sectionParts, partsCount) + g.r.setCapacity(capacity) + } + + if i == capacity { + continue + } + + ch <- getChunk{start: g.pos - g.offset, withRange: g.byteRange(), index: i} + i++ + g.pos += g.options.PartSizeBytes + } + + // Wait for completion + close(ch) + g.wg.Wait() + } + + if g.err != nil { + g.r.setErr(g.err) + return nil, g.err + } + + g.out.ContentLength = g.written + g.out.ContentRange = fmt.Sprintf("bytes=%d-%d", g.offset, g.offset+g.written-1) + return g.out, nil +} + +func (g *getter) init(ctx context.Context) error { + if g.options.PartSizeBytes < minPartSizeBytes { + return fmt.Errorf("part size must be at least %d bytes", minPartSizeBytes) + } + if g.r == nil { + return fmt.Errorf("concurrent reader is required in input") + } + + g.r.ch = make(chan outChunk, g.options.Concurrency) + g.totalBytes = -1 + + return nil +} + +func (g *getter) singleDownload(ctx context.Context, clientOptions ...func(*s3.Options)) (*GetObjectOutput, error) { + params := g.in.mapGetObjectInput(!g.options.DisableChecksumValidation) + out, err := g.options.S3.GetObject(ctx, params, clientOptions...) + if err != nil { + g.r.setErr(err) + return nil, err + } + + defer out.Body.Close() + buf, err := io.ReadAll(out.Body) + if err != nil { + g.r.setErr(err) + return nil, err + } + + g.r.setPartSize(max(1, aws.ToInt64(out.ContentLength))) + g.r.setCapacity(1) + g.r.setPartsCount(1) + g.r.ch <- outChunk{body: bytes.NewReader(buf), length: aws.ToInt64(out.ContentLength)} + var output GetObjectOutput + output.mapFromGetObjectOutput(out, params.ChecksumMode) + return &output, nil +} + +func (g *getter) downloadPart(ctx context.Context, ch chan getChunk, clientOptions ...func(*s3.Options)) { + defer g.wg.Done() + for { + chunk, ok := <-ch + if !ok { + break + } + if g.getErr() != nil { + continue + } + out, err := g.downloadChunk(ctx, chunk, clientOptions...) + if err != nil { + g.setErr(err) + } else { + g.setOutput(out) + } + } +} + +// downloadChunk downloads the chunk from s3 +func (g *getter) downloadChunk(ctx context.Context, chunk getChunk, clientOptions ...func(*s3.Options)) (*GetObjectOutput, error) { + params := g.in.mapGetObjectInput(!g.options.DisableChecksumValidation) + if chunk.part != 0 { + params.PartNumber = aws.Int32(chunk.part) + } + if chunk.withRange != "" { + params.Range = aws.String(chunk.withRange) + } + + out, err := g.options.S3.GetObject(ctx, params, clientOptions...) + if err != nil { + return nil, err + } + + defer out.Body.Close() + buf, err := io.ReadAll(out.Body) + g.incrWritten(int64(len(buf))) + + if err != nil { + return nil, err + } + g.r.ch <- outChunk{body: bytes.NewReader(buf), index: chunk.index, length: aws.ToInt64(out.ContentLength)} + + output := &GetObjectOutput{} + output.mapFromGetObjectOutput(out, params.ChecksumMode) + return output, err +} + +func (g *getter) setOutput(resp *GetObjectOutput) { + g.m.Lock() + defer g.m.Unlock() + + if g.out != nil { + return + } + g.out = resp +} + +func (g *getter) incrWritten(n int64) { + g.m.Lock() + defer g.m.Unlock() + + g.written += n +} + +func (g *getter) getDownloadRange() (int64, int64) { + parts := strings.Split(strings.Split(g.in.Range, "=")[1], "-") + + start, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + g.err = err + return 0, 0 + } + + end, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + g.err = err + return 0, 0 + } + + return start, end + 1 +} + +// byteRange returns an HTTP Byte-Range header value that should be used by the +// client to request a chunk range. +func (g *getter) byteRange() string { + return fmt.Sprintf("bytes=%d-%d", g.pos, min(g.totalBytes-1, g.pos+g.options.PartSizeBytes-1)) +} + +func (g *getter) getErr() error { + g.m.Lock() + defer g.m.Unlock() + + return g.err +} + +func (g *getter) setErr(e error) { + g.m.Lock() + defer g.m.Unlock() + + g.err = e +} + +type getChunk struct { + start int64 + cur int64 + + part int32 + withRange string + + index int32 +} + +type outChunk struct { + body io.Reader + index int32 + + length int64 + cur int64 +} diff --git a/feature/s3/transfermanager/api_op_GetObject_integ_test.go b/feature/s3/transfermanager/api_op_GetObject_integ_test.go new file mode 100644 index 00000000000..8acd85e3cab --- /dev/null +++ b/feature/s3/transfermanager/api_op_GetObject_integ_test.go @@ -0,0 +1,52 @@ +//go:build integration +// +build integration + +package transfermanager + +import ( + "bytes" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "strings" + "testing" +) + +func TestInteg_GetObject(t *testing.T) { + cases := map[string]getObjectTestData{ + "part get seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")}, + "part get empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")}, + "part get multipart body": {Body: bytes.NewReader(largeObjectBuf), ExpectBody: largeObjectBuf}, + "range get seekable body": { + Body: strings.NewReader("hello world"), + ExpectBody: []byte("hello world"), + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + "range get empty string body": { + Body: strings.NewReader(""), + ExpectBody: []byte(""), + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + "range get multipart body": { + Body: bytes.NewReader(largeObjectBuf), + ExpectBody: largeObjectBuf, + OptFns: []func(*Options){ + func(opt *Options) { + opt.GetObjectType = types.GetObjectRanges + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + testGetObject(t, setupMetadata.Buckets.Source.Name, c) + }) + } +} diff --git a/feature/s3/transfermanager/api_op_PutObject.go b/feature/s3/transfermanager/api_op_PutObject.go index 0785ec4e81e..22d6eba084d 100644 --- a/feature/s3/transfermanager/api_op_PutObject.go +++ b/feature/s3/transfermanager/api_op_PutObject.go @@ -432,7 +432,7 @@ func (i PutObjectInput) mapSingleUploadInput(body io.Reader, checksumAlgorithm t return input } -func (i PutObjectInput) mapCreateMultipartUploadInput() *s3.CreateMultipartUploadInput { +func (i PutObjectInput) mapCreateMultipartUploadInput(checksumAlgorithm types.ChecksumAlgorithm) *s3.CreateMultipartUploadInput { input := &s3.CreateMultipartUploadInput{ Bucket: aws.String(i.Bucket), Key: aws.String(i.Key), @@ -443,7 +443,7 @@ func (i PutObjectInput) mapCreateMultipartUploadInput() *s3.CreateMultipartUploa if i.ChecksumAlgorithm != "" { input.ChecksumAlgorithm = s3types.ChecksumAlgorithm(i.ChecksumAlgorithm) } else { - input.ChecksumAlgorithm = s3types.ChecksumAlgorithm(i.ChecksumAlgorithm) + input.ChecksumAlgorithm = s3types.ChecksumAlgorithm(checksumAlgorithm) } if i.ObjectLockLegalHoldStatus != "" { input.ObjectLockLegalHoldStatus = s3types.ObjectLockLegalHoldStatus(i.ObjectLockLegalHoldStatus) @@ -505,7 +505,7 @@ func (i PutObjectInput) mapCompleteMultipartUploadInput(uploadID *string, comple return input } -func (i PutObjectInput) mapUploadPartInput(body io.Reader, partNum *int32, uploadID *string) *s3.UploadPartInput { +func (i PutObjectInput) mapUploadPartInput(body io.Reader, partNum *int32, uploadID *string, checksumAlgorithm types.ChecksumAlgorithm) *s3.UploadPartInput { input := &s3.UploadPartInput{ Bucket: aws.String(i.Bucket), Key: aws.String(i.Key), @@ -515,7 +515,10 @@ func (i PutObjectInput) mapUploadPartInput(body io.Reader, partNum *int32, uploa } if i.ChecksumAlgorithm != "" { input.ChecksumAlgorithm = s3types.ChecksumAlgorithm(i.ChecksumAlgorithm) + } else { + input.ChecksumAlgorithm = s3types.ChecksumAlgorithm(checksumAlgorithm) } + if i.RequestPayer != "" { input.RequestPayer = s3types.RequestPayer(i.RequestPayer) } @@ -534,7 +537,7 @@ func (i *PutObjectInput) mapAbortMultipartUploadInput(uploadID *string) *s3.Abor return input } -// PutObjectOutput represents a response from the Upload() call. It contains common fields +// PutObjectOutput represents a response from the PutObject() call. It contains common fields // of s3 PutObject and CompleteMultipartUpload output type PutObjectOutput struct { // The ID for a multipart upload to S3. In the case of an error the error @@ -802,7 +805,7 @@ func (cp completedParts) Swap(i, j int) { // upload will perform a multipart upload using the firstBuf buffer containing // the first chunk of data. func (u *multiUploader) upload(ctx context.Context, firstBuf io.Reader, cleanup func(), clientOptions ...func(*s3.Options)) (*PutObjectOutput, error) { - params := u.uploader.in.mapCreateMultipartUploadInput() + params := u.uploader.in.mapCreateMultipartUploadInput(u.options.ChecksumAlgorithm) // Create a multipart resp, err := u.uploader.options.S3.CreateMultipartUpload(ctx, params, clientOptions...) @@ -902,7 +905,7 @@ func (u *multiUploader) readChunk(ctx context.Context, ch chan ulChunk, clientOp // send performs an UploadPart request and keeps track of the completed // part information. func (u *multiUploader) send(ctx context.Context, c ulChunk, clientOptions ...func(*s3.Options)) error { - params := u.in.mapUploadPartInput(c.buf, c.partNum, u.uploadID) + params := u.in.mapUploadPartInput(c.buf, c.partNum, u.uploadID, u.options.ChecksumAlgorithm) resp, err := u.options.S3.UploadPart(ctx, params, clientOptions...) if err != nil { return err diff --git a/feature/s3/transfermanager/api_op_PutObject_integ_test.go b/feature/s3/transfermanager/api_op_PutObject_integ_test.go index 853f0441b4a..921e237e5dc 100644 --- a/feature/s3/transfermanager/api_op_PutObject_integ_test.go +++ b/feature/s3/transfermanager/api_op_PutObject_integ_test.go @@ -10,8 +10,6 @@ import ( ) func TestInteg_PutObject(t *testing.T) { - t.Skip("broken until multipart upload addressed") - cases := map[string]putObjectTestData{ "seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")}, "empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")}, diff --git a/feature/s3/transfermanager/concurrent_reader.go b/feature/s3/transfermanager/concurrent_reader.go new file mode 100644 index 00000000000..085df8ee8e5 --- /dev/null +++ b/feature/s3/transfermanager/concurrent_reader.go @@ -0,0 +1,195 @@ +package transfermanager + +import ( + "io" + "sync" +) + +// ConcurrentReader receives object parts from working goroutines, composes those chunks in order and read +// to user's buffer. ConcurrentReader limits the max number of chunks it could receive and read at the same +// time so getter won't send following parts' request to s3 until user reads all current chunks, which avoids +// too much memory consumption when caching large object parts +type ConcurrentReader struct { + ch chan outChunk + buf map[int32]*outChunk + + partsCount int32 + capacity int32 + count int32 + read int32 + + written int64 + partSize int64 + + m sync.Mutex + + err error +} + +// NewConcurrentReader returns a ConcurrentReader used in GetObject input +func NewConcurrentReader() *ConcurrentReader { + return &ConcurrentReader{ + buf: make(map[int32]*outChunk), + partSize: 1, // just a placeholder value + } +} + +// Read implements io.Reader to compose object parts in order and read to p. +// It will receive up to r.capacity chunks, read them to p if any chunk index +// fits into p scope, otherwise it will buffer those chunks and read them in +// following calls +func (r *ConcurrentReader) Read(p []byte) (int, error) { + if cap(p) == 0 { + return 0, nil + } + + var written int + + for r.count < r.getCapacity() { + if e := r.getErr(); e != nil && e != io.EOF { + r.written += int64(written) + r.clean() + return written, r.getErr() + } + if written >= cap(p) { + r.written += int64(written) + return written, r.getErr() + } + + oc, ok := <-r.ch + if !ok { + r.written += int64(written) + return written, r.getErr() + } + + r.count++ + index := r.getPartSize()*int64(oc.index) - r.written + + if index < int64(cap(p)) { + n, err := oc.body.Read(p[index:]) + oc.cur += int64(n) + written += n + if err != nil && err != io.EOF { + r.setErr(err) + r.clean() + r.written += int64(written) + return written, r.getErr() + } + } + if oc.cur < oc.length { + r.buf[oc.index] = &oc + } else { + r.incrRead(1) + if r.getRead() >= r.partsCount { + r.setErr(io.EOF) + } + } + } + + partSize := r.getPartSize() + minIndex := int32(r.written / partSize) + maxIndex := min(int32((r.written+int64(cap(p))-1)/partSize), r.getCapacity()-1) + for i := minIndex; i <= maxIndex; i++ { + if e := r.getErr(); e != nil && e != io.EOF { + r.written += int64(written) + r.clean() + return written, r.getErr() + } + + c, ok := r.buf[i] + if ok { + index := int64(i)*partSize + c.cur - r.written + n, err := c.body.Read(p[index:]) + c.cur += int64(n) + written += n + if err != nil && err != io.EOF { + r.setErr(err) + r.clean() + r.written += int64(written) + return written, r.getErr() + } + if c.cur >= c.length { + r.incrRead(1) + delete(r.buf, i) + if r.getRead() >= r.partsCount { + r.setErr(io.EOF) + } + } + } + } + + r.written += int64(written) + return written, r.getErr() +} + +func (r *ConcurrentReader) setPartSize(n int64) { + r.m.Lock() + defer r.m.Unlock() + + r.partSize = n +} + +func (r *ConcurrentReader) getPartSize() int64 { + r.m.Lock() + defer r.m.Unlock() + + return r.partSize +} + +func (r *ConcurrentReader) setCapacity(n int32) { + r.m.Lock() + defer r.m.Unlock() + + r.capacity = n +} + +func (r *ConcurrentReader) getCapacity() int32 { + r.m.Lock() + defer r.m.Unlock() + + return r.capacity +} + +func (r *ConcurrentReader) setPartsCount(n int32) { + r.m.Lock() + defer r.m.Unlock() + + r.partsCount = n +} + +func (r *ConcurrentReader) incrRead(n int32) { + r.m.Lock() + defer r.m.Unlock() + + r.read += n +} + +func (r *ConcurrentReader) getRead() int32 { + r.m.Lock() + defer r.m.Unlock() + + return r.read +} + +func (r *ConcurrentReader) setErr(err error) { + r.m.Lock() + defer r.m.Unlock() + + r.err = err +} + +func (r *ConcurrentReader) getErr() error { + r.m.Lock() + defer r.m.Unlock() + + return r.err +} + +func (r *ConcurrentReader) clean() { + for { + _, ok := <-r.ch + if !ok { + break + } + } +} diff --git a/feature/s3/transfermanager/concurrent_reader_test.go b/feature/s3/transfermanager/concurrent_reader_test.go new file mode 100644 index 00000000000..83c894da69b --- /dev/null +++ b/feature/s3/transfermanager/concurrent_reader_test.go @@ -0,0 +1,163 @@ +package transfermanager + +import ( + "bytes" + "context" + "io" + "math" + "math/rand" + "sync" + "testing" +) + +func TestConcurrentReader(t *testing.T) { + cases := map[string]struct { + partSize int64 + partsCount int32 + concurrency int + sectionParts int32 + }{ + "single goroutine": { + partSize: 10, + partsCount: 1000, + concurrency: 1, + sectionParts: 6, + }, + "single goroutine with only one section": { + partSize: 1000, + partsCount: 5, + concurrency: 3, + sectionParts: 6, + }, + "single goroutine with only one part": { + partSize: 1000, + partsCount: 1, + concurrency: 3, + sectionParts: 6, + }, + "multiple goroutines": { + partSize: 10, + partsCount: 1000, + concurrency: 5, + sectionParts: 6, + }, + "multiple goroutines with only one section": { + partSize: 10, + partsCount: 6, + concurrency: 5, + sectionParts: 6, + }, + "multiple goroutines with only one part": { + partSize: 10, + partsCount: 1, + concurrency: 5, + sectionParts: 6, + }, + "multiple goroutines with large part size": { + partSize: 10000, + partsCount: 10000, + concurrency: 5, + sectionParts: 6, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + r := NewConcurrentReader() + r.ch = make(chan outChunk, c.concurrency) + r.setCapacity(int32(math.Min(float64(c.sectionParts), float64(c.partsCount)))) + r.setPartSize(c.partSize) + r.setPartsCount(c.partsCount) + ctx := context.Background() + var wg sync.WaitGroup + expectBuf := make([]byte, 0) + actualBuf := make([]byte, 0) + + wg.Add(1) + go func() { + defer wg.Done() + b, err := io.ReadAll(r) + if err != nil { + if err != io.EOF { + t.Error("error copying file: ", err) + } + return + } + + actualBuf = append(actualBuf, b...) + }() + + getter := mockGetter{} + ch := make(chan inChunk, c.concurrency) + + for i := 0; i < c.concurrency; i++ { + getter.wg.Add(1) + go getter.partGet(ctx, ch, r.ch) + } + + var i int32 + for { + if i == c.partsCount { + break + } + + if capacity := r.getCapacity(); r.getRead() == capacity { + r.setCapacity(int32(math.Min(float64(capacity+c.sectionParts), float64(c.partsCount)))) + } + + if i == r.getCapacity() { + continue + } + + b := make([]byte, c.partSize) + if i == c.partsCount-1 { + b = make([]byte, rand.Intn(int(c.partSize))+1) + } + rand.Read(b) + expectBuf = append(expectBuf, b...) + ch <- inChunk{ + index: i, + body: b, + } + i++ + } + + wg.Wait() + close(ch) + getter.wg.Wait() + close(r.ch) + + if e, a := len(expectBuf), len(actualBuf); e != a { + t.Errorf("expect data sent to have length %d, but got %d", e, a) + } + if e, a := expectBuf, actualBuf; !bytes.Equal(e, a) { + t.Errorf("expect data sent to be %v, got %v", e, a) + } + }) + } +} + +type mockGetter struct { + wg sync.WaitGroup +} + +func (g *mockGetter) partGet(ctx context.Context, inputCh chan inChunk, outCh chan outChunk) { + defer g.wg.Done() + for { + inC, ok := <-inputCh + if !ok { + break + } + + outCh <- outChunk{ + index: inC.index, + body: bytes.NewReader(inC.body), + length: int64(len(inC.body)), + } + } +} + +type inChunk struct { + body []byte + index int32 +} diff --git a/feature/s3/transfermanager/downloadobject_test.go b/feature/s3/transfermanager/downloadobject_test.go new file mode 100644 index 00000000000..7e4b7d19811 --- /dev/null +++ b/feature/s3/transfermanager/downloadobject_test.go @@ -0,0 +1,458 @@ +package transfermanager + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/internal/testing" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func TestDownloadObject(t *testing.T) { + cases := map[string]struct { + data []byte + errReaders []s3testing.TestErrReader + getObjectFn func(*s3testing.TransferManagerLoggingClient, *s3.GetObjectInput) (*s3.GetObjectOutput, error) + options Options + downloadRange string + expectInvocations int + expectRanges []string + partNumber int32 + partsCount int32 + expectParts []int32 + expectErr string + dataValidationFn func(*testing.T, *types.WriteAtBuffer) + }{ + "range download in order": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 3, + expectRanges: []string{"bytes=0-8388607", "bytes=8388608-16777215", "bytes=16777216-20971519"}, + }, + "range download zero": { + data: []byte{}, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectRanges: []string{"bytes=0-8388607"}, + }, + "range download with customized part size": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + PartSizeBytes: 10 * 1024 * 1024, + }, + expectInvocations: 2, + expectRanges: []string{"bytes=0-10485759", "bytes=10485760-20971519"}, + }, + "range download with s3 error": { + data: buf20MB, + getObjectFn: s3testing.ErrRangeGetObjectFn, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 2, + expectErr: "s3 service error", + }, + "content length download single chunk": { + data: buf2MB, + getObjectFn: s3testing.NonRangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectRanges: []string{"bytes=0-8388607"}, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + count := 0 + for _, b := range w.Bytes() { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } + }, + }, + "range download single chunk": { + data: buf2MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectRanges: []string{"bytes=0-8388607"}, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + count := 0 + for _, b := range w.Bytes() { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } + }, + }, + "range download with success retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + {Buf: []byte("123"), Len: 3, Err: io.EOF}, + }, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 2, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "range download success without retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("123"), Len: 3, Err: io.EOF}, + }, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "range download fail retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + }, + options: Options{ + Concurrency: 1, + PartBodyMaxRetries: 1, + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectErr: "unexpected EOF", + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "ab", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "range download a range of object": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + Concurrency: 1, + GetObjectType: types.GetObjectRanges, + }, + downloadRange: "bytes=0-10485759", + expectInvocations: 2, + expectRanges: []string{"bytes=0-8388607", "bytes=8388608-10485759"}, + }, + "parts download in order": { + data: buf2MB, + getObjectFn: s3testing.PartGetObjectFn, + options: Options{ + Concurrency: 1, + }, + partsCount: 3, + expectInvocations: 3, + expectParts: []int32{1, 2, 3}, + }, + "part download zero": { + data: buf2MB, + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + }, + "part download with s3 error": { + data: buf2MB, + getObjectFn: s3testing.ErrPartGetObjectFn, + options: Options{ + Concurrency: 1, + }, + partsCount: 3, + expectInvocations: 2, + expectErr: "s3 service error", + }, + "part download single chunk": { + data: []byte("123"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "part download with success retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + {Buf: []byte("123"), Len: 3, Err: io.EOF}, + }, + options: Options{ + Concurrency: 1, + }, + partsCount: 1, + expectInvocations: 2, + expectParts: []int32{1, 1}, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "part download success without retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.EOF}, + }, + options: Options{ + Concurrency: 1, + }, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "ab", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "part download fail retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, + }, + options: Options{ + Concurrency: 1, + PartBodyMaxRetries: 1, + }, + expectInvocations: 1, + expectErr: "unexpected EOF", + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "ab", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "parts download with range input": { + data: []byte("123"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + downloadRange: "bytes=0-100", + partsCount: 3, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "123", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "parts download with part number input": { + data: []byte("ab"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 3, + partNumber: 5, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, w *types.WriteAtBuffer) { + if e, a := "ab", string(w.Bytes()); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + s3Client, invocations, parts, ranges := s3testing.NewDownloadClient() + s3Client.Data = c.data + s3Client.GetObjectFn = c.getObjectFn + s3Client.ErrReaders = c.errReaders + s3Client.PartsCount = c.partsCount + mgr := New(s3Client, c.options) + w := types.NewWriteAtBuffer(make([]byte, 0)) + + input := &DownloadObjectInput{ + Bucket: "bucket", + Key: "key", + WriterAt: w, + Range: c.downloadRange, + PartNumber: c.partNumber, + } + + _, err := mgr.DownloadObject(context.Background(), input) + if err != nil { + if c.expectErr == "" { + t.Fatalf("expect no error, got %q", err) + } else if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %s error message to be in %s", e, a) + } + } else { + if c.expectErr != "" { + t.Fatal("expect error, got nil") + } + } + + if err != nil { + return + } + + if e, a := c.expectInvocations, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + if len(c.expectParts) > 0 { + if e, a := c.expectParts, *parts; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v parts, got %v", e, a) + } + } + if len(c.expectRanges) > 0 { + if e, a := c.expectRanges, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } + } + + if c.dataValidationFn != nil { + c.dataValidationFn(t, w) + } + }) + } +} + +func TestDownloadAsyncWithFailure(t *testing.T) { + cases := map[string]struct { + downloadType types.GetObjectType + }{ + "part download by default": {}, + "range download": { + downloadType: types.GetObjectRanges, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + startingByte := 0 + reqCount := int64(0) + + s3Client, _, _, _ := s3testing.NewDownloadClient() + s3Client.GetObjectFn = func(c *s3testing.TransferManagerLoggingClient, params *s3.GetObjectInput) (out *s3.GetObjectOutput, err error) { + switch atomic.LoadInt64(&reqCount) { + case 1: + // Give a chance for the multipart chunks to be queued up + time.Sleep(1 * time.Second) + err = fmt.Errorf("some connection error") + default: + body := bytes.NewReader(make([]byte, minPartSizeBytes)) + out = &s3.GetObjectOutput{ + Body: ioutil.NopCloser(body), + ContentLength: aws.Int64(int64(body.Len())), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)), + PartsCount: aws.Int32(10), + } + + startingByte += body.Len() + if reqCount > 0 { + // sleep here to ensure context switching between goroutines + time.Sleep(25 * time.Millisecond) + } + } + atomic.AddInt64(&reqCount, 1) + return out, err + } + + d := New(s3Client, Options{ + Concurrency: 2, + GetObjectType: c.downloadType, + }) + + w := types.NewWriteAtBuffer(make([]byte, 0)) + + // Expect this request to exit quickly after failure + _, err := d.DownloadObject(context.Background(), &DownloadObjectInput{ + Bucket: "Bucket", + Key: "Key", + WriterAt: w, + }) + if err == nil { + t.Fatal("expect error, got none") + } else if e, a := "some connection error", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %s error message to be in %s", e, a) + } + + if atomic.LoadInt64(&reqCount) > 3 { + t.Errorf("expect no more than 3 requests, but received %d", reqCount) + } + }) + } +} + +func TestDownloadObjectWithContextCanceled(t *testing.T) { + cases := map[string]struct { + downloadType types.GetObjectType + }{ + "part download by default": {}, + "range download": { + downloadType: types.GetObjectRanges, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + d := New(s3.New(s3.Options{ + Region: "mock-region", + }), Options{ + GetObjectType: c.downloadType, + }) + + ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} + ctx.Error = fmt.Errorf("context canceled") + close(ctx.DoneCh) + + w := types.NewWriteAtBuffer(make([]byte, 0)) + + _, err := d.DownloadObject(ctx, &DownloadObjectInput{ + Bucket: "bucket", + Key: "Key", + WriterAt: w, + }) + if err == nil { + t.Fatalf("expected error, did not get one") + } + if e, a := "canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expected error message to contain %q, but did not %q", e, a) + } + }) + } +} diff --git a/feature/s3/transfermanager/getobject_test.go b/feature/s3/transfermanager/getobject_test.go new file mode 100644 index 00000000000..8948b2b7172 --- /dev/null +++ b/feature/s3/transfermanager/getobject_test.go @@ -0,0 +1,446 @@ +package transfermanager + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/internal/testing" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func TestGetObject(t *testing.T) { + cases := map[string]struct { + data []byte + errReaders []s3testing.TestErrReader + getObjectFn func(*s3testing.TransferManagerLoggingClient, *s3.GetObjectInput) (*s3.GetObjectOutput, error) + options Options + downloadRange string + expectInvocations int + expectRanges []string + partNumber int32 + partsCount int32 + expectParts []int32 + expectErr string + dataValidationFn func(*testing.T, []byte) + }{ + "range download in order": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + Concurrency: 1, + }, + expectInvocations: 3, + expectRanges: []string{"bytes=0-8388607", "bytes=8388608-16777215", "bytes=16777216-20971519"}, + }, + "range download zero": { + data: []byte{}, + getObjectFn: s3testing.NonRangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + }, + "range download with customized part size": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + PartSizeBytes: 10 * 1024 * 1024, + Concurrency: 1, + }, + expectInvocations: 2, + expectRanges: []string{"bytes=0-10485759", "bytes=10485760-20971519"}, + }, + "range download with s3 error": { + data: buf20MB, + getObjectFn: s3testing.ErrRangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 2, + expectErr: "s3 service error", + }, + "content length download single chunk": { + data: buf2MB, + getObjectFn: s3testing.NonRangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectRanges: []string{"bytes=0-2097151"}, + dataValidationFn: func(t *testing.T, bytes []byte) { + count := 0 + for _, b := range bytes { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } + }, + }, + "range download single chunk": { + data: buf2MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectRanges: []string{"bytes=0-2097151"}, + dataValidationFn: func(t *testing.T, bytes []byte) { + count := 0 + for _, b := range bytes { + count += int(b) + } + if count != 0 { + t.Errorf("expect 0 count, got %d", count) + } + }, + }, + "range download success without retry": { + data: []byte("123"), + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("123"), Len: 3, Err: io.EOF}, + }, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "123", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "range download fail retry": { + data: []byte("ab"), + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 2, Err: io.ErrUnexpectedEOF}, + }, + options: Options{ + GetObjectType: types.GetObjectRanges, + }, + expectInvocations: 1, + expectErr: "unexpected EOF", + }, + "range download a range of object": { + data: buf20MB, + getObjectFn: s3testing.RangeGetObjectFn, + options: Options{ + GetObjectType: types.GetObjectRanges, + Concurrency: 1, + }, + downloadRange: "bytes=0-10485759", + expectInvocations: 2, + expectRanges: []string{"bytes=0-8388607", "bytes=8388608-10485759"}, + }, + "parts download in order": { + data: buf2MB, + getObjectFn: s3testing.PartGetObjectFn, + options: Options{ + Concurrency: 1, + }, + partsCount: 3, + expectInvocations: 3, + expectParts: []int32{1, 2, 3}, + }, + "part download zero": { + data: buf2MB, + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + }, + "part download with s3 error": { + data: buf2MB, + getObjectFn: s3testing.ErrPartGetObjectFn, + options: Options{}, + partsCount: 3, + expectInvocations: 2, + expectErr: "s3 service error", + }, + "part download single chunk": { + data: []byte("123"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "123", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "part download success without retry": { + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 2, Err: io.EOF}, + }, + options: Options{}, + partsCount: 1, + expectInvocations: 1, + expectParts: []int32{1}, + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "ab", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "part download fail retry": { + data: []byte("ab"), + getObjectFn: s3testing.ErrReaderFn, + errReaders: []s3testing.TestErrReader{ + {Buf: []byte("ab"), Len: 2, Err: io.ErrUnexpectedEOF}, + }, + options: Options{}, + expectInvocations: 1, + expectErr: "unexpected EOF", + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "ab", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "parts download with range input": { + data: []byte("123"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + downloadRange: "bytes=0-100", + partsCount: 3, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "123", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + "parts download with part number input": { + data: []byte("ab"), + getObjectFn: s3testing.PartGetObjectFn, + options: Options{}, + partsCount: 3, + partNumber: 5, + expectInvocations: 1, + dataValidationFn: func(t *testing.T, bytes []byte) { + if e, a := "ab", string(bytes); e != a { + t.Errorf("expect %q response, got %q", e, a) + } + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + s3Client, invocations, parts, ranges := s3testing.NewDownloadClient() + s3Client.Data = c.data + s3Client.GetObjectFn = c.getObjectFn + s3Client.ErrReaders = c.errReaders + s3Client.PartsCount = c.partsCount + mgr := New(s3Client, c.options) + + input := &GetObjectInput{ + Bucket: "bucket", + Key: "key", + } + input.Range = c.downloadRange + input.PartNumber = c.partNumber + r := NewConcurrentReader() + input.Reader = r + + var wg sync.WaitGroup + actualBuf := make([]byte, 0) + + wg.Add(1) + go func() { + defer wg.Done() + b, err := io.ReadAll(r) + if err != nil { + if c.expectErr == "" { + t.Errorf("expect no error when copying file, got %q", err) + } else if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %s error message to be in %s", e, a) + } + //return + } else if c.expectErr != "" { + t.Error("expect an error, but got none") + //return + } + + actualBuf = append(actualBuf, b...) + }() + + _, err := mgr.GetObject(context.Background(), input) + wg.Wait() + + if err != nil { + if c.expectErr == "" { + t.Fatalf("expect no error, got %q", err) + } else if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %s error message to be in %s", e, a) + } + } else if c.expectErr != "" { + t.Fatal("expect error, got nil") + } + + if err != nil { + return + } + + if e, a := c.expectInvocations, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + if len(c.expectParts) > 0 { + if e, a := c.expectParts, *parts; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v parts, got %v", e, a) + } + } + if len(c.expectRanges) > 0 { + if e, a := c.expectRanges, *ranges; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v ranges, got %v", e, a) + } + } + + if c.dataValidationFn != nil { + c.dataValidationFn(t, actualBuf) + } + }) + } +} + +func TestGetAsyncWithFailure(t *testing.T) { + cases := map[string]struct { + downloadType types.GetObjectType + }{ + "part download by default": {}, + "range download": { + downloadType: types.GetObjectRanges, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + startingByte := 0 + reqCount := int64(0) + + s3Client, _, _, _ := s3testing.NewDownloadClient() + s3Client.PartsCount = 10 + s3Client.Data = buf40MB + s3Client.GetObjectFn = func(c *s3testing.TransferManagerLoggingClient, params *s3.GetObjectInput) (out *s3.GetObjectOutput, err error) { + switch atomic.LoadInt64(&reqCount) { + case 1: + // Give a chance for the multipart chunks to be queued up + time.Sleep(1 * time.Second) + err = fmt.Errorf("some connection error") + default: + body := bytes.NewReader(make([]byte, minPartSizeBytes)) + out = &s3.GetObjectOutput{ + Body: ioutil.NopCloser(body), + ContentLength: aws.Int64(int64(body.Len())), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)), + } + + startingByte += body.Len() + if reqCount > 0 { + // sleep here to ensure context switching between goroutines + time.Sleep(25 * time.Millisecond) + } + } + atomic.AddInt64(&reqCount, 1) + return out, err + } + + mgr := New(s3Client, Options{ + Concurrency: 2, + GetObjectType: c.downloadType, + }) + r := NewConcurrentReader() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, _ = io.ReadAll(r) + }() + + // Expect this request to exit quickly after failure + _, err := mgr.GetObject(context.Background(), &GetObjectInput{ + Bucket: "Bucket", + Key: "Key", + Reader: r, + }) + wg.Wait() + + if err == nil { + t.Fatal("expect error, got none") + } else if e, a := "some connection error", err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %s error message to be in %s", e, a) + } + + if atomic.LoadInt64(&reqCount) > 3 { + t.Errorf("expect no more than 3 requests, but received %d", reqCount) + } + }) + } +} + +func TestGetObjectWithContextCanceled(t *testing.T) { + cases := map[string]struct { + downloadType types.GetObjectType + }{ + "part download by default": {}, + "range download": { + downloadType: types.GetObjectRanges, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + mgr := New(s3.New(s3.Options{ + Region: "mock-region", + }), Options{ + GetObjectType: c.downloadType, + }) + + ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} + ctx.Error = fmt.Errorf("context canceled") + close(ctx.DoneCh) + + r := NewConcurrentReader() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, _ = io.ReadAll(r) + }() + _, err := mgr.GetObject(ctx, &GetObjectInput{ + Bucket: "bucket", + Key: "Key", + Reader: r, + }) + wg.Wait() + + if err == nil { + t.Fatalf("expected error, did not get one") + } + if e, a := "canceled", err.Error(); !strings.Contains(a, e) { + t.Errorf("expected error message to contain %q, but did not %q", e, a) + } + }) + } +} diff --git a/feature/s3/transfermanager/internal/testing/client.go b/feature/s3/transfermanager/internal/testing/client.go new file mode 100644 index 00000000000..2fda4a9ba65 --- /dev/null +++ b/feature/s3/transfermanager/internal/testing/client.go @@ -0,0 +1,353 @@ +package testing + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "regexp" + "slices" + "strconv" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// TransferManagerLoggingClient is a mock client that can be used to record and stub responses for testing the transfer manager. +type TransferManagerLoggingClient struct { + // params for upload test + + UploadInvocations []string + Params []interface{} + + ConsumeBody bool + + ignoredOperations []string + + PartNum int + + // params for download test + + Data []byte + PartsCount int32 + + GetObjectInvocations int + + RetrievedRanges []string + RetrievedParts []int32 + + ErrReaders []TestErrReader + index int + + m sync.Mutex + + PutObjectFn func(*TransferManagerLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) + UploadPartFn func(*TransferManagerLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) + CreateMultipartUploadFn func(*TransferManagerLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) + CompleteMultipartUploadFn func(*TransferManagerLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) + AbortMultipartUploadFn func(*TransferManagerLoggingClient, *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) + GetObjectFn func(*TransferManagerLoggingClient, *s3.GetObjectInput) (*s3.GetObjectOutput, error) +} + +func (c *TransferManagerLoggingClient) simulateHTTPClientOption(optFns ...func(*s3.Options)) error { + + o := s3.Options{ + HTTPClient: httpDoFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + Request: r, + }, nil + }), + } + + for _, fn := range optFns { + fn(&o) + } + + _, err := o.HTTPClient.Do(&http.Request{ + URL: &url.URL{ + Scheme: "https", + Host: "mock.amazonaws.com", + Path: "/key", + RawQuery: "foo=bar", + }, + }) + if err != nil { + return err + } + + return nil +} + +type httpDoFunc func(*http.Request) (*http.Response, error) + +func (f httpDoFunc) Do(r *http.Request) (*http.Response, error) { + return f(r) +} + +func (c *TransferManagerLoggingClient) traceOperation(name string, params interface{}) { + if slices.Contains(c.ignoredOperations, name) { + return + } + c.UploadInvocations = append(c.UploadInvocations, name) + c.Params = append(c.Params, params) + +} + +// PutObject is the S3 PutObject API. +func (c *TransferManagerLoggingClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + if c.ConsumeBody { + io.Copy(ioutil.Discard, params.Body) + } + + c.traceOperation("PutObject", params) + + if err := c.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if c.PutObjectFn != nil { + return c.PutObjectFn(c, params) + } + + return &s3.PutObjectOutput{ + VersionId: aws.String("VERSION-ID"), + }, nil +} + +// UploadPart is the S3 UploadPart API. +func (c *TransferManagerLoggingClient) UploadPart(ctx context.Context, params *s3.UploadPartInput, optFns ...func(*s3.Options)) (*s3.UploadPartOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + if c.ConsumeBody { + io.Copy(ioutil.Discard, params.Body) + } + + c.traceOperation("UploadPart", params) + + if err := c.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if c.UploadPartFn != nil { + return c.UploadPartFn(c, params) + } + + return &s3.UploadPartOutput{ + ETag: aws.String(fmt.Sprintf("ETAG%d", *params.PartNumber)), + }, nil +} + +// CreateMultipartUpload is the S3 CreateMultipartUpload API. +func (c *TransferManagerLoggingClient) CreateMultipartUpload(ctx context.Context, params *s3.CreateMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + c.traceOperation("CreateMultipartUpload", params) + + if err := c.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if c.CreateMultipartUploadFn != nil { + return c.CreateMultipartUploadFn(c, params) + } + + return &s3.CreateMultipartUploadOutput{ + UploadId: aws.String("UPLOAD-ID"), + }, nil +} + +// CompleteMultipartUpload is the S3 CompleteMultipartUpload API. +func (c *TransferManagerLoggingClient) CompleteMultipartUpload(ctx context.Context, params *s3.CompleteMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + c.traceOperation("CompleteMultipartUpload", params) + + if err := c.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if c.CompleteMultipartUploadFn != nil { + return c.CompleteMultipartUploadFn(c, params) + } + + return &s3.CompleteMultipartUploadOutput{ + Location: aws.String("http://location"), + VersionId: aws.String("VERSION-ID"), + }, nil +} + +// AbortMultipartUpload is the S3 AbortMultipartUpload API. +func (c *TransferManagerLoggingClient) AbortMultipartUpload(ctx context.Context, params *s3.AbortMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + c.traceOperation("AbortMultipartUpload", params) + if err := c.simulateHTTPClientOption(optFns...); err != nil { + return nil, err + } + + if c.AbortMultipartUploadFn != nil { + return c.AbortMultipartUploadFn(c, params) + } + + return &s3.AbortMultipartUploadOutput{}, nil +} + +// GetObject is the S3 GetObject API. +func (c *TransferManagerLoggingClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + c.GetObjectInvocations++ + + if params.Range != nil { + c.RetrievedRanges = append(c.RetrievedRanges, aws.ToString(params.Range)) + } + if params.PartNumber != nil { + c.RetrievedParts = append(c.RetrievedParts, aws.ToInt32(params.PartNumber)) + } + + if c.GetObjectFn != nil { + return c.GetObjectFn(c, params) + } + + return &s3.GetObjectOutput{}, nil +} + +// HeadObject is the S3 HeadObject API +func (c *TransferManagerLoggingClient) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { + c.m.Lock() + defer c.m.Unlock() + + return &s3.HeadObjectOutput{ + PartsCount: aws.Int32(c.PartsCount), + ContentLength: aws.Int64(int64(len(c.Data))), + }, nil +} + +// NewUploadLoggingClient returns a new TransferManagerLoggingClient for upload testing. +func NewUploadLoggingClient(ignoredOps []string) (*TransferManagerLoggingClient, *[]string, *[]interface{}) { + c := &TransferManagerLoggingClient{ + ignoredOperations: ignoredOps, + } + + return c, &c.UploadInvocations, &c.Params +} + +// NewDownloadClient returns a new TransferManagerLoggingClient for download testing +func NewDownloadClient() (*TransferManagerLoggingClient, *int, *[]int32, *[]string) { + c := &TransferManagerLoggingClient{} + + return c, &c.GetObjectInvocations, &c.RetrievedParts, &c.RetrievedRanges +} + +var rangeValueRegex = regexp.MustCompile(`bytes=(\d+)-(\d+)`) + +func parseRange(rangeValue string) (start, fin int64) { + rng := rangeValueRegex.FindStringSubmatch(rangeValue) + start, _ = strconv.ParseInt(rng[1], 10, 64) + fin, _ = strconv.ParseInt(rng[2], 10, 64) + return start, fin +} + +// RangeGetObjectFn mocks getobject behavior of s3 client to return object in ranges +var RangeGetObjectFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + start, fin := parseRange(aws.ToString(params.Range)) + fin++ + + if fin >= int64(len(c.Data)) { + fin = int64(len(c.Data)) + } + + bodyBytes := c.Data[start:fin] + + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(c.Data))), + ContentLength: aws.Int64(int64(len(bodyBytes))), + }, nil +} + +// ErrRangeGetObjectFn mocks getobject behavior of s3 client to return service error when certain number of range get is called from s3 client +var ErrRangeGetObjectFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + out, err := RangeGetObjectFn(c, params) + c.index++ + if c.index > 1 { + return &s3.GetObjectOutput{}, fmt.Errorf("s3 service error") + } + return out, err +} + +// NonRangeGetObjectFn mocks getobject behavior of s3 client to return the whole object +var NonRangeGetObjectFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(c.Data[:])), + ContentLength: aws.Int64(int64(len(c.Data))), + }, nil +} + +// ErrReaderFn mocks getobject behavior of s3 client to return object parts triggering different readerror +var ErrReaderFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + r := c.ErrReaders[c.index] + out := &s3.GetObjectOutput{ + Body: ioutil.NopCloser(&r), + ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", 0, r.Len-1, r.Len)), + ContentLength: aws.Int64(r.Len), + PartsCount: aws.Int32(c.PartsCount), + } + c.index++ + return out, nil +} + +// PartGetObjectFn mocks getobject behavior of s3 client to return object parts and total parts count +var PartGetObjectFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(c.Data)), + ContentLength: aws.Int64(int64(len(c.Data))), + PartsCount: aws.Int32(c.PartsCount), + }, nil +} + +// ErrPartGetObjectFn mocks getobject behavior of s3 client to return service error when certain number of part get is called from s3 client +var ErrPartGetObjectFn = func(c *TransferManagerLoggingClient, params *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + out, err := PartGetObjectFn(c, params) + c.index++ + if c.index > 1 { + return &s3.GetObjectOutput{}, fmt.Errorf("s3 service error") + } + return out, err +} + +// TestErrReader mocks response's object body triggering specified error when read +type TestErrReader struct { + Buf []byte + Err error + Len int64 + + off int +} + +// Read implements io.Reader.Read() +func (r *TestErrReader) Read(p []byte) (int, error) { + to := len(r.Buf) - r.off + + n := copy(p, r.Buf[r.off:to]) + r.off += n + + if n < len(p) { + return n, r.Err + + } + + return n, nil +} diff --git a/feature/s3/transfermanager/internal/testing/upload.go b/feature/s3/transfermanager/internal/testing/upload.go deleted file mode 100644 index 1764fc089e2..00000000000 --- a/feature/s3/transfermanager/internal/testing/upload.go +++ /dev/null @@ -1,193 +0,0 @@ -package testing - -import ( - "context" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/url" - "slices" - "sync" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -// UploadLoggingClient is a mock client that can be used to record and stub responses for testing the Uploader. -type UploadLoggingClient struct { - Invocations []string - Params []interface{} - - ConsumeBody bool - - ignoredOperations []string - - PartNum int - m sync.Mutex - - PutObjectFn func(*UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) - UploadPartFn func(*UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) - CreateMultipartUploadFn func(*UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) - CompleteMultipartUploadFn func(*UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) - AbortMultipartUploadFn func(*UploadLoggingClient, *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) -} - -func (u *UploadLoggingClient) simulateHTTPClientOption(optFns ...func(*s3.Options)) error { - - o := s3.Options{ - HTTPClient: httpDoFunc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - Request: r, - }, nil - }), - } - - for _, fn := range optFns { - fn(&o) - } - - _, err := o.HTTPClient.Do(&http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: "mock.amazonaws.com", - Path: "/key", - RawQuery: "foo=bar", - }, - }) - if err != nil { - return err - } - - return nil -} - -type httpDoFunc func(*http.Request) (*http.Response, error) - -func (f httpDoFunc) Do(r *http.Request) (*http.Response, error) { - return f(r) -} - -func (u *UploadLoggingClient) traceOperation(name string, params interface{}) { - if slices.Contains(u.ignoredOperations, name) { - return - } - u.Invocations = append(u.Invocations, name) - u.Params = append(u.Params, params) - -} - -// PutObject is the S3 PutObject API. -func (u *UploadLoggingClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { - u.m.Lock() - defer u.m.Unlock() - - if u.ConsumeBody { - io.Copy(ioutil.Discard, params.Body) - } - - u.traceOperation("PutObject", params) - - if err := u.simulateHTTPClientOption(optFns...); err != nil { - return nil, err - } - - if u.PutObjectFn != nil { - return u.PutObjectFn(u, params) - } - - return &s3.PutObjectOutput{ - VersionId: aws.String("VERSION-ID"), - }, nil -} - -// UploadPart is the S3 UploadPart API. -func (u *UploadLoggingClient) UploadPart(ctx context.Context, params *s3.UploadPartInput, optFns ...func(*s3.Options)) (*s3.UploadPartOutput, error) { - u.m.Lock() - defer u.m.Unlock() - - if u.ConsumeBody { - io.Copy(ioutil.Discard, params.Body) - } - - u.traceOperation("UploadPart", params) - - if err := u.simulateHTTPClientOption(optFns...); err != nil { - return nil, err - } - - if u.UploadPartFn != nil { - return u.UploadPartFn(u, params) - } - - return &s3.UploadPartOutput{ - ETag: aws.String(fmt.Sprintf("ETAG%d", *params.PartNumber)), - }, nil -} - -// CreateMultipartUpload is the S3 CreateMultipartUpload API. -func (u *UploadLoggingClient) CreateMultipartUpload(ctx context.Context, params *s3.CreateMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) { - u.m.Lock() - defer u.m.Unlock() - - u.traceOperation("CreateMultipartUpload", params) - - if err := u.simulateHTTPClientOption(optFns...); err != nil { - return nil, err - } - - if u.CreateMultipartUploadFn != nil { - return u.CreateMultipartUploadFn(u, params) - } - - return &s3.CreateMultipartUploadOutput{ - UploadId: aws.String("UPLOAD-ID"), - }, nil -} - -// CompleteMultipartUpload is the S3 CompleteMultipartUpload API. -func (u *UploadLoggingClient) CompleteMultipartUpload(ctx context.Context, params *s3.CompleteMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) { - u.m.Lock() - defer u.m.Unlock() - - u.traceOperation("CompleteMultipartUpload", params) - - if err := u.simulateHTTPClientOption(optFns...); err != nil { - return nil, err - } - - if u.CompleteMultipartUploadFn != nil { - return u.CompleteMultipartUploadFn(u, params) - } - - return &s3.CompleteMultipartUploadOutput{ - Location: aws.String("http://location"), - VersionId: aws.String("VERSION-ID"), - }, nil -} - -// AbortMultipartUpload is the S3 AbortMultipartUpload API. -func (u *UploadLoggingClient) AbortMultipartUpload(ctx context.Context, params *s3.AbortMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) { - u.m.Lock() - defer u.m.Unlock() - - u.traceOperation("AbortMultipartUpload", params) - if err := u.simulateHTTPClientOption(optFns...); err != nil { - return nil, err - } - - if u.AbortMultipartUploadFn != nil { - return u.AbortMultipartUploadFn(u, params) - } - - return &s3.AbortMultipartUploadOutput{}, nil -} - -// NewUploadLoggingClient returns a new UploadLoggingClient. -func NewUploadLoggingClient(ignoredOps []string) (*UploadLoggingClient, *[]string, *[]interface{}) { - c := &UploadLoggingClient{ - ignoredOperations: ignoredOps, - } - - return c, &c.Invocations, &c.Params -} diff --git a/feature/s3/transfermanager/options.go b/feature/s3/transfermanager/options.go index a49e74afd64..770774166fb 100644 --- a/feature/s3/transfermanager/options.go +++ b/feature/s3/transfermanager/options.go @@ -1,6 +1,8 @@ package transfermanager -import "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" +import ( + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" +) // Options provides params needed for transfer api calls type Options struct { @@ -16,7 +18,7 @@ type Options struct { MultipartUploadThreshold int64 // Option to disable checksum validation for download - DisableChecksum bool + DisableChecksumValidation bool // Checksum algorithm to use for upload ChecksumAlgorithm types.ChecksumAlgorithm @@ -27,6 +29,15 @@ type Options struct { // // The concurrency pool is not shared between calls to Upload. Concurrency int + + // The type indicating if object is multi-downloaded in parts or ranges + GetObjectType types.GetObjectType + + // PartBodyMaxRetries is the number of retry attempts to make for failed part downloads. + PartBodyMaxRetries int + + // Max size for the get object buffer + GetBufferSize int64 } func (o *Options) init() { @@ -56,6 +67,24 @@ func resolveMultipartUploadThreshold(o *Options) { } } +func resolveGetObjectType(o *Options) { + if o.GetObjectType == "" { + o.GetObjectType = types.GetObjectParts + } +} + +func resolvePartBodyMaxRetries(o *Options) { + if o.PartBodyMaxRetries == 0 { + o.PartBodyMaxRetries = defaultPartBodyMaxRetries + } +} + +func resolveGetBufferSize(o *Options) { + if o.GetBufferSize == 0 { + o.GetBufferSize = defaultGetBufferSize + } +} + // Copy returns new copy of the Options func (o Options) Copy() Options { to := o diff --git a/feature/s3/transfermanager/putobject_test.go b/feature/s3/transfermanager/putobject_test.go index 06cd0ebe149..cc779e3d819 100644 --- a/feature/s3/transfermanager/putobject_test.go +++ b/feature/s3/transfermanager/putobject_test.go @@ -207,7 +207,7 @@ func TestUploadOrderSingle(t *testing.T) { func TestUploadSingleFailure(t *testing.T) { c, invocations, _ := s3testing.NewUploadLoggingClient(nil) - c.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) { + c.PutObjectFn = func(*s3testing.TransferManagerLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) { return nil, fmt.Errorf("put object failure") } @@ -260,7 +260,7 @@ func TestUploadOrderZero(t *testing.T) { func TestUploadOrderMultiFailure(t *testing.T) { c, invocations, _ := s3testing.NewUploadLoggingClient(nil) - c.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + c.UploadPartFn = func(u *s3testing.TransferManagerLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { if *params.PartNumber == 2 { return nil, fmt.Errorf("an unexpected error") } @@ -288,7 +288,7 @@ func TestUploadOrderMultiFailure(t *testing.T) { func TestUploadOrderMultiFailureOnComplete(t *testing.T) { c, invocations, _ := s3testing.NewUploadLoggingClient(nil) - c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { + c.CompleteMultipartUploadFn = func(*s3testing.TransferManagerLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { return nil, fmt.Errorf("complete multipart error") } @@ -314,7 +314,7 @@ func TestUploadOrderMultiFailureOnComplete(t *testing.T) { func TestUploadOrderMultiFailureOnCreate(t *testing.T) { c, invocations, _ := s3testing.NewUploadLoggingClient(nil) - c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { + c.CreateMultipartUploadFn = func(*s3testing.TransferManagerLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { return nil, fmt.Errorf("create multipart upload failure") } @@ -597,7 +597,7 @@ func TestUploadUnexpectedEOF(t *testing.T) { func TestSSE(t *testing.T) { c, _, _ := s3testing.NewUploadLoggingClient(nil) - c.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + c.UploadPartFn = func(u *s3testing.TransferManagerLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { if params.SSECustomerAlgorithm == nil { t.Fatal("SSECustomerAlgoritm should not be nil") } diff --git a/feature/s3/transfermanager/setup_integ_test.go b/feature/s3/transfermanager/setup_integ_test.go index d8efb29a220..764ce5a5ec8 100644 --- a/feature/s3/transfermanager/setup_integ_test.go +++ b/feature/s3/transfermanager/setup_integ_test.go @@ -15,14 +15,16 @@ import ( "net/http" "os" "strings" + "sync" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/types" "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/aws/aws-sdk-go-v2/service/s3/types" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -218,6 +220,13 @@ type putObjectTestData struct { ExpectError string } +type getObjectTestData struct { + Body io.Reader + ExpectBody []byte + ExpectError string + OptFns []func(*Options) +} + // UniqueID returns a unique UUID-like identifier for use in generating // resources for integration tests. // @@ -269,6 +278,100 @@ func testPutObject(t *testing.T, bucket string, testData putObjectTestData, opts } } +func testGetObject(t *testing.T, bucket string, testData getObjectTestData) { + key := UniqueID() + + _, err := s3Client.PutObject(context.Background(), + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: testData.Body, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + var b []byte + r := NewConcurrentReader() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + b, err = ioutil.ReadAll(r) + if err != nil { + t.Errorf("error when reading response body: %v", err) + } + }() + + _, err = s3TransferManagerClient.GetObject(context.Background(), + &GetObjectInput{ + Bucket: bucket, + Key: key, + Reader: r, + }, testData.OptFns...) + wg.Wait() + if err != nil { + if len(testData.ExpectError) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if e, a := testData.ExpectError, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + } else { + if e := testData.ExpectError; len(e) != 0 { + t.Fatalf("expect error: %v, got none", e) + } + } + if len(testData.ExpectError) != 0 { + return + } + + if e, a := testData.ExpectBody, b; !bytes.EqualFold(e, a) { + t.Errorf("expect %s, got %s", e, a) + } +} + +func testDownloadObject(t *testing.T, bucket string, testData getObjectTestData) { + key := UniqueID() + + _, err := s3Client.PutObject(context.Background(), + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: testData.Body, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + w := types.NewWriteAtBuffer(make([]byte, 0)) + _, err = s3TransferManagerClient.DownloadObject(context.Background(), + &DownloadObjectInput{ + Bucket: bucket, + Key: key, + WriterAt: w, + }, testData.OptFns...) + if err != nil { + if len(testData.ExpectError) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if e, a := testData.ExpectError, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + } else { + if e := testData.ExpectError; len(e) != 0 { + t.Fatalf("expect error: %v, got none", e) + } + } + if len(testData.ExpectError) != 0 { + return + } + + if e, a := testData.ExpectBody, w.Bytes(); !bytes.EqualFold(e, a) { + t.Errorf("expect %s, got %s", e, a) + } +} + // TODO: duped from service/internal/integrationtest, remove after beta. const expressAZID = "usw2-az3" @@ -307,7 +410,7 @@ func SetupBucket(ctx context.Context, svc *s3.Client, bucketName string) (err er fmt.Println("Setup: Creating test bucket,", bucketName) _, err = svc.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: &bucketName, - CreateBucketConfiguration: &types.CreateBucketConfiguration{ + CreateBucketConfiguration: &s3types.CreateBucketConfiguration{ LocationConstraint: "us-west-2", }, }) @@ -413,14 +516,14 @@ func SetupExpressBucket(ctx context.Context, svc *s3.Client, bucketName string) fmt.Println("Setup: Creating test express bucket,", bucketName) _, err := svc.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: &bucketName, - CreateBucketConfiguration: &types.CreateBucketConfiguration{ - Location: &types.LocationInfo{ + CreateBucketConfiguration: &s3types.CreateBucketConfiguration{ + Location: &s3types.LocationInfo{ Name: aws.String(expressAZID), - Type: types.LocationTypeAvailabilityZone, + Type: s3types.LocationTypeAvailabilityZone, }, - Bucket: &types.BucketInfo{ - DataRedundancy: types.DataRedundancySingleAvailabilityZone, - Type: types.BucketTypeDirectory, + Bucket: &s3types.BucketInfo{ + DataRedundancy: s3types.DataRedundancySingleAvailabilityZone, + Type: s3types.BucketTypeDirectory, }, }, }) diff --git a/feature/s3/transfermanager/shared_test.go b/feature/s3/transfermanager/shared_test.go index 364423e96c2..1c62ecab25b 100644 --- a/feature/s3/transfermanager/shared_test.go +++ b/feature/s3/transfermanager/shared_test.go @@ -2,3 +2,4 @@ package transfermanager var buf20MB = make([]byte, 1024*1024*20) var buf2MB = make([]byte, 1024*1024*2) +var buf40MB = make([]byte, 1024*1024*40) diff --git a/feature/s3/transfermanager/types/types.go b/feature/s3/transfermanager/types/types.go index 8a2d877e461..26ab2719e81 100644 --- a/feature/s3/transfermanager/types/types.go +++ b/feature/s3/transfermanager/types/types.go @@ -2,6 +2,7 @@ package types import ( "io" + "sync" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" @@ -344,3 +345,83 @@ const ( type Metadata struct { values map[interface{}]interface{} } + +// GetObjectType specifies how transfer manager should perform multipart download +type GetObjectType string + +// Enum values for MultipartDownloadType +const ( + GetObjectParts GetObjectType = "PART" + GetObjectRanges = "RANGE" +) + +// ChecksumMode indicates if the response checksum validation is enabled +type ChecksumMode string + +// Enum values for ChecksumMode +const ( + ChecksumModeEnabled ChecksumMode = "ENABLED" +) + +// ReplicationStatus indicates if your request involves a bucket that's either a +// source or destination in a replication rule +type ReplicationStatus string + +// Enum values for ReplicationStatus +const ( + ReplicationStatusComplete ReplicationStatus = "COMPLETE" + ReplicationStatusPending ReplicationStatus = "PENDING" + ReplicationStatusFailed ReplicationStatus = "FAILED" + ReplicationStatusReplica ReplicationStatus = "REPLICA" + ReplicationStatusCompleted ReplicationStatus = "COMPLETED" +) + +// A WriteAtBuffer provides a in memory buffer supporting the io.WriterAt interface +// Can be used with the s3manager.Downloader to download content to a buffer +// in memory. Safe to use concurrently. +type WriteAtBuffer struct { + buf []byte + m sync.Mutex + + // GrowthCoeff defines the growth rate of the internal buffer. By + // default, the growth rate is 1, where expanding the internal + // buffer will allocate only enough capacity to fit the new expected + // length. + GrowthCoeff float64 +} + +// NewWriteAtBuffer creates a WriteAtBuffer with an internal buffer +// provided by buf. +func NewWriteAtBuffer(buf []byte) *WriteAtBuffer { + return &WriteAtBuffer{buf: buf} +} + +// WriteAt writes a slice of bytes to a buffer starting at the position provided +// The number of bytes written will be returned, or error. Can overwrite previous +// written slices if the write ats overlap. +func (b *WriteAtBuffer) WriteAt(p []byte, pos int64) (n int, err error) { + pLen := len(p) + expLen := pos + int64(pLen) + b.m.Lock() + defer b.m.Unlock() + if int64(len(b.buf)) < expLen { + if int64(cap(b.buf)) < expLen { + if b.GrowthCoeff < 1 { + b.GrowthCoeff = 1 + } + newBuf := make([]byte, expLen, int64(b.GrowthCoeff*float64(expLen))) + copy(newBuf, b.buf) + b.buf = newBuf + } + b.buf = b.buf[:expLen] + } + copy(b.buf[pos:], p) + return pLen, nil +} + +// Bytes returns a slice of bytes written to the buffer. +func (b *WriteAtBuffer) Bytes() []byte { + b.m.Lock() + defer b.m.Unlock() + return b.buf +}