rpcenc: Support reader redirect
This commit is contained in:
parent
c17a0c4fed
commit
0c809d3a5f
@ -40,7 +40,63 @@ type ReaderStream struct {
|
||||
Info string
|
||||
}
|
||||
|
||||
var client = func() *http.Client {
|
||||
c := *http.DefaultClient
|
||||
c.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return &c
|
||||
}()
|
||||
|
||||
/*
|
||||
|
||||
Example rpc function:
|
||||
Push(context.Context, io.Reader) error
|
||||
|
||||
Request flow:
|
||||
1. Client invokes a method with an io.Reader param\
|
||||
2. go-jsonrpc invokes `ReaderParamEncoder` for the client-provided io.Reader
|
||||
3. `ReaderParamEncoder` transforms the reader into a `ReaderStream` which can
|
||||
be serialized as JSON, and sent as jsonrpc request parameter
|
||||
3.1. If the reader is of type `*sealing.NullReader`, the resulting object
|
||||
is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }`
|
||||
3.2. If the reader is of type `*rpcReader`, and it wasn't read from, we
|
||||
notify that rpcReader to go a different push endpoint, and return
|
||||
a `ReaderStream` object like in 3.4.
|
||||
3.3. In remaining cases we start a goroutine which:
|
||||
3.3.1. Makes a HEAD request to the server push endpoint
|
||||
3.3.2. If the HEAD request is redirected, it follows the redirect
|
||||
3.3.3. If the request succeeds, it starts a POST request to the
|
||||
endpoint to which the last HEAD request was sent with the
|
||||
reader set as request body.
|
||||
3.4. We return a `ReaderStream` indicating the uuid of push request, ex:
|
||||
`ReaderStream{ Type: "push", Info: "[UUID string]" }`
|
||||
4. If the reader wasn't a NullReader, the server will receive a HEAD (or
|
||||
POST in case of older clients) request to the push endpoint.
|
||||
4.1. The server gets or registers an `*rpcReader` in the `readers` map.
|
||||
4.2. It waits for a request to a matching push endpoint to be opened
|
||||
4.3. After the request is opened, it returns the `*rpcReader` to
|
||||
go-jsonrpc, which will pass it as the io.Reader parameter to the
|
||||
rpc method implementation
|
||||
4.4. If the first request made to the push endpoint was a POST, the
|
||||
returned `*rpcReader` acts as a simple reader reading the POST
|
||||
request body
|
||||
4.5. If the first request made to the push endpoint was a HEAD
|
||||
4.5.1. On the first call to Read or Close the server responds with
|
||||
a 200 OK header, the client starts a POST request to the same
|
||||
push URL, and the reader starts passing through the POST request
|
||||
body
|
||||
4.5.2. If the reader is passed to another (now client) RPC method as a
|
||||
reader parameter, the server for the first request responds to the
|
||||
HEAD request with http 302 Found, instructing the first client to
|
||||
go to the push endpoint of the second RPC server
|
||||
5. If the reader was a NullReader (ReaderStream.Type=="null"), we instantiate
|
||||
it, and provide to the method implementation
|
||||
|
||||
*/
|
||||
|
||||
func ReaderParamEncoder(addr string) jsonrpc.Option {
|
||||
// Client side parameter encoder. Runs on the rpc client side. io.Reader -> ReaderStream{}
|
||||
return jsonrpc.WithParamEncoder(new(io.Reader), func(value reflect.Value) (reflect.Value, error) {
|
||||
r := value.Interface().(io.Reader)
|
||||
|
||||
@ -55,62 +111,171 @@ func ReaderParamEncoder(addr string) jsonrpc.Option {
|
||||
}
|
||||
u.Path = path.Join(u.Path, reqID.String())
|
||||
|
||||
go func() {
|
||||
// TODO: figure out errors here
|
||||
rpcReader, redir := r.(*rpcReader)
|
||||
if redir {
|
||||
// if we have an rpc stream, redirect instead of proxying all the data
|
||||
redir = rpcReader.redirect(u.String())
|
||||
}
|
||||
|
||||
resp, err := http.Post(u.String(), "application/octet-stream", r)
|
||||
if err != nil {
|
||||
log.Errorf("sending reader param: %+v", err)
|
||||
return
|
||||
}
|
||||
if !redir {
|
||||
go func() {
|
||||
// TODO: figure out errors here
|
||||
for {
|
||||
req, err := http.NewRequest("HEAD", u.String(), nil)
|
||||
if err != nil {
|
||||
log.Errorf("sending HEAD request for the reder param: %+v", err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Errorf("sending reader param: %+v", err)
|
||||
return
|
||||
}
|
||||
// todo do we need to close the body for a head request?
|
||||
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
if resp.StatusCode == http.StatusFound {
|
||||
nextStr := resp.Header.Get("Location")
|
||||
u, err = url.Parse(nextStr)
|
||||
if err != nil {
|
||||
log.Errorf("sending HEAD request for the reder param, parsing next url (%s): %+v", nextStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
log.Errorf("sending reader param (%s): non-200 status: %s, msg: '%s'", u.String(), resp.Status, string(b))
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
}()
|
||||
if resp.StatusCode == http.StatusNoContent { // reader closed before reading anything
|
||||
// todo just return??
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
log.Errorf("sending reader param (%s): non-200 status: %s, msg: '%s'", u.String(), resp.Status, string(b))
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// now actually send the data
|
||||
req, err := http.NewRequest("POST", u.String(), r)
|
||||
if err != nil {
|
||||
log.Errorf("sending reader param: %+v", err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Errorf("sending reader param: %+v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close() //nolint
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
log.Errorf("sending reader param (%s): non-200 status: %s, msg: '%s'", u.String(), resp.Status, string(b))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return reflect.ValueOf(ReaderStream{Type: PushStream, Info: reqID.String()}), nil
|
||||
})
|
||||
}
|
||||
|
||||
// watchReadCloser watches the ReadCloser and closes the watch channel when
|
||||
type resType int
|
||||
|
||||
const (
|
||||
resStart resType = iota // send on first read after HEAD
|
||||
resRedirect // send on redirect before first read after HEAD
|
||||
// done/closed = close res channel
|
||||
)
|
||||
|
||||
type readRes struct {
|
||||
rt resType
|
||||
meta string
|
||||
}
|
||||
|
||||
// rpcReader watches the ReadCloser and closes the res channel when
|
||||
// either: (1) the ReaderCloser fails on Read (including with a benign error
|
||||
// like EOF), or (2) when Close is called.
|
||||
//
|
||||
// Use it be notified of terminal states, in situations where a Read failure (or
|
||||
// EOF) is considered a terminal state too (besides Close).
|
||||
type watchReadCloser struct {
|
||||
io.ReadCloser
|
||||
watch chan struct{}
|
||||
type rpcReader struct {
|
||||
postBody io.ReadCloser // nil on initial head request
|
||||
next chan *rpcReader // on head will get us the postBody after sending resStart
|
||||
|
||||
res chan readRes
|
||||
beginOnce *sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (w *watchReadCloser) Read(p []byte) (int, error) {
|
||||
n, err := w.ReadCloser.Read(p)
|
||||
func (w *rpcReader) beginPost() {
|
||||
if w.postBody == nil {
|
||||
w.res <- readRes{
|
||||
rt: resStart,
|
||||
}
|
||||
|
||||
nr := <-w.next
|
||||
|
||||
w.postBody = nr.postBody
|
||||
w.res = nr.res
|
||||
w.beginOnce = nr.beginOnce
|
||||
}
|
||||
}
|
||||
|
||||
func (w *rpcReader) Read(p []byte) (int, error) {
|
||||
w.beginOnce.Do(func() {
|
||||
w.beginPost()
|
||||
})
|
||||
|
||||
if w.postBody == nil {
|
||||
return 0, xerrors.Errorf("reader already closed or redirected")
|
||||
}
|
||||
|
||||
n, err := w.postBody.Read(p)
|
||||
if err != nil {
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.watch)
|
||||
close(w.res)
|
||||
})
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *watchReadCloser) Close() error {
|
||||
func (w *rpcReader) Close() error {
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.watch)
|
||||
close(w.res)
|
||||
})
|
||||
return w.ReadCloser.Close()
|
||||
return w.postBody.Close()
|
||||
}
|
||||
|
||||
func (w *rpcReader) redirect(to string) bool {
|
||||
done := false
|
||||
|
||||
w.beginOnce.Do(func() {
|
||||
w.closeOnce.Do(func() {
|
||||
w.res <- readRes{
|
||||
rt: resRedirect,
|
||||
meta: to,
|
||||
}
|
||||
|
||||
done = true
|
||||
close(w.res)
|
||||
})
|
||||
})
|
||||
|
||||
return done
|
||||
}
|
||||
|
||||
func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
|
||||
var readersLk sync.Mutex
|
||||
readers := map[uuid.UUID]chan *watchReadCloser{}
|
||||
readers := map[uuid.UUID]chan *rpcReader{}
|
||||
|
||||
// runs on the rpc server side, called by the client before making the jsonrpc request
|
||||
hnd := func(resp http.ResponseWriter, req *http.Request) {
|
||||
strId := path.Base(req.URL.Path)
|
||||
u, err := uuid.Parse(strId)
|
||||
@ -122,14 +287,24 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
|
||||
readersLk.Lock()
|
||||
ch, found := readers[u]
|
||||
if !found {
|
||||
ch = make(chan *watchReadCloser)
|
||||
ch = make(chan *rpcReader)
|
||||
readers[u] = ch
|
||||
}
|
||||
readersLk.Unlock()
|
||||
|
||||
wr := &watchReadCloser{
|
||||
ReadCloser: req.Body,
|
||||
watch: make(chan struct{}),
|
||||
wr := &rpcReader{
|
||||
res: make(chan readRes),
|
||||
next: ch,
|
||||
beginOnce: &sync.Once{},
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodHead:
|
||||
// leave body nil
|
||||
case http.MethodPost:
|
||||
wr.postBody = req.Body
|
||||
default:
|
||||
http.Error(resp, "unsupported method", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
tctx, cancel := context.WithTimeout(req.Context(), Timeout)
|
||||
@ -145,18 +320,37 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
|
||||
}
|
||||
|
||||
select {
|
||||
case <-wr.watch:
|
||||
case res, ok := <-wr.res:
|
||||
if !ok {
|
||||
if req.Method == http.MethodHead {
|
||||
resp.WriteHeader(http.StatusNoContent)
|
||||
} else {
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return
|
||||
}
|
||||
// TODO should we check if we failed the Read, and if so
|
||||
// return an HTTP 500? i.e. turn watch into a chan error?
|
||||
// return an HTTP 500? i.e. turn res into a chan error?
|
||||
|
||||
switch res.rt {
|
||||
case resRedirect:
|
||||
http.Redirect(resp, req, res.meta, http.StatusFound)
|
||||
case resStart: // responding to HEAD, request POST with reader data
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
log.Errorf("unknown res.rt")
|
||||
resp.WriteHeader(500)
|
||||
}
|
||||
|
||||
return
|
||||
case <-req.Context().Done():
|
||||
log.Errorf("context error in reader stream handler (2): %v", req.Context().Err())
|
||||
resp.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
|
||||
resp.WriteHeader(200)
|
||||
}
|
||||
|
||||
// Server side reader decoder. runs on the rpc server side, invoked when decoding client request parameters. json(ReaderStream{}) -> io.Reader
|
||||
dec := jsonrpc.WithParamDecoder(new(io.Reader), func(ctx context.Context, b []byte) (reflect.Value, error) {
|
||||
var rs ReaderStream
|
||||
if err := json.Unmarshal(b, &rs); err != nil {
|
||||
@ -180,7 +374,7 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
|
||||
readersLk.Lock()
|
||||
ch, found := readers[u]
|
||||
if !found {
|
||||
ch = make(chan *watchReadCloser)
|
||||
ch = make(chan *rpcReader)
|
||||
readers[u] = ch
|
||||
}
|
||||
readersLk.Unlock()
|
||||
|
@ -16,6 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type ReaderHandler struct {
|
||||
readApi func(ctx context.Context, r io.Reader) ([]byte, error)
|
||||
}
|
||||
|
||||
func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader) ([]byte, error) {
|
||||
return h.readApi(ctx, r)
|
||||
}
|
||||
|
||||
func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) {
|
||||
@ -88,3 +93,57 @@ func TestNullReaderProxy(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1016), n)
|
||||
}
|
||||
|
||||
func TestReaderRedirect(t *testing.T) {
|
||||
var allClient struct {
|
||||
ReadAll func(ctx context.Context, r io.Reader) ([]byte, error)
|
||||
}
|
||||
|
||||
{
|
||||
allServerHandler := &ReaderHandler{}
|
||||
readerHandler, readerServerOpt := ReaderParamDecoder()
|
||||
rpcServer := jsonrpc.NewServer(readerServerOpt)
|
||||
rpcServer.Register("ReaderHandler", allServerHandler)
|
||||
|
||||
mux := mux.NewRouter()
|
||||
mux.Handle("/rpc/v0", rpcServer)
|
||||
mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler)
|
||||
|
||||
testServ := httptest.NewServer(mux)
|
||||
defer testServ.Close()
|
||||
|
||||
re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push")
|
||||
closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&allClient}, nil, re)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer closer()
|
||||
}
|
||||
|
||||
var redirClient struct {
|
||||
ReadAllApi func(ctx context.Context, r io.Reader) ([]byte, error)
|
||||
}
|
||||
|
||||
{
|
||||
allServerHandler := &ReaderHandler{readApi: allClient.ReadAll}
|
||||
readerHandler, readerServerOpt := ReaderParamDecoder()
|
||||
rpcServer := jsonrpc.NewServer(readerServerOpt)
|
||||
rpcServer.Register("ReaderHandler", allServerHandler)
|
||||
|
||||
mux := mux.NewRouter()
|
||||
mux.Handle("/rpc/v0", rpcServer)
|
||||
mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler)
|
||||
|
||||
testServ := httptest.NewServer(mux)
|
||||
defer testServ.Close()
|
||||
|
||||
re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push")
|
||||
closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&redirClient}, nil, re)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer closer()
|
||||
}
|
||||
|
||||
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user