diff --git a/lib/rpcenc/reader.go b/lib/rpcenc/reader.go index 99b0faf05..64dcc4b49 100644 --- a/lib/rpcenc/reader.go +++ b/lib/rpcenc/reader.go @@ -282,7 +282,7 @@ func (w *RpcReader) Read(p []byte) (int, error) { } if w.postBody == nil { - return 0, xerrors.Errorf("reader already closed or redirected") + return 0, xerrors.Errorf("reader already closed, redirected or cancelled") } n, err := w.postBody.Read(p) @@ -406,6 +406,29 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { return case <-req.Context().Done(): log.Errorf("context error in reader stream handler (2): %v", req.Context().Err()) + + closed := make(chan struct{}) + // start a draining goroutine + go func() { + for { + select { + case r, ok := <-wr.res: + if !ok { + return + } + log.Errorw("discarding read res", "type", r.rt, "meta", r.meta) + case <-closed: + return + } + } + }() + + wr.beginOnce.Do(func() {}) + wr.closeOnce.Do(func() { + close(wr.res) + }) + close(closed) + resp.WriteHeader(500) return } diff --git a/lib/rpcenc/reader_test.go b/lib/rpcenc/reader_test.go index 455cc8fcc..da436068f 100644 --- a/lib/rpcenc/reader_test.go +++ b/lib/rpcenc/reader_test.go @@ -7,7 +7,9 @@ import ( "io/ioutil" "net/http/httptest" "strings" + "sync" "testing" + "time" "github.com/gorilla/mux" "github.com/stretchr/testify/require" @@ -20,6 +22,9 @@ import ( type ReaderHandler struct { readApi func(ctx context.Context, r io.Reader) ([]byte, error) + + cont chan struct{} + subErr error } func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) { @@ -31,6 +36,24 @@ func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir b return h.readApi(ctx, r) } +func (h *ReaderHandler) ReadAllWaiting(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) { + if mustRedir { + if err := r.(*RpcReader).MustRedirect(); err != nil { + return nil, err + } + } + + h.cont <- struct{}{} + <-h.cont + + var m []byte + m, h.subErr = h.readApi(ctx, r) + + h.cont <- struct{}{} + + return m, h.subErr +} + func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) { if mustRedir { if err := r.(*RpcReader).MustRedirect(); err != nil { @@ -194,3 +217,110 @@ func TestReaderRedirect(t *testing.T) { err = redirClient.CloseReader(context.TODO(), strings.NewReader("rediracted pooooootato")) require.NoError(t, err) } + +func TestReaderRedirectDrop(t *testing.T) { + // lower timeout so that the dangling connection between client and reader is dropped quickly + // after the test. Otherwise httptest.Close is blocked. + Timeout = 200 * time.Millisecond + + var allClient struct { + ReadAll func(ctx context.Context, r io.Reader) ([]byte, error) + } + + { + allServerHandler := &ReaderHandler{} + readerHandler, readerServerOpt := ReaderParamDecoder() + rpcServer := jsonrpc.NewServer(readerServerOpt) + rpcServer.Register("ReaderHandler", allServerHandler) + + mux := mux.NewRouter() + mux.Handle("/rpc/v0", rpcServer) + mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler) + + testServ := httptest.NewServer(mux) + defer testServ.Close() + t.Logf("test server reading: %s", testServ.URL) + + re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push") + closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&allClient}, nil, re) + require.NoError(t, err) + + defer closer() + } + + var redirClient struct { + ReadAllWaiting func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) + } + contCh := make(chan struct{}) + + allServerHandler := &ReaderHandler{readApi: allClient.ReadAll, cont: contCh} + readerHandler, readerServerOpt := ReaderParamDecoder() + rpcServer := jsonrpc.NewServer(readerServerOpt) + rpcServer.Register("ReaderHandler", allServerHandler) + + mux := mux.NewRouter() + mux.Handle("/rpc/v0", rpcServer) + mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler) + + testServ := httptest.NewServer(mux) + defer testServ.Close() + t.Logf("test server redirecting: %s", testServ.URL) + + re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push") + closer, err := jsonrpc.NewMergeClient(context.Background(), "http://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&redirClient}, nil, re) + require.NoError(t, err) + + defer closer() + + var done sync.WaitGroup + + // Happy case + + done.Add(1) + go func() { + defer done.Done() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + read, err := redirClient.ReadAllWaiting(ctx, strings.NewReader("rediracted pooooootato"), true) + require.NoError(t, err) + require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal") + }() + + <-contCh // exec enter ReadAllWaiting + contCh <- struct{}{} // stert subcall + <-contCh // wait for subcall to finish + + done.Wait() + + // Redir client drops before subcall + done.Add(1) + + go func() { + defer done.Done() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := redirClient.ReadAllWaiting(ctx, strings.NewReader("rediracted pooooootato"), true) + require.ErrorContains(t, err, "sendRequest failed") + }() + + // wait for execution to enter ReadAllWaiting + <-contCh + + // kill redirecting server connection + testServ.CloseClientConnections() + + // ReadAllWaiting should fail + done.Wait() + + // resume execution in ReadAllWaiting, calling redicect + contCh <- struct{}{} + + // wait for subcall to finish + <-contCh + + require.ErrorContains(t, allServerHandler.subErr, "decoding params for 'ReaderHandler.ReadAll' (param: 0; custom decoder): context canceled") +}