diff --git a/lib/rpcenc/reader.go b/lib/rpcenc/reader.go index c423bae55..6693dc83d 100644 --- a/lib/rpcenc/reader.go +++ b/lib/rpcenc/reader.go @@ -3,6 +3,7 @@ package rpcenc import ( "context" "encoding/json" + "errors" "fmt" "io" "io/ioutil" @@ -54,14 +55,14 @@ var client = func() *http.Client { Push(context.Context, io.Reader) error Request flow: - 1. Client invokes a method with an io.Reader param\ + 1. Client invokes a method with an io.Reader param 2. go-jsonrpc invokes `ReaderParamEncoder` for the client-provided io.Reader 3. `ReaderParamEncoder` transforms the reader into a `ReaderStream` which can be serialized as JSON, and sent as jsonrpc request parameter 3.1. If the reader is of type `*sealing.NullReader`, the resulting object is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }` - 3.2. If the reader is of type `*rpcReader`, and it wasn't read from, we - notify that rpcReader to go a different push endpoint, and return + 3.2. If the reader is of type `*RpcReader`, and it wasn't read from, we + notify that RpcReader to go a different push endpoint, and return a `ReaderStream` object like in 3.4. 3.3. In remaining cases we start a goroutine which: 3.3.1. Makes a HEAD request to the server push endpoint @@ -73,13 +74,13 @@ var client = func() *http.Client { `ReaderStream{ Type: "push", Info: "[UUID string]" }` 4. If the reader wasn't a NullReader, the server will receive a HEAD (or POST in case of older clients) request to the push endpoint. - 4.1. The server gets or registers an `*rpcReader` in the `readers` map. + 4.1. The server gets or registers an `*RpcReader` in the `readers` map. 4.2. It waits for a request to a matching push endpoint to be opened - 4.3. After the request is opened, it returns the `*rpcReader` to + 4.3. After the request is opened, it returns the `*RpcReader` to go-jsonrpc, which will pass it as the io.Reader parameter to the rpc method implementation 4.4. If the first request made to the push endpoint was a POST, the - returned `*rpcReader` acts as a simple reader reading the POST + returned `*RpcReader` acts as a simple reader reading the POST request body 4.5. If the first request made to the push endpoint was a HEAD 4.5.1. On the first call to Read or Close the server responds with @@ -111,7 +112,7 @@ func ReaderParamEncoder(addr string) jsonrpc.Option { } u.Path = path.Join(u.Path, reqID.String()) - rpcReader, redir := r.(*rpcReader) + rpcReader, redir := r.(*RpcReader) if redir { // if we have an rpc stream, redirect instead of proxying all the data redir = rpcReader.redirect(u.String()) @@ -191,6 +192,7 @@ type resType int const ( resStart resType = iota // send on first read after HEAD resRedirect // send on redirect before first read after HEAD + resError // done/closed = close res channel ) @@ -199,22 +201,53 @@ type readRes struct { meta string } -// rpcReader watches the ReadCloser and closes the res channel when +// RpcReader watches the ReadCloser and closes the res channel when // either: (1) the ReaderCloser fails on Read (including with a benign error // like EOF), or (2) when Close is called. // // Use it be notified of terminal states, in situations where a Read failure (or // EOF) is considered a terminal state too (besides Close). -type rpcReader struct { - postBody io.ReadCloser // nil on initial head request - next chan *rpcReader // on head will get us the postBody after sending resStart +type RpcReader struct { + postBody io.ReadCloser // nil on initial head request + next chan *RpcReader // on head will get us the postBody after sending resStart + mustRedirect bool res chan readRes beginOnce *sync.Once closeOnce sync.Once } -func (w *rpcReader) beginPost() { +var ErrHasBody = errors.New("RPCReader has body, either already read from or from a client with no redirect support") +var ErrMustRedirect = errors.New("reader can't be read directly; marked as MustRedirect") + +// MustRedirect marks the reader as required to be redirected. Will make local +// calls Read fail. MUST be called before this reader is used in any goroutine. +// If the reader can't be redirected will return ErrHasBody +func (w *RpcReader) MustRedirect() error { + if w.postBody != nil { + w.closeOnce.Do(func() { + w.res <- readRes{ + rt: resError, + } + close(w.res) + }) + + return ErrHasBody + } + + w.mustRedirect = true + return nil +} + +func (w *RpcReader) beginPost() { + if w.mustRedirect { + w.res <- readRes{ + rt: resError, + } + close(w.res) + return + } + if w.postBody == nil { w.res <- readRes{ rt: resStart, @@ -228,11 +261,15 @@ func (w *rpcReader) beginPost() { } } -func (w *rpcReader) Read(p []byte) (int, error) { +func (w *RpcReader) Read(p []byte) (int, error) { w.beginOnce.Do(func() { w.beginPost() }) + if w.mustRedirect { + return 0, ErrMustRedirect + } + if w.postBody == nil { return 0, xerrors.Errorf("reader already closed or redirected") } @@ -246,14 +283,18 @@ func (w *rpcReader) Read(p []byte) (int, error) { return n, err } -func (w *rpcReader) Close() error { +func (w *RpcReader) Close() error { + w.beginOnce.Do(func() {}) w.closeOnce.Do(func() { close(w.res) }) + if w.postBody == nil { + return nil + } return w.postBody.Close() } -func (w *rpcReader) redirect(to string) bool { +func (w *RpcReader) redirect(to string) bool { if w.postBody != nil { return false } @@ -277,7 +318,7 @@ func (w *rpcReader) redirect(to string) bool { func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { var readersLk sync.Mutex - readers := map[uuid.UUID]chan *rpcReader{} + readers := map[uuid.UUID]chan *RpcReader{} // runs on the rpc server side, called by the client before making the jsonrpc request hnd := func(resp http.ResponseWriter, req *http.Request) { @@ -291,12 +332,12 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { readersLk.Lock() ch, found := readers[u] if !found { - ch = make(chan *rpcReader) + ch = make(chan *RpcReader) readers[u] = ch } readersLk.Unlock() - wr := &rpcReader{ + wr := &RpcReader{ res: make(chan readRes), next: ch, beginOnce: &sync.Once{}, @@ -341,6 +382,8 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { http.Redirect(resp, req, res.meta, http.StatusFound) case resStart: // responding to HEAD, request POST with reader data resp.WriteHeader(http.StatusOK) + case resError: + resp.WriteHeader(500) default: log.Errorf("unknown res.rt") resp.WriteHeader(500) @@ -378,7 +421,7 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { readersLk.Lock() ch, found := readers[u] if !found { - ch = make(chan *rpcReader) + ch = make(chan *RpcReader) readers[u] = ch } readersLk.Unlock() diff --git a/lib/rpcenc/reader_test.go b/lib/rpcenc/reader_test.go index b425aec73..87296e1e5 100644 --- a/lib/rpcenc/reader_test.go +++ b/lib/rpcenc/reader_test.go @@ -20,11 +20,22 @@ type ReaderHandler struct { readApi func(ctx context.Context, r io.Reader) ([]byte, error) } -func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader) ([]byte, error) { +func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) { + if mustRedir { + if err := r.(*RpcReader).MustRedirect(); err != nil { + return nil, err + } + } return h.readApi(ctx, r) } -func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byte, error) { +func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) { + if mustRedir { + if err := r.(*RpcReader).MustRedirect(); err != nil { + return nil, err + } + } + n, err := r.Read([]byte{0}) if err != nil { return nil, err @@ -36,6 +47,10 @@ func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byt return h.readApi(ctx, r) } +func (h *ReaderHandler) CloseReader(ctx context.Context, r io.Reader) error { + return r.(io.Closer).Close() +} + func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) { return ioutil.ReadAll(r) } @@ -133,8 +148,9 @@ func TestReaderRedirect(t *testing.T) { } var redirClient struct { - ReadAllApi func(ctx context.Context, r io.Reader) ([]byte, error) - ReadStartAndApi func(ctx context.Context, r io.Reader) ([]byte, error) + ReadAllApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) + ReadStartAndApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) + CloseReader func(ctx context.Context, r io.Reader) error } { @@ -158,12 +174,21 @@ func TestReaderRedirect(t *testing.T) { } // redirect - read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato")) + read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true) require.NoError(t, err) require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal") // proxy (because we started reading locally) - read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato")) + read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), false) require.NoError(t, err) require.Equal(t, "ediracted pooooootato", string(read), "otatoes weren't equal") + + // check mustredir check; proxy (because we started reading locally) + read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true) + require.Error(t, err) + require.Contains(t, err.Error(), ErrMustRedirect.Error()) + require.Empty(t, read) + + err = redirClient.CloseReader(context.TODO(), strings.NewReader("rediracted pooooootato")) + require.NoError(t, err) }