diff --git a/README.md b/README.md index 76758473..126e0038 100644 --- a/README.md +++ b/README.md @@ -757,6 +757,29 @@ if err != nil { client.R().SetOutput(file).Get(url) ``` +**Download Callback** + +You can set `DownloadCallback` if you want to show download progress: + +```go +client := req.C() +client.R(). + SetOutputFile("test.gz"). + SetUploadCallback(func(info req.UploadInfo) { + fmt.Printf("downloaded %.2f%%\n", float64(info.DownloadedSize)/float64(info.Response.ContentLength)*100.0) + }).Post("https://exmaple.com/upload") +/* Output +downloaded 17.92% +downloaded 41.77% +downloaded 67.71% +downloaded 98.89% +downloaded 100.00% +*/ +``` + +> `info.Response.ContentLength` could be 0 or -1 when the total size is unknown. +> `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`. + **Multipart Upload** ```go diff --git a/client.go b/client.go index 875d39e6..85f74235 100644 --- a/client.go +++ b/client.go @@ -995,8 +995,28 @@ func (c *Client) do(r *Request) (resp *Response, err error) { for _, cookie := range r.Cookies { req.AddCookie(cookie) } - if r.ctx != nil { - req = req.WithContext(r.ctx) + ctx := r.ctx + if r.isSaveResponse && r.downloadCallback != nil { + var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser { + return &callbackReader{ + ReadCloser: rc, + callback: func(read int64) { + r.downloadCallback(DownloadInfo{ + Response: resp, + DownloadedSize: read, + }) + }, + lastTime: time.Now(), + interval: r.downloadCallbackInterval, + } + } + if ctx == nil { + ctx = context.Background() + } + ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap) + } + if ctx != nil { + req = req.WithContext(ctx) } r.RawRequest = req diff --git a/docs/api.md b/docs/api.md index 491329b7..821ecf6a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -210,6 +210,8 @@ Basically, you can know the meaning of most settings directly from the method na * [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput) * [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile) +* [SetDownloadCallback(callback DownloadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallback) +* [SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallbackWithInterval) ### Retry diff --git a/examples/uploadcallback/uploadclient/main.go b/examples/uploadcallback/uploadclient/main.go index b617e8f7..19dd217e 100644 --- a/examples/uploadcallback/uploadclient/main.go +++ b/examples/uploadcallback/uploadclient/main.go @@ -43,5 +43,5 @@ func main() { FileSize: int64(size), }).SetUploadCallbackWithInterval(func(info req.UploadInfo) { fmt.Printf("%s: %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0) - }, 1*time.Second).Post("http://127.0.0.1:8888/upload") + }, 30*time.Millisecond).Post("http://127.0.0.1:8888/upload") } diff --git a/middleware.go b/middleware.go index 21315f24..e6d2f86f 100644 --- a/middleware.go +++ b/middleware.go @@ -2,7 +2,6 @@ package req import ( "bytes" - "errors" "github.com/imroc/req/v3/internal/util" "io" "io/ioutil" @@ -80,66 +79,82 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e if err != nil { return err } + + if r.uploadCallback != nil { + pw = &callbackWriter{ + Writer: pw, + lastTime: lastTime, + interval: r.uploadCallbackInterval, + totalSize: file.FileSize, + callback: func(written int64) { + r.uploadCallback(UploadInfo{ + ParamName: file.ParamName, + FileName: file.FileName, + FileSize: file.FileSize, + UploadedSize: written, + }) + }, + } + } + if _, err = pw.Write(cbuf[:size]); err != nil { return err } if seeEOF { return nil } - if r.uploadCallback == nil { - _, err = io.Copy(pw, content) - return err - } - uploadedBytes := int64(size) - progressCallback := func() { - r.uploadCallback(UploadInfo{ - ParamName: file.ParamName, - FileName: file.FileName, - FileSize: file.FileSize, - UploadedSize: uploadedBytes, - }) - } - if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - lastTime = now - progressCallback() - } - buf := make([]byte, 1024) - for { - callback := false - nr, er := content.Read(buf) - if nr > 0 { - nw, ew := pw.Write(buf[:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errors.New("invalid write result") - } - } - uploadedBytes += int64(nw) - if ew != nil { - return ew - } - if nr != nw { - return io.ErrShortWrite - } - if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { - lastTime = now - progressCallback() - callback = true - } - } - if er != nil { - if er == io.EOF { - if !callback { - progressCallback() - } - break - } else { - return er - } - } - } + _, err = io.Copy(pw, content) + return err + // uploadedBytes := int64(size) + // progressCallback := func() { + // r.uploadCallback(UploadInfo{ + // ParamName: file.ParamName, + // FileName: file.FileName, + // FileSize: file.FileSize, + // UploadedSize: uploadedBytes, + // }) + // } + // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + // lastTime = now + // progressCallback() + // } + // buf := make([]byte, 1024) + // for { + // callback := false + // nr, er := content.Read(buf) + // if nr > 0 { + // nw, ew := pw.Write(buf[:nr]) + // if nw < 0 || nr < nw { + // nw = 0 + // if ew == nil { + // ew = errors.New("invalid write result") + // } + // } + // uploadedBytes += int64(nw) + // if ew != nil { + // return ew + // } + // if nr != nw { + // return io.ErrShortWrite + // } + // if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval { + // lastTime = now + // progressCallback() + // callback = true + // } + // } + // if er != nil { + // if er == io.EOF { + // if !callback { + // progressCallback() + // } + // break + // } else { + // return er + // } + // } + // } return nil } @@ -266,6 +281,60 @@ func parseResponseBody(c *Client, r *Response) (err error) { return } +type callbackWriter struct { + io.Writer + written int64 + totalSize int64 + lastTime time.Time + interval time.Duration + callback func(written int64) +} + +func (w *callbackWriter) Write(p []byte) (n int, err error) { + n, err = w.Writer.Write(p) + if n <= 0 { + return + } + w.written += int64(n) + if w.written == w.totalSize { + w.callback(w.written) + } else if now := time.Now(); now.Sub(w.lastTime) >= w.interval { + w.lastTime = now + w.callback(w.written) + } + return +} + +type callbackReader struct { + io.ReadCloser + read int64 + lastRead int64 + callback func(read int64) + lastTime time.Time + interval time.Duration +} + +func (r *callbackReader) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + if n <= 0 { + if err == io.EOF && r.read > r.lastRead { + r.callback(r.read) + r.lastRead = r.read + } + return + } + r.read += int64(n) + if err == io.EOF { + r.callback(r.read) + r.lastRead = r.read + } else if now := time.Now(); now.Sub(r.lastTime) >= r.interval { + r.lastTime = now + r.callback(r.read) + r.lastRead = r.read + } + return +} + func handleDownload(c *Client, r *Response) (err error) { if !r.Request.isSaveResponse { return nil @@ -302,6 +371,21 @@ func handleDownload(c *Client, r *Response) (err error) { body.Close() closeq(output) }() + + // if r.Request.downloadCallback != nil { + // output = &callbackWriter{ + // Writer: output, + // lastTime: time.Now(), + // interval: r.Request.downloadCallbackInterval, + // callback: func(written int64) { + // r.Request.downloadCallback(DownloadInfo{ + // Response: r, + // DownloadedSize: written, + // }) + // }, + // } + // } + _, err = io.Copy(output, body) r.setReceivedAt() return diff --git a/req.go b/req.go index 249f4bea..e3e4caee 100644 --- a/req.go +++ b/req.go @@ -80,6 +80,18 @@ type UploadInfo struct { // multipart upload. type UploadCallback func(info UploadInfo) +// DownloadInfo is the information for each DownloadCallback call. +type DownloadInfo struct { + // Response is the corresponding Response during download. + Response *Response + // downloaded body length in bytes. + DownloadedSize int64 +} + +// DownloadCallback is the callback which will be invoked during +// response body download. +type DownloadCallback func(info DownloadInfo) + func cloneCookies(cookies []*http.Cookie) []*http.Cookie { if len(cookies) == 0 { return nil diff --git a/req_test.go b/req_test.go index 50bddd1c..9687f311 100644 --- a/req_test.go +++ b/req_test.go @@ -15,6 +15,7 @@ import ( "os" "path/filepath" "reflect" + "strconv" "strings" "sync" "testing" @@ -357,6 +358,24 @@ func handleGet(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.URL.RawQuery)) case "/search": handleSearch(w, r) + case "/download": + size := 100 * 1024 * 1024 + w.Header().Set("Content-Length", strconv.Itoa(size)) + buf := make([]byte, 1024) + for i := 0; i < 1024; i++ { + buf[i] = 'h' + } + for i := 0; i < size; { + wbuf := buf + if size-i < 1024 { + wbuf = buf[:size-i] + } + n, err := w.Write(wbuf) + if err != nil { + break + } + i += n + } case "/protected": auth := r.Header.Get("Authorization") if auth == "Bearer goodtoken" { diff --git a/request.go b/request.go index a14af1ce..b5b873de 100644 --- a/request.go +++ b/request.go @@ -33,28 +33,30 @@ type Request struct { StartTime time.Time RetryAttempt int - RawURL string // read only - method string - URL *urlpkg.URL - getBody GetContentFunc - uploadCallback UploadCallback - uploadCallbackInterval time.Duration - unReplayableBody io.ReadCloser - retryOption *retryOption - bodyReadCloser io.ReadCloser - body []byte - dumpOptions *DumpOptions - marshalBody interface{} - ctx context.Context - isMultiPart bool - uploadFiles []*FileUpload - uploadReader []io.ReadCloser - outputFile string - isSaveResponse bool - output io.Writer - trace *clientTrace - dumpBuffer *bytes.Buffer - responseReturnTime time.Time + RawURL string // read only + method string + URL *urlpkg.URL + getBody GetContentFunc + uploadCallback UploadCallback + uploadCallbackInterval time.Duration + downloadCallback DownloadCallback + downloadCallbackInterval time.Duration + unReplayableBody io.ReadCloser + retryOption *retryOption + bodyReadCloser io.ReadCloser + body []byte + dumpOptions *DumpOptions + marshalBody interface{} + ctx context.Context + isMultiPart bool + uploadFiles []*FileUpload + uploadReader []io.ReadCloser + outputFile string + isSaveResponse bool + output io.Writer + trace *clientTrace + dumpBuffer *bytes.Buffer + responseReturnTime time.Time } type GetContentFunc func() (io.ReadCloser, error) @@ -294,6 +296,23 @@ func (r *Request) SetUploadCallbackWithInterval(callback UploadCallback, minInte return r } +// SetDownloadCallback set the DownloadCallback which will be invoked at least +// every 200ms during file upload, usually used to show download progress. +func (r *Request) SetDownloadCallback(callback DownloadCallback) *Request { + return r.SetDownloadCallbackWithInterval(callback, 200*time.Millisecond) +} + +// SetDownloadCallbackWithInterval set the DownloadCallback which will be invoked at least +// every `minInterval` during file upload, usually used to show download progress. +func (r *Request) SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration) *Request { + if callback == nil { + return r + } + r.downloadCallback = callback + r.downloadCallbackInterval = minInterval + return r +} + // SetResult set the result that response body will be unmarshaled to if // request is success (status `code >= 200 and <= 299`). func (r *Request) SetResult(result interface{}) *Request { diff --git a/request_test.go b/request_test.go index 50a75044..3d30054e 100644 --- a/request_test.go +++ b/request_test.go @@ -904,7 +904,7 @@ type SlowReader struct { } func (r *SlowReader) Read(p []byte) (int, error) { - time.Sleep(10 * time.Millisecond) + time.Sleep(100 * time.Millisecond) return r.ReadCloser.Read(p) } @@ -932,3 +932,14 @@ func TestUploadCallback(t *testing.T) { assertSuccess(t, resp, err) assertEqual(t, true, n > 1) } + +func TestDownloadCallback(t *testing.T) { + n := 0 + resp, err := tc().R(). + SetOutput(ioutil.Discard). + SetDownloadCallback(func(info DownloadInfo) { + n++ + }).Get("/download") + assertSuccess(t, resp, err) + assertEqual(t, true, n > 0) +} diff --git a/request_wrapper.go b/request_wrapper.go index 0006aecc..e496cdc4 100644 --- a/request_wrapper.go +++ b/request_wrapper.go @@ -421,3 +421,15 @@ func SetUploadCallback(callback UploadCallback) *Request { func SetUploadCallbackWithInterval(callback UploadCallback, minInterval time.Duration) *Request { return defaultClient.R().SetUploadCallbackWithInterval(callback, minInterval) } + +// SetDownloadCallback is a global wrapper methods which delegated +// to the default client, create a request and SetDownloadCallback for request. +func SetDownloadCallback(callback DownloadCallback) *Request { + return defaultClient.R().SetDownloadCallback(callback) +} + +// SetDownloadCallbackWithInterval is a global wrapper methods which delegated +// to the default client, create a request and SetDownloadCallbackWithInterval for request. +func SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration) *Request { + return defaultClient.R().SetDownloadCallbackWithInterval(callback, minInterval) +} diff --git a/request_wrapper_test.go b/request_wrapper_test.go index 14d43b45..fe080646 100644 --- a/request_wrapper_test.go +++ b/request_wrapper_test.go @@ -70,6 +70,8 @@ func TestGlobalWrapperForRequestSettings(t *testing.T) { SetContext(context.Background()), SetUploadCallback(nil), SetUploadCallbackWithInterval(nil, 0), + SetDownloadCallback(nil), + SetDownloadCallbackWithInterval(nil, 0), ) } diff --git a/transport.go b/transport.go index 302398c5..6d63e2bd 100644 --- a/transport.go +++ b/transport.go @@ -269,15 +269,37 @@ type Transport struct { *ResponseOptions - dump *dumper + dump *dumper + + // Debugf is the optional debug function. Debugf func(format string, v ...interface{}) } +type wrapResponseBodyKeyType int + +const wrapResponseBodyKey wrapResponseBodyKeyType = iota + +type wrapResponseBodyFunc func(rc io.ReadCloser) io.ReadCloser + func (t *Transport) handleResponseBody(res *http.Response, req *http.Request) { + if wrap, ok := req.Context().Value(wrapResponseBodyKey).(wrapResponseBodyFunc); ok { + t.wrapResponseBody(res, wrap) + } t.autoDecodeResponseBody(res) t.dumpResponseBody(res, req) } +func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFunc) { + switch b := res.Body.(type) { + case *gzipReader: + b.body.body = wrap(b.body.body) + case *http2gzipReader: + b.body = wrap(b.body) + default: + res.Body = wrap(res.Body) + } +} + func (t *Transport) dumpResponseBody(res *http.Response, req *http.Request) { dumps := getDumpers(req.Context(), t.dump) for _, dump := range dumps {