package httpreader import ( "context" "fmt" "io" "net/http" "strconv" logging "github.com/ipfs/go-log/v2" "go.uber.org/multierr" "golang.org/x/xerrors" ) var log = logging.Logger("httpreader") type ResumableReader struct { ctx context.Context initialURL string finalURL *string position int64 contentLength int64 client *http.Client reader io.ReadCloser } func NewResumableReader(ctx context.Context, url string) (*ResumableReader, error) { finalURL := "" client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { finalURL = req.URL.String() if len(via) >= 10 { return xerrors.New("stopped after 10 redirects") } return nil }, } r := &ResumableReader{ ctx: ctx, initialURL: url, finalURL: &finalURL, position: 0, client: client, } req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } resp, err := r.client.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to fetch resource, status code: %d", resp.StatusCode) } contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) if err != nil { if err = resp.Body.Close(); err != nil { err = multierr.Append(err, err) } return nil, err } r.contentLength = contentLength r.reader = resp.Body return r, nil } func (r *ResumableReader) ContentLength() int64 { return r.contentLength } func (r *ResumableReader) Read(p []byte) (n int, err error) { for { if r.reader == nil { reqURL := r.initialURL if *r.finalURL != "" { reqURL = *r.finalURL } req, err := http.NewRequestWithContext(r.ctx, "GET", reqURL, nil) if err != nil { return 0, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.position)) resp, err := r.client.Do(req) if err != nil { return 0, err } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { return 0, fmt.Errorf("non-resumable status code: %d", resp.StatusCode) } r.reader = resp.Body } n, err = r.reader.Read(p) r.position += int64(n) if err == io.EOF || err == io.ErrUnexpectedEOF { if r.position == r.contentLength { if err := r.reader.Close(); err != nil { log.Warnf("error closing reader: %+v", err) } return n, io.EOF } if err := r.reader.Close(); err != nil { log.Warnf("error closing reader: %+v", err) } r.reader = nil } else { return n, err } } }