Skip to content

Commit e78b0a6

Browse files
committed
support download callback
1 parent 2d22ff0 commit e78b0a6

File tree

12 files changed

+307
-81
lines changed

12 files changed

+307
-81
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,29 @@ if err != nil {
757757
client.R().SetOutput(file).Get(url)
758758
```
759759

760+
**Download Callback**
761+
762+
You can set `DownloadCallback` if you want to show download progress:
763+
764+
```go
765+
client := req.C()
766+
client.R().
767+
SetOutputFile("test.gz").
768+
SetUploadCallback(func(info req.UploadInfo) {
769+
fmt.Printf("downloaded %.2f%%\n", float64(info.DownloadedSize)/float64(info.Response.ContentLength)*100.0)
770+
}).Post("https://exmaple.com/upload")
771+
/* Output
772+
downloaded 17.92%
773+
downloaded 41.77%
774+
downloaded 67.71%
775+
downloaded 98.89%
776+
downloaded 100.00%
777+
*/
778+
```
779+
780+
> `info.Response.ContentLength` could be 0 or -1 when the total size is unknown.
781+
> `DownloadCallback` will be invoked at least every 200ms by default, you can customize the minimal invoke interval using `SetDownloadCallbackWithInterval`.
782+
760783
**Multipart Upload**
761784

762785
```go

client.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,8 +995,28 @@ func (c *Client) do(r *Request) (resp *Response, err error) {
995995
for _, cookie := range r.Cookies {
996996
req.AddCookie(cookie)
997997
}
998-
if r.ctx != nil {
999-
req = req.WithContext(r.ctx)
998+
ctx := r.ctx
999+
if r.isSaveResponse && r.downloadCallback != nil {
1000+
var wrap wrapResponseBodyFunc = func(rc io.ReadCloser) io.ReadCloser {
1001+
return &callbackReader{
1002+
ReadCloser: rc,
1003+
callback: func(read int64) {
1004+
r.downloadCallback(DownloadInfo{
1005+
Response: resp,
1006+
DownloadedSize: read,
1007+
})
1008+
},
1009+
lastTime: time.Now(),
1010+
interval: r.downloadCallbackInterval,
1011+
}
1012+
}
1013+
if ctx == nil {
1014+
ctx = context.Background()
1015+
}
1016+
ctx = context.WithValue(ctx, wrapResponseBodyKey, wrap)
1017+
}
1018+
if ctx != nil {
1019+
req = req.WithContext(ctx)
10001020
}
10011021
r.RawRequest = req
10021022

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ Basically, you can know the meaning of most settings directly from the method na
210210

211211
* [SetOutput(output io.Writer)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutput)
212212
* [SetOutputFile(file string)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetOutputFile)
213+
* [SetDownloadCallback(callback DownloadCallback)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallback)
214+
* [SetDownloadCallbackWithInterval(callback DownloadCallback, minInterval time.Duration)](https://pkg.go.dev/github.com/imroc/req/v3#Request.SetDownloadCallbackWithInterval)
213215

214216
### <a name="Retry-Request">Retry</a>
215217

examples/uploadcallback/uploadclient/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ func main() {
4343
FileSize: int64(size),
4444
}).SetUploadCallbackWithInterval(func(info req.UploadInfo) {
4545
fmt.Printf("%s: %.2f%%\n", info.FileName, float64(info.UploadedSize)/float64(info.FileSize)*100.0)
46-
}, 1*time.Second).Post("http://127.0.0.1:8888/upload")
46+
}, 30*time.Millisecond).Post("http://127.0.0.1:8888/upload")
4747
}

middleware.go

Lines changed: 138 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package req
22

33
import (
44
"bytes"
5-
"errors"
65
"github.com/imroc/req/v3/internal/util"
76
"io"
87
"io/ioutil"
@@ -80,66 +79,82 @@ func writeMultipartFormFile(w *multipart.Writer, file *FileUpload, r *Request) e
8079
if err != nil {
8180
return err
8281
}
82+
83+
if r.uploadCallback != nil {
84+
pw = &callbackWriter{
85+
Writer: pw,
86+
lastTime: lastTime,
87+
interval: r.uploadCallbackInterval,
88+
totalSize: file.FileSize,
89+
callback: func(written int64) {
90+
r.uploadCallback(UploadInfo{
91+
ParamName: file.ParamName,
92+
FileName: file.FileName,
93+
FileSize: file.FileSize,
94+
UploadedSize: written,
95+
})
96+
},
97+
}
98+
}
99+
83100
if _, err = pw.Write(cbuf[:size]); err != nil {
84101
return err
85102
}
86103
if seeEOF {
87104
return nil
88105
}
89-
if r.uploadCallback == nil {
90-
_, err = io.Copy(pw, content)
91-
return err
92-
}
93106

94-
uploadedBytes := int64(size)
95-
progressCallback := func() {
96-
r.uploadCallback(UploadInfo{
97-
ParamName: file.ParamName,
98-
FileName: file.FileName,
99-
FileSize: file.FileSize,
100-
UploadedSize: uploadedBytes,
101-
})
102-
}
103-
if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
104-
lastTime = now
105-
progressCallback()
106-
}
107-
buf := make([]byte, 1024)
108-
for {
109-
callback := false
110-
nr, er := content.Read(buf)
111-
if nr > 0 {
112-
nw, ew := pw.Write(buf[:nr])
113-
if nw < 0 || nr < nw {
114-
nw = 0
115-
if ew == nil {
116-
ew = errors.New("invalid write result")
117-
}
118-
}
119-
uploadedBytes += int64(nw)
120-
if ew != nil {
121-
return ew
122-
}
123-
if nr != nw {
124-
return io.ErrShortWrite
125-
}
126-
if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
127-
lastTime = now
128-
progressCallback()
129-
callback = true
130-
}
131-
}
132-
if er != nil {
133-
if er == io.EOF {
134-
if !callback {
135-
progressCallback()
136-
}
137-
break
138-
} else {
139-
return er
140-
}
141-
}
142-
}
107+
_, err = io.Copy(pw, content)
108+
return err
109+
// uploadedBytes := int64(size)
110+
// progressCallback := func() {
111+
// r.uploadCallback(UploadInfo{
112+
// ParamName: file.ParamName,
113+
// FileName: file.FileName,
114+
// FileSize: file.FileSize,
115+
// UploadedSize: uploadedBytes,
116+
// })
117+
// }
118+
// if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
119+
// lastTime = now
120+
// progressCallback()
121+
// }
122+
// buf := make([]byte, 1024)
123+
// for {
124+
// callback := false
125+
// nr, er := content.Read(buf)
126+
// if nr > 0 {
127+
// nw, ew := pw.Write(buf[:nr])
128+
// if nw < 0 || nr < nw {
129+
// nw = 0
130+
// if ew == nil {
131+
// ew = errors.New("invalid write result")
132+
// }
133+
// }
134+
// uploadedBytes += int64(nw)
135+
// if ew != nil {
136+
// return ew
137+
// }
138+
// if nr != nw {
139+
// return io.ErrShortWrite
140+
// }
141+
// if now := time.Now(); now.Sub(lastTime) >= r.uploadCallbackInterval {
142+
// lastTime = now
143+
// progressCallback()
144+
// callback = true
145+
// }
146+
// }
147+
// if er != nil {
148+
// if er == io.EOF {
149+
// if !callback {
150+
// progressCallback()
151+
// }
152+
// break
153+
// } else {
154+
// return er
155+
// }
156+
// }
157+
// }
143158
return nil
144159
}
145160

@@ -266,6 +281,60 @@ func parseResponseBody(c *Client, r *Response) (err error) {
266281
return
267282
}
268283

284+
type callbackWriter struct {
285+
io.Writer
286+
written int64
287+
totalSize int64
288+
lastTime time.Time
289+
interval time.Duration
290+
callback func(written int64)
291+
}
292+
293+
func (w *callbackWriter) Write(p []byte) (n int, err error) {
294+
n, err = w.Writer.Write(p)
295+
if n <= 0 {
296+
return
297+
}
298+
w.written += int64(n)
299+
if w.written == w.totalSize {
300+
w.callback(w.written)
301+
} else if now := time.Now(); now.Sub(w.lastTime) >= w.interval {
302+
w.lastTime = now
303+
w.callback(w.written)
304+
}
305+
return
306+
}
307+
308+
type callbackReader struct {
309+
io.ReadCloser
310+
read int64
311+
lastRead int64
312+
callback func(read int64)
313+
lastTime time.Time
314+
interval time.Duration
315+
}
316+
317+
func (r *callbackReader) Read(p []byte) (n int, err error) {
318+
n, err = r.ReadCloser.Read(p)
319+
if n <= 0 {
320+
if err == io.EOF && r.read > r.lastRead {
321+
r.callback(r.read)
322+
r.lastRead = r.read
323+
}
324+
return
325+
}
326+
r.read += int64(n)
327+
if err == io.EOF {
328+
r.callback(r.read)
329+
r.lastRead = r.read
330+
} else if now := time.Now(); now.Sub(r.lastTime) >= r.interval {
331+
r.lastTime = now
332+
r.callback(r.read)
333+
r.lastRead = r.read
334+
}
335+
return
336+
}
337+
269338
func handleDownload(c *Client, r *Response) (err error) {
270339
if !r.Request.isSaveResponse {
271340
return nil
@@ -302,6 +371,21 @@ func handleDownload(c *Client, r *Response) (err error) {
302371
body.Close()
303372
closeq(output)
304373
}()
374+
375+
// if r.Request.downloadCallback != nil {
376+
// output = &callbackWriter{
377+
// Writer: output,
378+
// lastTime: time.Now(),
379+
// interval: r.Request.downloadCallbackInterval,
380+
// callback: func(written int64) {
381+
// r.Request.downloadCallback(DownloadInfo{
382+
// Response: r,
383+
// DownloadedSize: written,
384+
// })
385+
// },
386+
// }
387+
// }
388+
305389
_, err = io.Copy(output, body)
306390
r.setReceivedAt()
307391
return

req.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ type UploadInfo struct {
8080
// multipart upload.
8181
type UploadCallback func(info UploadInfo)
8282

83+
// DownloadInfo is the information for each DownloadCallback call.
84+
type DownloadInfo struct {
85+
// Response is the corresponding Response during download.
86+
Response *Response
87+
// downloaded body length in bytes.
88+
DownloadedSize int64
89+
}
90+
91+
// DownloadCallback is the callback which will be invoked during
92+
// response body download.
93+
type DownloadCallback func(info DownloadInfo)
94+
8395
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
8496
if len(cookies) == 0 {
8597
return nil

req_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"os"
1616
"path/filepath"
1717
"reflect"
18+
"strconv"
1819
"strings"
1920
"sync"
2021
"testing"
@@ -357,6 +358,24 @@ func handleGet(w http.ResponseWriter, r *http.Request) {
357358
w.Write([]byte(r.URL.RawQuery))
358359
case "/search":
359360
handleSearch(w, r)
361+
case "/download":
362+
size := 100 * 1024 * 1024
363+
w.Header().Set("Content-Length", strconv.Itoa(size))
364+
buf := make([]byte, 1024)
365+
for i := 0; i < 1024; i++ {
366+
buf[i] = 'h'
367+
}
368+
for i := 0; i < size; {
369+
wbuf := buf
370+
if size-i < 1024 {
371+
wbuf = buf[:size-i]
372+
}
373+
n, err := w.Write(wbuf)
374+
if err != nil {
375+
break
376+
}
377+
i += n
378+
}
360379
case "/protected":
361380
auth := r.Header.Get("Authorization")
362381
if auth == "Bearer goodtoken" {

0 commit comments

Comments
 (0)