rpcenc: Test early close, add reader.MustRedirect

This commit is contained in:
Łukasz Magiera 2021-07-30 13:27:51 +02:00
parent 555c402ba3
commit 8426a62d15
2 changed files with 93 additions and 25 deletions

View File

@ -3,6 +3,7 @@ package rpcenc
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -54,14 +55,14 @@ var client = func() *http.Client {
Push(context.Context, io.Reader) error Push(context.Context, io.Reader) error
Request flow: Request flow:
1. Client invokes a method with an io.Reader param\ 1. Client invokes a method with an io.Reader param
2. go-jsonrpc invokes `ReaderParamEncoder` for the client-provided io.Reader 2. go-jsonrpc invokes `ReaderParamEncoder` for the client-provided io.Reader
3. `ReaderParamEncoder` transforms the reader into a `ReaderStream` which can 3. `ReaderParamEncoder` transforms the reader into a `ReaderStream` which can
be serialized as JSON, and sent as jsonrpc request parameter be serialized as JSON, and sent as jsonrpc request parameter
3.1. If the reader is of type `*sealing.NullReader`, the resulting object 3.1. If the reader is of type `*sealing.NullReader`, the resulting object
is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }` is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }`
3.2. If the reader is of type `*rpcReader`, and it wasn't read from, we 3.2. If the reader is of type `*RpcReader`, and it wasn't read from, we
notify that rpcReader to go a different push endpoint, and return notify that RpcReader to go a different push endpoint, and return
a `ReaderStream` object like in 3.4. a `ReaderStream` object like in 3.4.
3.3. In remaining cases we start a goroutine which: 3.3. In remaining cases we start a goroutine which:
3.3.1. Makes a HEAD request to the server push endpoint 3.3.1. Makes a HEAD request to the server push endpoint
@ -73,13 +74,13 @@ var client = func() *http.Client {
`ReaderStream{ Type: "push", Info: "[UUID string]" }` `ReaderStream{ Type: "push", Info: "[UUID string]" }`
4. If the reader wasn't a NullReader, the server will receive a HEAD (or 4. If the reader wasn't a NullReader, the server will receive a HEAD (or
POST in case of older clients) request to the push endpoint. POST in case of older clients) request to the push endpoint.
4.1. The server gets or registers an `*rpcReader` in the `readers` map. 4.1. The server gets or registers an `*RpcReader` in the `readers` map.
4.2. It waits for a request to a matching push endpoint to be opened 4.2. It waits for a request to a matching push endpoint to be opened
4.3. After the request is opened, it returns the `*rpcReader` to 4.3. After the request is opened, it returns the `*RpcReader` to
go-jsonrpc, which will pass it as the io.Reader parameter to the go-jsonrpc, which will pass it as the io.Reader parameter to the
rpc method implementation rpc method implementation
4.4. If the first request made to the push endpoint was a POST, the 4.4. If the first request made to the push endpoint was a POST, the
returned `*rpcReader` acts as a simple reader reading the POST returned `*RpcReader` acts as a simple reader reading the POST
request body request body
4.5. If the first request made to the push endpoint was a HEAD 4.5. If the first request made to the push endpoint was a HEAD
4.5.1. On the first call to Read or Close the server responds with 4.5.1. On the first call to Read or Close the server responds with
@ -111,7 +112,7 @@ func ReaderParamEncoder(addr string) jsonrpc.Option {
} }
u.Path = path.Join(u.Path, reqID.String()) u.Path = path.Join(u.Path, reqID.String())
rpcReader, redir := r.(*rpcReader) rpcReader, redir := r.(*RpcReader)
if redir { if redir {
// if we have an rpc stream, redirect instead of proxying all the data // if we have an rpc stream, redirect instead of proxying all the data
redir = rpcReader.redirect(u.String()) redir = rpcReader.redirect(u.String())
@ -191,6 +192,7 @@ type resType int
const ( const (
resStart resType = iota // send on first read after HEAD resStart resType = iota // send on first read after HEAD
resRedirect // send on redirect before first read after HEAD resRedirect // send on redirect before first read after HEAD
resError
// done/closed = close res channel // done/closed = close res channel
) )
@ -199,22 +201,53 @@ type readRes struct {
meta string meta string
} }
// rpcReader watches the ReadCloser and closes the res channel when // RpcReader watches the ReadCloser and closes the res channel when
// either: (1) the ReaderCloser fails on Read (including with a benign error // either: (1) the ReaderCloser fails on Read (including with a benign error
// like EOF), or (2) when Close is called. // like EOF), or (2) when Close is called.
// //
// Use it be notified of terminal states, in situations where a Read failure (or // Use it be notified of terminal states, in situations where a Read failure (or
// EOF) is considered a terminal state too (besides Close). // EOF) is considered a terminal state too (besides Close).
type rpcReader struct { type RpcReader struct {
postBody io.ReadCloser // nil on initial head request postBody io.ReadCloser // nil on initial head request
next chan *rpcReader // on head will get us the postBody after sending resStart next chan *RpcReader // on head will get us the postBody after sending resStart
mustRedirect bool
res chan readRes res chan readRes
beginOnce *sync.Once beginOnce *sync.Once
closeOnce sync.Once closeOnce sync.Once
} }
func (w *rpcReader) beginPost() { var ErrHasBody = errors.New("RPCReader has body, either already read from or from a client with no redirect support")
var ErrMustRedirect = errors.New("reader can't be read directly; marked as MustRedirect")
// MustRedirect marks the reader as required to be redirected. Will make local
// calls Read fail. MUST be called before this reader is used in any goroutine.
// If the reader can't be redirected will return ErrHasBody
func (w *RpcReader) MustRedirect() error {
if w.postBody != nil {
w.closeOnce.Do(func() {
w.res <- readRes{
rt: resError,
}
close(w.res)
})
return ErrHasBody
}
w.mustRedirect = true
return nil
}
func (w *RpcReader) beginPost() {
if w.mustRedirect {
w.res <- readRes{
rt: resError,
}
close(w.res)
return
}
if w.postBody == nil { if w.postBody == nil {
w.res <- readRes{ w.res <- readRes{
rt: resStart, rt: resStart,
@ -228,11 +261,15 @@ func (w *rpcReader) beginPost() {
} }
} }
func (w *rpcReader) Read(p []byte) (int, error) { func (w *RpcReader) Read(p []byte) (int, error) {
w.beginOnce.Do(func() { w.beginOnce.Do(func() {
w.beginPost() w.beginPost()
}) })
if w.mustRedirect {
return 0, ErrMustRedirect
}
if w.postBody == nil { if w.postBody == nil {
return 0, xerrors.Errorf("reader already closed or redirected") return 0, xerrors.Errorf("reader already closed or redirected")
} }
@ -246,14 +283,18 @@ func (w *rpcReader) Read(p []byte) (int, error) {
return n, err return n, err
} }
func (w *rpcReader) Close() error { func (w *RpcReader) Close() error {
w.beginOnce.Do(func() {})
w.closeOnce.Do(func() { w.closeOnce.Do(func() {
close(w.res) close(w.res)
}) })
if w.postBody == nil {
return nil
}
return w.postBody.Close() return w.postBody.Close()
} }
func (w *rpcReader) redirect(to string) bool { func (w *RpcReader) redirect(to string) bool {
if w.postBody != nil { if w.postBody != nil {
return false return false
} }
@ -277,7 +318,7 @@ func (w *rpcReader) redirect(to string) bool {
func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
var readersLk sync.Mutex var readersLk sync.Mutex
readers := map[uuid.UUID]chan *rpcReader{} readers := map[uuid.UUID]chan *RpcReader{}
// runs on the rpc server side, called by the client before making the jsonrpc request // runs on the rpc server side, called by the client before making the jsonrpc request
hnd := func(resp http.ResponseWriter, req *http.Request) { hnd := func(resp http.ResponseWriter, req *http.Request) {
@ -291,12 +332,12 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock() readersLk.Lock()
ch, found := readers[u] ch, found := readers[u]
if !found { if !found {
ch = make(chan *rpcReader) ch = make(chan *RpcReader)
readers[u] = ch readers[u] = ch
} }
readersLk.Unlock() readersLk.Unlock()
wr := &rpcReader{ wr := &RpcReader{
res: make(chan readRes), res: make(chan readRes),
next: ch, next: ch,
beginOnce: &sync.Once{}, beginOnce: &sync.Once{},
@ -341,6 +382,8 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
http.Redirect(resp, req, res.meta, http.StatusFound) http.Redirect(resp, req, res.meta, http.StatusFound)
case resStart: // responding to HEAD, request POST with reader data case resStart: // responding to HEAD, request POST with reader data
resp.WriteHeader(http.StatusOK) resp.WriteHeader(http.StatusOK)
case resError:
resp.WriteHeader(500)
default: default:
log.Errorf("unknown res.rt") log.Errorf("unknown res.rt")
resp.WriteHeader(500) resp.WriteHeader(500)
@ -378,7 +421,7 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock() readersLk.Lock()
ch, found := readers[u] ch, found := readers[u]
if !found { if !found {
ch = make(chan *rpcReader) ch = make(chan *RpcReader)
readers[u] = ch readers[u] = ch
} }
readersLk.Unlock() readersLk.Unlock()

View File

@ -20,11 +20,22 @@ type ReaderHandler struct {
readApi func(ctx context.Context, r io.Reader) ([]byte, error) readApi func(ctx context.Context, r io.Reader) ([]byte, error)
} }
func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader) ([]byte, error) { func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}
return h.readApi(ctx, r) return h.readApi(ctx, r)
} }
func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byte, error) { func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}
n, err := r.Read([]byte{0}) n, err := r.Read([]byte{0})
if err != nil { if err != nil {
return nil, err return nil, err
@ -36,6 +47,10 @@ func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byt
return h.readApi(ctx, r) return h.readApi(ctx, r)
} }
func (h *ReaderHandler) CloseReader(ctx context.Context, r io.Reader) error {
return r.(io.Closer).Close()
}
func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) { func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) {
return ioutil.ReadAll(r) return ioutil.ReadAll(r)
} }
@ -133,8 +148,9 @@ func TestReaderRedirect(t *testing.T) {
} }
var redirClient struct { var redirClient struct {
ReadAllApi func(ctx context.Context, r io.Reader) ([]byte, error) ReadAllApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
ReadStartAndApi func(ctx context.Context, r io.Reader) ([]byte, error) ReadStartAndApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
CloseReader func(ctx context.Context, r io.Reader) error
} }
{ {
@ -158,12 +174,21 @@ func TestReaderRedirect(t *testing.T) {
} }
// redirect // redirect
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato")) read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal") require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal")
// proxy (because we started reading locally) // proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato")) read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), false)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "ediracted pooooootato", string(read), "otatoes weren't equal") require.Equal(t, "ediracted pooooootato", string(read), "otatoes weren't equal")
// check mustredir check; proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.Error(t, err)
require.Contains(t, err.Error(), ErrMustRedirect.Error())
require.Empty(t, read)
err = redirClient.CloseReader(context.TODO(), strings.NewReader("rediracted pooooootato"))
require.NoError(t, err)
} }