Skip to content

Commit

Permalink
support download callback
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Mar 27, 2022
1 parent 2d22ff0 commit e78b0a6
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 81 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,29 @@ if err != nil {
client.R().SetOutput(file).Get(url)
```

**Download Callback**

You can set `DownloadCallback` if you want to show download progress:

```go
client := req.C()
client.R().
SetOutputFile("test.gz").
SetUploadCallback(func(info req.UploadInfo) {
fmt.Printf("downloaded %.2f%%\n", float64(info.DownloadedSize)/float64(info.Response.ContentLength)*100.0)
}).Post("https://exmaple.com/upload")
/* Output
downloaded 17.92%
downloaded 41.77%
downloaded 67.71%
downloaded 98.89%
downloaded 100.00%
*/
```

> `info.Response.ContentLength` could be 0 or -1 when the total size is unknown.
> `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`.
**Multipart Upload**

```go
Expand Down
24 changes: 22 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,8 +995,28 @@ func (c *Client) do(r *Request) (resp *Response, err error) {
for _, cookie := range r.Cookies {
req.AddCookie(cookie)
}
if r.ctx != nil {
req = req.WithContext(r.ctx)
ctx := r.ctx
if r.isSaveResponse && r.downloadCallback != nil {
var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser {
return &callbackReader{
ReadCloser: rc,
callback: func(read int64) {
r.downloadCallback(DownloadInfo{
Response: resp,
DownloadedSize: read,
})
},
lastTime: time.Now(),
interval: r.downloadCallbackInterval,
}
}
if ctx == nil {
ctx = context.Background()
}
ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap)
}
if ctx != nil {
req = req.WithContext(ctx)
}
r.RawRequest = req

Expand Down
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ Basically, you can know the meaning of most settings directly from the method na

* [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput)
* [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile)
* [SetDownloadCallback(callback DownloadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallback)
* [SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallbackWithInterval)

### <a name="Retry-Request">Retry</a>

Expand Down
2 changes: 1 addition & 1 deletion examples/uploadcallback/uploadclient/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ func main() {
FileSize: int64(size),
}).SetUploadCallbackWithInterval(func(info req.UploadInfo) {
fmt.Printf("%s: %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0)
}, 1*time.Second).Post("http://127.0.0.1:8888/upload")
}, 30*time.Millisecond).Post("http://127.0.0.1:8888/upload")
}
192 changes: 138 additions & 54 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package req

import (
"bytes"
"errors"
"github.com/imroc/req/v3/internal/util"
"io"
"io/ioutil"
Expand Down Expand Up @@ -80,66 +79,82 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e
if err != nil {
return err
}

if r.uploadCallback != nil {
pw = &callbackWriter{
Writer: pw,
lastTime: lastTime,
interval: r.uploadCallbackInterval,
totalSize: file.FileSize,
callback: func(written int64) {
r.uploadCallback(UploadInfo{
ParamName: file.ParamName,
FileName: file.FileName,
FileSize: file.FileSize,
UploadedSize: written,
})
},
}
}

if _, err = pw.Write(cbuf[:size]); err != nil {
return err
}
if seeEOF {
return nil
}
if r.uploadCallback == nil {
_, err = io.Copy(pw, content)
return err
}

uploadedBytes := int64(size)
progressCallback := func() {
r.uploadCallback(UploadInfo{
ParamName: file.ParamName,
FileName: file.FileName,
FileSize: file.FileSize,
UploadedSize: uploadedBytes,
})
}
if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
lastTime = now
progressCallback()
}
buf := make([]byte, 1024)
for {
callback := false
nr, er := content.Read(buf)
if nr > 0 {
nw, ew := pw.Write(buf[:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write result")
}
}
uploadedBytes += int64(nw)
if ew != nil {
return ew
}
if nr != nw {
return io.ErrShortWrite
}
if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
lastTime = now
progressCallback()
callback = true
}
}
if er != nil {
if er == io.EOF {
if !callback {
progressCallback()
}
break
} else {
return er
}
}
}
_, err = io.Copy(pw, content)
return err
// uploadedBytes := int64(size)
// progressCallback := func() {
// r.uploadCallback(UploadInfo{
// ParamName: file.ParamName,
// FileName: file.FileName,
// FileSize: file.FileSize,
// UploadedSize: uploadedBytes,
// })
// }
// if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
// lastTime = now
// progressCallback()
// }
// buf := make([]byte, 1024)
// for {
// callback := false
// nr, er := content.Read(buf)
// if nr > 0 {
// nw, ew := pw.Write(buf[:nr])
// if nw < 0 || nr < nw {
// nw = 0
// if ew == nil {
// ew = errors.New("invalid write result")
// }
// }
// uploadedBytes += int64(nw)
// if ew != nil {
// return ew
// }
// if nr != nw {
// return io.ErrShortWrite
// }
// if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
// lastTime = now
// progressCallback()
// callback = true
// }
// }
// if er != nil {
// if er == io.EOF {
// if !callback {
// progressCallback()
// }
// break
// } else {
// return er
// }
// }
// }
return nil
}

Expand Down Expand Up @@ -266,6 +281,60 @@ func parseResponseBody(c *Client, r *Response) (err error) {
return
}

type callbackWriter struct {
io.Writer
written int64
totalSize int64
lastTime time.Time
interval time.Duration
callback func(written int64)
}

func (w *callbackWriter) Write(p []byte) (n int, err error) {
n, err = w.Writer.Write(p)
if n <= 0 {
return
}
w.written += int64(n)
if w.written == w.totalSize {
w.callback(w.written)
} else if now := time.Now(); now.Sub(w.lastTime) >= w.interval {
w.lastTime = now
w.callback(w.written)
}
return
}

type callbackReader struct {
io.ReadCloser
read int64
lastRead int64
callback func(read int64)
lastTime time.Time
interval time.Duration
}

func (r *callbackReader) Read(p []byte) (n int, err error) {
n, err = r.ReadCloser.Read(p)
if n <= 0 {
if err == io.EOF && r.read > r.lastRead {
r.callback(r.read)
r.lastRead = r.read
}
return
}
r.read += int64(n)
if err == io.EOF {
r.callback(r.read)
r.lastRead = r.read
} else if now := time.Now(); now.Sub(r.lastTime) >= r.interval {
r.lastTime = now
r.callback(r.read)
r.lastRead = r.read
}
return
}

func handleDownload(c *Client, r *Response) (err error) {
if !r.Request.isSaveResponse {
return nil
Expand Down Expand Up @@ -302,6 +371,21 @@ func handleDownload(c *Client, r *Response) (err error) {
body.Close()
closeq(output)
}()

// if r.Request.downloadCallback != nil {
// output = &callbackWriter{
// Writer: output,
// lastTime: time.Now(),
// interval: r.Request.downloadCallbackInterval,
// callback: func(written int64) {
// r.Request.downloadCallback(DownloadInfo{
// Response: r,
// DownloadedSize: written,
// })
// },
// }
// }

_, err = io.Copy(output, body)
r.setReceivedAt()
return
Expand Down
12 changes: 12 additions & 0 deletions req.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ type UploadInfo struct {
// multipart upload.
type UploadCallback func(info UploadInfo)

// DownloadInfo is the information for each DownloadCallback call.
type DownloadInfo struct {
// Response is the corresponding Response during download.
Response *Response
// downloaded body length in bytes.
DownloadedSize int64
}

// DownloadCallback is the callback which will be invoked during
// response body download.
type DownloadCallback func(info DownloadInfo)

func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
if len(cookies) == 0 {
return nil
Expand Down
19 changes: 19 additions & 0 deletions req_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -357,6 +358,24 @@ func handleGet(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.URL.RawQuery))
case "/search":
handleSearch(w, r)
case "/download":
size := 100 * 1024 * 1024
w.Header().Set("Content-Length", strconv.Itoa(size))
buf := make([]byte, 1024)
for i := 0; i < 1024; i++ {
buf[i] = 'h'
}
for i := 0; i < size; {
wbuf := buf
if size-i < 1024 {
wbuf = buf[:size-i]
}
n, err := w.Write(wbuf)
if err != nil {
break
}
i += n
}
case "/protected":
auth := r.Header.Get("Authorization")
if auth == "Bearer goodtoken" {
Expand Down
Loading

0 comments on commit e78b0a6

Please sign in to comment.