rpcenc: Test early close, add reader.MustRedirect

This commit is contained in:
Łukasz Magiera 2021-07-30 13:27:51 +02:00
parent 555c402ba3
commit 8426a62d15
2 changed files with 93 additions and 25 deletions

View File

@ -3,6 +3,7 @@ package rpcenc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
@ -54,14 +55,14 @@ var client = func() *http.Client {
Push(context.Context, io.Reader) error
Request flow:
1. Client invokes a method with an io.Reader param\
1. Client invokes a method with an io.Reader param
2. go-jsonrpc invokes `ReaderParamEncoder` for the client-provided io.Reader
3. `ReaderParamEncoder` transforms the reader into a `ReaderStream` which can
be serialized as JSON, and sent as jsonrpc request parameter
3.1. If the reader is of type `*sealing.NullReader`, the resulting object
is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }`
3.2. If the reader is of type `*rpcReader`, and it wasn't read from, we
notify that rpcReader to go a different push endpoint, and return
3.2. If the reader is of type `*RpcReader`, and it wasn't read from, we
notify that RpcReader to go a different push endpoint, and return
a `ReaderStream` object like in 3.4.
3.3. In remaining cases we start a goroutine which:
3.3.1. Makes a HEAD request to the server push endpoint
@ -73,13 +74,13 @@ var client = func() *http.Client {
`ReaderStream{ Type: "push", Info: "[UUID string]" }`
4. If the reader wasn't a NullReader, the server will receive a HEAD (or
POST in case of older clients) request to the push endpoint.
4.1. The server gets or registers an `*rpcReader` in the `readers` map.
4.1. The server gets or registers an `*RpcReader` in the `readers` map.
4.2. It waits for a request to a matching push endpoint to be opened
4.3. After the request is opened, it returns the `*rpcReader` to
4.3. After the request is opened, it returns the `*RpcReader` to
go-jsonrpc, which will pass it as the io.Reader parameter to the
rpc method implementation
4.4. If the first request made to the push endpoint was a POST, the
returned `*rpcReader` acts as a simple reader reading the POST
returned `*RpcReader` acts as a simple reader reading the POST
request body
4.5. If the first request made to the push endpoint was a HEAD
4.5.1. On the first call to Read or Close the server responds with
@ -111,7 +112,7 @@ func ReaderParamEncoder(addr string) jsonrpc.Option {
}
u.Path = path.Join(u.Path, reqID.String())
rpcReader, redir := r.(*rpcReader)
rpcReader, redir := r.(*RpcReader)
if redir {
// if we have an rpc stream, redirect instead of proxying all the data
redir = rpcReader.redirect(u.String())
@ -191,6 +192,7 @@ type resType int
const (
resStart resType = iota // send on first read after HEAD
resRedirect // send on redirect before first read after HEAD
resError
// done/closed = close res channel
)
@ -199,22 +201,53 @@ type readRes struct {
meta string
}
// rpcReader watches the ReadCloser and closes the res channel when
// RpcReader watches the ReadCloser and closes the res channel when
// either: (1) the ReaderCloser fails on Read (including with a benign error
// like EOF), or (2) when Close is called.
//
// Use it be notified of terminal states, in situations where a Read failure (or
// EOF) is considered a terminal state too (besides Close).
type rpcReader struct {
type RpcReader struct {
postBody io.ReadCloser // nil on initial head request
next chan *rpcReader // on head will get us the postBody after sending resStart
next chan *RpcReader // on head will get us the postBody after sending resStart
mustRedirect bool
res chan readRes
beginOnce *sync.Once
closeOnce sync.Once
}
func (w *rpcReader) beginPost() {
var ErrHasBody = errors.New("RPCReader has body, either already read from or from a client with no redirect support")
var ErrMustRedirect = errors.New("reader can't be read directly; marked as MustRedirect")
// MustRedirect marks the reader as required to be redirected. Will make local
// calls Read fail. MUST be called before this reader is used in any goroutine.
// If the reader can't be redirected will return ErrHasBody
func (w *RpcReader) MustRedirect() error {
if w.postBody != nil {
w.closeOnce.Do(func() {
w.res <- readRes{
rt: resError,
}
close(w.res)
})
return ErrHasBody
}
w.mustRedirect = true
return nil
}
func (w *RpcReader) beginPost() {
if w.mustRedirect {
w.res <- readRes{
rt: resError,
}
close(w.res)
return
}
if w.postBody == nil {
w.res <- readRes{
rt: resStart,
@ -228,11 +261,15 @@ func (w *rpcReader) beginPost() {
}
}
func (w *rpcReader) Read(p []byte) (int, error) {
func (w *RpcReader) Read(p []byte) (int, error) {
w.beginOnce.Do(func() {
w.beginPost()
})
if w.mustRedirect {
return 0, ErrMustRedirect
}
if w.postBody == nil {
return 0, xerrors.Errorf("reader already closed or redirected")
}
@ -246,14 +283,18 @@ func (w *rpcReader) Read(p []byte) (int, error) {
return n, err
}
func (w *rpcReader) Close() error {
func (w *RpcReader) Close() error {
w.beginOnce.Do(func() {})
w.closeOnce.Do(func() {
close(w.res)
})
if w.postBody == nil {
return nil
}
return w.postBody.Close()
}
func (w *rpcReader) redirect(to string) bool {
func (w *RpcReader) redirect(to string) bool {
if w.postBody != nil {
return false
}
@ -277,7 +318,7 @@ func (w *rpcReader) redirect(to string) bool {
func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
var readersLk sync.Mutex
readers := map[uuid.UUID]chan *rpcReader{}
readers := map[uuid.UUID]chan *RpcReader{}
// runs on the rpc server side, called by the client before making the jsonrpc request
hnd := func(resp http.ResponseWriter, req *http.Request) {
@ -291,12 +332,12 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *rpcReader)
ch = make(chan *RpcReader)
readers[u] = ch
}
readersLk.Unlock()
wr := &rpcReader{
wr := &RpcReader{
res: make(chan readRes),
next: ch,
beginOnce: &sync.Once{},
@ -341,6 +382,8 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
http.Redirect(resp, req, res.meta, http.StatusFound)
case resStart: // responding to HEAD, request POST with reader data
resp.WriteHeader(http.StatusOK)
case resError:
resp.WriteHeader(500)
default:
log.Errorf("unknown res.rt")
resp.WriteHeader(500)
@ -378,7 +421,7 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *rpcReader)
ch = make(chan *RpcReader)
readers[u] = ch
}
readersLk.Unlock()

View File

@ -20,11 +20,22 @@ type ReaderHandler struct {
readApi func(ctx context.Context, r io.Reader) ([]byte, error)
}
func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader) ([]byte, error) {
func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}
return h.readApi(ctx, r)
}
func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byte, error) {
func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}
n, err := r.Read([]byte{0})
if err != nil {
return nil, err
@ -36,6 +47,10 @@ func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byt
return h.readApi(ctx, r)
}
func (h *ReaderHandler) CloseReader(ctx context.Context, r io.Reader) error {
return r.(io.Closer).Close()
}
func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) {
return ioutil.ReadAll(r)
}
@ -133,8 +148,9 @@ func TestReaderRedirect(t *testing.T) {
}
var redirClient struct {
ReadAllApi func(ctx context.Context, r io.Reader) ([]byte, error)
ReadStartAndApi func(ctx context.Context, r io.Reader) ([]byte, error)
ReadAllApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
ReadStartAndApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
CloseReader func(ctx context.Context, r io.Reader) error
}
{
@ -158,12 +174,21 @@ func TestReaderRedirect(t *testing.T) {
}
// redirect
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"))
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.NoError(t, err)
require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal")
// proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"))
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), false)
require.NoError(t, err)
require.Equal(t, "ediracted pooooootato", string(read), "otatoes weren't equal")
// check mustredir check; proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.Error(t, err)
require.Contains(t, err.Error(), ErrMustRedirect.Error())
require.Empty(t, read)
err = redirClient.CloseReader(context.TODO(), strings.NewReader("rediracted pooooootato"))
require.NoError(t, err)
}