diff --git a/internal/controllers/s3.go b/internal/controllers/s3.go index cf9832d..c526c96 100644 --- a/internal/controllers/s3.go +++ b/internal/controllers/s3.go @@ -60,36 +60,62 @@ func AwsS3(w http.ResponseWriter, r *http.Request) { } path += c.IndexDocument } - // Get a S3 object - obj, err := client.S3get(c.S3Bucket, c.S3KeyPrefix+path, rangeHeader) - metrics.UpdateS3Reads(err, metrics.GetObjectAction, metrics.ProxySource) - if err != nil { - code, message := toHTTPError(err) - - if (code == 404 || code == 403) && c.SPA && !strings.Contains(path, c.IndexDocument) { - idx := strings.LastIndex(path, "/") - - if idx > -1 { - indexPath := c.S3KeyPrefix + path[:idx+1] + c.IndexDocument - - var indexError error - obj, indexError = client.S3get(c.S3Bucket, indexPath, rangeHeader) - - if indexError != nil { - code, message = toHTTPError(indexError) - http.Error(w, message, code) - return + switch r.Method { + case "GET": + // Get a S3 object + obj, err := client.S3get(c.S3Bucket, c.S3KeyPrefix+path, rangeHeader) + metrics.UpdateS3Reads(err, metrics.GetObjectAction, metrics.ProxySource) + if err != nil { + code, message := toHTTPError(err) + if (code == 404 || code == 403) && c.SPA && !strings.Contains(path, c.IndexDocument) { + idx := strings.LastIndex(path, "/") + if idx > -1 { + indexPath := c.S3KeyPrefix + path[:idx+1] + c.IndexDocument + var indexError error + obj, indexError = client.S3get(c.S3Bucket, indexPath, rangeHeader) + if indexError != nil { + code, message = toHTTPError(indexError) + http.Error(w, message, code) + return + } } + } else { + http.Error(w, message, code) + return + } + } + setHeadersFromAwsResponse(w, obj, c.HTTPCacheControl, c.HTTPExpires) + _, _ = io.Copy(w, obj.Body) // nolint + case "HEAD": + // Head a S3 object + obj, err := client.S3head(c.S3Bucket, c.S3KeyPrefix+path, rangeHeader) + // metrics.UpdateS3Reads(err, metrics.GetObjectAction, metrics.ProxySource) + if err != nil { + code, message := toHTTPError(err) + if (code == 404 || code == 403) && c.SPA && !strings.Contains(path, c.IndexDocument) { + idx := strings.LastIndex(path, "/") + if idx > -1 { + indexPath := c.S3KeyPrefix + path[:idx+1] + c.IndexDocument + var indexError error + obj, indexError = client.S3head(c.S3Bucket, indexPath, rangeHeader) + if indexError != nil { + code, message = toHTTPError(indexError) + http.Error(w, message, code) + return + } + } + } else { + http.Error(w, message, code) + return } - } else { - http.Error(w, message, code) - return } + setHeadersFromAwsHeadResponse(w, obj, c.HTTPCacheControl, c.HTTPExpires) + default: + // return method not allowed, 405 + http.Error(w, "Method Not Allowed", 405) + return } - setHeadersFromAwsResponse(w, obj, c.HTTPCacheControl, c.HTTPExpires) - - _, _ = io.Copy(w, obj.Body) // nolint } func replacePathWithSymlink(client service.AWS, bucket, symlinkPath string) (*string, error) { @@ -157,6 +183,51 @@ func setHeadersFromAwsResponse(w http.ResponseWriter, obj *s3.GetObjectOutput, h w.WriteHeader(determineHTTPStatus(obj)) } +func setHeadersFromAwsHeadResponse(w http.ResponseWriter, obj *s3.HeadObjectOutput, httpCacheControl, httpExpires string) { + + // Cache-Control + if len(httpCacheControl) > 0 { + setStrHeader(w, "Cache-Control", &httpCacheControl) + } else { + setStrHeader(w, "Cache-Control", obj.CacheControl) + } + // Expires + if len(httpExpires) > 0 { + setStrHeader(w, "Expires", &httpExpires) + } else { + setStrHeader(w, "Expires", obj.Expires) + } + setStrHeader(w, "Content-Encoding", obj.ContentEncoding) + setStrHeader(w, "Content-Language", obj.ContentLanguage) + + if len(w.Header().Get("Content-Encoding")) == 0 { + setIntHeader(w, "Content-Length", obj.ContentLength) + } + if config.Config.ContentType == "" { + setStrHeader(w, "Content-Type", obj.ContentType) + } else { + setStrHeader(w, "Content-Type", &config.Config.ContentType) + } + if config.Config.ContentDisposition == "" { + setStrHeader(w, "Content-Disposition", obj.ContentDisposition) + } else { + setStrHeader(w, "Content-Disposition", &config.Config.ContentDisposition) + } + setStrHeader(w, "ETag", obj.ETag) + setTimeHeader(w, "Last-Modified", obj.LastModified) + + // Location, rewrite to our own + if len(w.Header().Get("Location")) > 0 { + l, err := url.Parse(w.Header().Get("Location")) + if err == nil && strings.Contains(l.Host, config.Config.S3Bucket) { + path := l.RequestURI() + setStrHeader(w, "Location", &path) + } + } + + w.WriteHeader(http.StatusOK) +} + func setStrHeader(w http.ResponseWriter, key string, value *string) { if value != nil && len(*value) > 0 { w.Header().Add(key, *value) diff --git a/internal/service/amazon-s3.go b/internal/service/amazon-s3.go index 34b5b33..84d7d35 100644 --- a/internal/service/amazon-s3.go +++ b/internal/service/amazon-s3.go @@ -16,6 +16,16 @@ func (c client) S3get(bucket, key string, rangeHeader *string) (*s3.GetObjectOut return s3.New(c.Session).GetObjectWithContext(c.Context, req) } +// S3head returns a specified object metadata from Amazon S3 +func (c client) S3head(bucket, key string, rangeHeader *string) (*s3.HeadObjectOutput, error) { + req := &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Range: rangeHeader, + } + return s3.New(c.Session).HeadObjectWithContext(c.Context, req) +} + // S3exists returns true if a specified key exists in Amazon S3 func (c client) S3exists(bucket, key string) bool { req := &s3.HeadObjectInput{ diff --git a/internal/service/client.go b/internal/service/client.go index e0f9e63..97a4c3d 100644 --- a/internal/service/client.go +++ b/internal/service/client.go @@ -10,6 +10,7 @@ import ( // AWS is a service to interact with original AWS services type AWS interface { S3get(bucket, key string, rangeHeader *string) (*s3.GetObjectOutput, error) + S3head(bucket, key string, rangeHeader *string) (*s3.HeadObjectOutput, error) S3exists(bucket, key string) bool S3listObjects(bucket, prefix string) (*s3.ListObjectsOutput, error) }