From f20eba426a1a871f98d0d988bfd51767364650b7 Mon Sep 17 00:00:00 2001 From: Sina Mahmoodi <1591639+s1na@users.noreply.github.com> Date: Wed, 7 Dec 2022 14:02:14 +0100 Subject: [PATCH] graphql, node, rpc: improve HTTP write timeout handling (#25457) Here we add special handling for sending an error response when the write timeout of the HTTP server is just about to expire. This is surprisingly difficult to get right, since is must be ensured that all output is fully flushed in time, which needs support from multiple levels of the RPC handler stack: The timeout response can't use chunked transfer-encoding because there is no way to write the final terminating chunk. net/http writes it when the topmost handler returns, but the timeout will already be over by the time that happens. We decided to disable chunked encoding by setting content-length explicitly. Gzip compression must also be disabled for timeout responses because we don't know the true content-length before compressing all output, i.e. compression would reintroduce chunked transfer-encoding. --- graphql/graphql_test.go | 9 +- graphql/service.go | 65 +++++++-- node/api_test.go | 3 + node/node_test.go | 13 +- node/rpcstack.go | 100 +++++++++++-- node/rpcstack_test.go | 220 +++++++++++++++++++++++++++-- p2p/simulations/adapters/inproc.go | 2 +- rpc/client.go | 5 +- rpc/errors.go | 5 + rpc/handler.go | 142 +++++++++++++++++-- rpc/http.go | 73 +++++++++- rpc/json.go | 26 ++-- rpc/server.go | 3 +- rpc/subscription.go | 6 +- rpc/types.go | 4 +- rpc/websocket.go | 10 +- 16 files changed, 606 insertions(+), 80 deletions(-) diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 491c73152..46acd1529 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) { func createNode(t *testing.T) *node.Node { stack, err := node.New(&node.Config{ - HTTPHost: "127.0.0.1", - HTTPPort: 0, - WSHost: "127.0.0.1", - WSPort: 0, + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + HTTPTimeouts: node.DefaultConfig.HTTPTimeouts, }) if err != nil { t.Fatalf("could not create node: %v", err) diff --git a/graphql/service.go b/graphql/service.go index 684fdc712..4392dd83e 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -20,12 +20,16 @@ import ( "context" "encoding/json" "net/http" + "strconv" + "sync" "time" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/rpc" "github.com/graph-gophers/graphql-go" + gqlErrors "github.com/graph-gophers/graphql-go/errors" ) type handler struct { @@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + var ( + ctx = r.Context() + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + ctx, cancel = context.WithCancel(ctx) defer cancel() - response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) - responseJSON, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if len(response.Errors) > 0 { - w.WriteHeader(http.StatusBadRequest) + if timeout, ok := rpc.ContextRequestTimeout(ctx); ok { + timer = time.AfterFunc(timeout, func() { + responded.Do(func() { + // Cancel request handling. + cancel() + + // Create the timeout response. + response := &graphql.Response{ + Errors: []*gqlErrors.QueryError{{Message: "request timed out"}}, + } + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Setting this disables gzip compression in package node. + w.Header().Set("transfer-encoding", "identity") + + // Flush the response. Since we are writing close to the response timeout, + // chunked transfer encoding must be disabled by setting content-length. + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.Itoa(len(responseJSON))) + w.Write(responseJSON) + if flush, ok := w.(http.Flusher); ok { + flush.Flush() + } + }) + }) } - w.Header().Set("Content-Type", "application/json") - w.Write(responseJSON) + response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) + timer.Stop() + responded.Do(func() { + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(response.Errors) > 0 { + w.WriteHeader(http.StatusBadRequest) + } + w.Header().Set("Content-Type", "application/json") + w.Write(responseJSON) + }) } // New constructs a new GraphQL service instance. diff --git a/node/api_test.go b/node/api_test.go index d76cb943e..8761c4883 100644 --- a/node/api_test.go +++ b/node/api_test.go @@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) { config := test.cfg // config.Logger = testlog.Logger(t, log.LvlDebug) config.P2P.NoDiscovery = true + if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) { + config.HTTPTimeouts = rpc.DefaultHTTPTimeouts + } // Create Node. stack, err := New(&config) diff --git a/node/node_test.go b/node/node_test.go index 7c76e21f6..560d487fa 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } for _, path := range test.wantHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 200 { t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode) } } for _, path := range test.wantNoHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 404 { t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode) } @@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { func createNode(t *testing.T, httpPort, wsPort int) *Node { conf := &Config{ - HTTPHost: "127.0.0.1", - HTTPPort: httpPort, - WSHost: "127.0.0.1", - WSPort: wsPort, + HTTPHost: "127.0.0.1", + HTTPPort: httpPort, + WSHost: "127.0.0.1", + WSPort: wsPort, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, } node, err := New(conf) if err != nil { diff --git a/node/rpcstack.go b/node/rpcstack.go index 8244c892f..97d591642 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -24,6 +24,7 @@ import ( "net" "net/http" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + // if http-rpc is enabled, try to serve request rpc := h.httpHandler.Load().(*rpcHandler) if rpc != nil { @@ -462,17 +464,94 @@ var gzPool = sync.Pool{ } type gzipResponseWriter struct { - io.Writer - http.ResponseWriter + resp http.ResponseWriter + + gz *gzip.Writer + contentLength uint64 // total length of the uncompressed response + written uint64 // amount of written bytes from the uncompressed response + hasLength bool // true if uncompressed response had Content-Length + inited bool // true after init was called for the first time +} + +// init runs just before response headers are written. Among other things, this function +// also decides whether compression will be applied at all. +func (w *gzipResponseWriter) init() { + if w.inited { + return + } + w.inited = true + + hdr := w.resp.Header() + length := hdr.Get("content-length") + if len(length) > 0 { + if n, err := strconv.ParseUint(length, 10, 64); err != nil { + w.hasLength = true + w.contentLength = n + } + } + + // Setting Transfer-Encoding to "identity" explicitly disables compression. net/http + // also recognizes this header value and uses it to disable "chunked" transfer + // encoding, trimming the header from the response. This means downstream handlers can + // set this without harm, even if they aren't wrapped by newGzipHandler. + // + // In go-ethereum, we use this signal to disable compression for certain error + // responses which are flushed out close to the write deadline of the response. For + // these cases, we want to avoid chunked transfer encoding and compression because + // they require additional output that may not get written in time. + passthrough := hdr.Get("transfer-encoding") == "identity" + if !passthrough { + w.gz = gzPool.Get().(*gzip.Writer) + w.gz.Reset(w.resp) + hdr.Del("content-length") + hdr.Set("content-encoding", "gzip") + } +} + +func (w *gzipResponseWriter) Header() http.Header { + return w.resp.Header() } func (w *gzipResponseWriter) WriteHeader(status int) { - w.Header().Del("Content-Length") - w.ResponseWriter.WriteHeader(status) + w.init() + w.resp.WriteHeader(status) } func (w *gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) + w.init() + + if w.gz == nil { + // Compression is disabled. + return w.resp.Write(b) + } + + n, err := w.gz.Write(b) + w.written += uint64(n) + if w.hasLength && w.written >= w.contentLength { + // The HTTP handler has finished writing the entire uncompressed response. Close + // the gzip stream to ensure the footer will be seen by the client in case the + // response is flushed after this call to write. + err = w.gz.Close() + } + return n, err +} + +func (w *gzipResponseWriter) Flush() { + if w.gz != nil { + w.gz.Flush() + } + if f, ok := w.resp.(http.Flusher); ok { + f.Flush() + } +} + +func (w *gzipResponseWriter) close() { + if w.gz == nil { + return + } + w.gz.Close() + gzPool.Put(w.gz) + w.gz = nil } func newGzipHandler(next http.Handler) http.Handler { @@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler { return } - w.Header().Set("Content-Encoding", "gzip") + wrapper := &gzipResponseWriter{resp: w} + defer wrapper.close() - gz := gzPool.Get().(*gzip.Writer) - defer gzPool.Put(gz) - - gz.Reset(w) - defer gz.Close() - - next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + next.ServeHTTP(wrapper, r) }) } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index ebc253800..795bc93c8 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -19,7 +19,9 @@ package node import ( "bytes" "fmt" + "io" "net/http" + "net/http/httptest" "net/url" "strconv" "strings" @@ -34,29 +36,31 @@ import ( "github.com/stretchr/testify/assert" ) +const testMethod = "rpc_modules" + // TestCorsHandler makes sure CORS are properly handled on the http server. func TestCorsHandler(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "origin", "test.com") + resp := rpcRequest(t, url, testMethod, "origin", "test.com") assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) - resp2 := rpcRequest(t, url, "origin", "bad") + resp2 := rpcRequest(t, url, testMethod, "origin", "bad") assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) } // TestVhosts makes sure vhosts are properly handled on the http server. func TestVhosts(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "host", "test") + resp := rpcRequest(t, url, testMethod, "host", "test") assert.Equal(t, resp.StatusCode, http.StatusOK) - resp2 := rpcRequest(t, url, "host", "bad") + resp2 := rpcRequest(t, url, testMethod, "host", "bad") assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } @@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) { }, } for _, tc := range tests { - srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) + srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil) url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { if err := wsRequest(t, url, "Origin", origin); err != nil { @@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) { } } -func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer { +func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer { t.Helper() - srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) - assert.NoError(t, srv.enableRPC(nil, *conf)) + if timeouts == nil { + timeouts = &rpc.DefaultHTTPTimeouts + } + srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts) + assert.NoError(t, srv.enableRPC(apis(), *conf)) if ws { assert.NoError(t, srv.enableWS(nil, *wsConf)) } @@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error { } // rpcRequest performs a JSON-RPC request to the given URL. -func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response { +func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response { + t.Helper() + + body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response { + reqs := make([]string, len(methods)) + for i, m := range methods { + reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m) + } + body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ",")) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response { t.Helper() // Create the request. - body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`)) + body := bytes.NewReader([]byte(bodyStr)) req, err := http.NewRequest("POST", url, body) if err != nil { t.Fatal("could not create http request:", err) } req.Header.Set("content-type", "application/json") + req.Header.Set("accept-encoding", "identity") // Apply extra headers. if len(extraHeaders)%2 != 0 { @@ -315,7 +339,7 @@ func TestJWT(t *testing.T) { return ss } srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) @@ -348,7 +372,7 @@ func TestJWT(t *testing.T) { t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) } token = tokenFn() - if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 { t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) } } @@ -414,10 +438,176 @@ func TestJWT(t *testing.T) { } token = tokenFn() - resp := rpcRequest(t, htUrl, "Authorization", token) + resp := rpcRequest(t, htUrl, testMethod, "Authorization", token) if resp.StatusCode != http.StatusUnauthorized { t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) } } srv.stop() } + +func TestGzipHandler(t *testing.T) { + type gzipTest struct { + name string + handler http.HandlerFunc + status int + isGzip bool + header map[string]string + } + tests := []gzipTest{ + { + name: "Write", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: true, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "WriteContentLength", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-length", "8") + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "Flush", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("res")) + w.(http.Flusher).Flush() + w.Write([]byte("ponse")) + }, + isGzip: true, + status: 200, + }, + { + name: "disable", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.Write([]byte("response")) + }, + isGzip: false, + status: 200, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "disable-WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: false, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + srv := httptest.NewServer(newGzipHandler(test.handler)) + defer srv.Close() + + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + content, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + wasGzip := resp.Uncompressed + + if string(content) != "response" { + t.Fatalf("wrong response content %q", content) + } + if wasGzip != test.isGzip { + t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip) + } + if resp.StatusCode != test.status { + t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status) + } + for name, expectedValue := range test.header { + if v := resp.Header.Get(name); v != expectedValue { + t.Fatalf("response header %s == %s, want %s", name, v, expectedValue) + } + } + }) + } +} + +func TestHTTPWriteTimeout(t *testing.T) { + const ( + timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}` + greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}` + ) + // Set-up server + timeouts := rpc.DefaultHTTPTimeouts + timeouts.WriteTimeout = time.Second + srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts) + url := fmt.Sprintf("http://%v", srv.listenAddr()) + + // Send normal request + t.Run("message", func(t *testing.T) { + resp := rpcRequest(t, url, "test_sleep") + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != timeoutRes { + t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes) + } + }) + + // Batch request + t.Run("batch", func(t *testing.T) { + want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes) + resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"}) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != want { + t.Errorf("wrong response. have %s, want %s", string(body), want) + } + }) +} + +func apis() []rpc.API { + return []rpc.API{ + { + Namespace: "test", + Service: &testService{}, + }, + } +} + +type testService struct{} + +func (s *testService) Greet() string { + return "Hello" +} + +func (s *testService) Sleep() { + time.Sleep(1500 * time.Millisecond) +} diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 1cb26a8ea..36b528651 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error { if err != nil { return err } - codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) + codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON) handler.ServeCodec(codec, 0) return nil } diff --git a/rpc/client.go b/rpc/client.go index d89aa6927..a509cb2e0 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { return err } } - err := c.writeConn.writeJSON(ctx, msg) + err := c.writeConn.writeJSON(ctx, msg, false) if err != nil { c.writeConn = nil if !retry { @@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) { for { msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { - codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) + msg := errorMessage(&parseError{err.Error()}) + codec.writeJSON(context.Background(), msg, true) } if err != nil { c.readErr <- err diff --git a/rpc/errors.go b/rpc/errors.go index 9a19e9fe6..7188332d5 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -60,10 +60,15 @@ var ( const ( errcodeDefault = -32000 errcodeNotificationsUnsupported = -32001 + errcodeTimeout = -32002 errcodePanic = -32603 errcodeMarshalError = -32603 ) +const ( + errMsgTimeout = "request timed out" +) + type methodNotFoundError struct{ method string } func (e *methodNotFoundError) ErrorCode() int { return -32601 } diff --git a/rpc/handler.go b/rpc/handler.go index f3052e7eb..c2e7d7dc0 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * return h } +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// timeout sends the responses added so far. For the remaining unanswered call +// messages, it sends a timeout error response. +func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + b.resp = append(b.resp, resp) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { - h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + resp := errorMessage(&invalidRequestError{"empty batch"}) + h.conn.writeJSON(cp.ctx, resp, true) }) return } @@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - answers := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range calls { - if answer := h.handleCallMsg(cp, msg); answer != nil { - answers = append(answers, answer) + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + callBuffer.timeout(cp.ctx, h.conn) + }) + } + + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) } + if timer != nil { + timer.Stop() + } + callBuffer.write(cp.ctx, h.conn) h.addSubscriptions(cp.notifiers) - if len(answers) > 0 { - h.conn.writeJSON(cp.ctx, answers) - } for _, n := range cp.notifiers { n.activate() } @@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { return } h.startCallProc(func(cp *callProc) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } h.addSubscriptions(cp.notifiers) if answer != nil { - h.conn.writeJSON(cp.ctx, answer) + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) } for _, n := range cp.notifiers { n.activate() @@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage } start := time.Now() answer := h.runMethod(cp.ctx, msg, callb, args) - // Collect the statistics for RPC calls if metrics is enabled. // We only care about pure rpc call. Filter out subscription. if callb != h.unsubscribeCb { diff --git a/rpc/http.go b/rpc/http.go index 0ba6588f9..bbabe15ba 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -23,9 +23,11 @@ import ( "errors" "fmt" "io" + "math" "mime" "net/http" "net/url" + "strconv" "sync" "time" ) @@ -52,7 +54,7 @@ type httpConn struct { // and some methods don't work. The panic() stubs here exist to ensure // this special treatment is correct. -func (hc *httpConn) writeJSON(context.Context, interface{}) error { +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { panic("writeJSON called on httpConn") } @@ -256,7 +258,42 @@ type httpServerConn struct { func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { body := io.LimitReader(r.Body, maxRequestContentLength) conn := &httpServerConn{Reader: body, Writer: w, r: r} - return NewCodec(conn) + + encoder := func(v any, isErrorResponse bool) error { + if !isErrorResponse { + return json.NewEncoder(conn).Encode(v) + } + + // It's an error response and requires special treatment. + // + // In case of a timeout error, the response must be written before the HTTP + // server's write timeout occurs. So we need to flush the response. The + // Content-Length header also needs to be set to ensure the client knows + // when it has the full response. + encdata, err := json.Marshal(v) + if err != nil { + return err + } + w.Header().Set("content-length", strconv.Itoa(len(encdata))) + + // If this request is wrapped in a handler that might remove Content-Length (such + // as the automatic gzip we do in package node), we need to ensure the HTTP server + // doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked + // encoding might not be finished correctly, and some clients do not like it when + // the final chunk is missing. + w.Header().Set("transfer-encoding", "identity") + + _, err = w.Write(encdata) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return err + } + + dec := json.NewDecoder(conn) + dec.UseNumber() + + return NewFuncCodec(conn, encoder, dec.Decode) } // Close does nothing and always returns nil. @@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/rpc/json.go b/rpc/json.go index 1064939ff..8a3b162ca 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -168,18 +168,22 @@ type ConnRemoteAddr interface { // support for parsing arguments and serializing (result) objects. type jsonCodec struct { remote string - closer sync.Once // close closed channel once - closeCh chan interface{} // closed on Close - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports conn deadlineCloser } +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + // NewFuncCodec creates a codec which uses the given functions to read and write. If conn // implements ConnRemoteAddr, log messages will use it to include the remote address of // the connection. -func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { codec := &jsonCodec{ closeCh: make(chan interface{}), encode: encode, @@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewFuncCodec(conn, enc.Encode, dec.Decode) + + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) } func (c *jsonCodec) peerInfo() PeerInfo { @@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err return messages, batch, nil } -func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { c.encMu.Lock() defer c.encMu.Unlock() @@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { deadline = time.Now().Add(defaultWriteTimeout) } c.conn.SetWriteDeadline(deadline) - return c.encode(v) + return c.encode(v, isErrorResponse) } func (c *jsonCodec) close() { diff --git a/rpc/server.go b/rpc/server.go index fe162d5a4..9c72c26d7 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { reqs, batch, err := codec.readBatch() if err != nil { if err != io.EOF { - codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) + resp := errorMessage(&invalidMessageError{"parse error"}) + codec.writeJSON(ctx, resp, true) } return } diff --git a/rpc/subscription.go b/rpc/subscription.go index d7ba784fc..334ead3ac 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -175,11 +175,13 @@ func (n *Notifier) activate() error { func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) ctx := context.Background() - return n.h.conn.writeJSON(ctx, &jsonrpcMessage{ + + msg := &jsonrpcMessage{ Version: vsn, Method: n.namespace + notificationMethodSuffix, Params: params, - }) + } + return n.h.conn.writeJSON(ctx, msg, false) } // A Subscription is created by a notifier and tied to that notifier. The client can use diff --git a/rpc/types.go b/rpc/types.go index e7158796e..9dda067e7 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -51,7 +51,9 @@ type ServerCodec interface { // jsonWriter can write JSON messages to its underlying connection. // Implementations must be safe for concurrent use. type jsonWriter interface { - writeJSON(context.Context, interface{}) error + // writeJSON writes a message to the connection. + writeJSON(ctx context.Context, msg interface{}, isError bool) error + // Closed returns a channel which is closed when the connection is closed. closed() <-chan interface{} // RemoteAddr returns the peer address of the connection. diff --git a/rpc/websocket.go b/rpc/websocket.go index 21e446e9f..0ac2a2792 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve conn.SetReadDeadline(time.Time{}) return nil }) + + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } wc := &websocketCodec{ - jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), info: PeerInfo{ @@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo { return wc.info } -func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { - err := wc.jsonCodec.writeJSON(ctx, v) +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) if err == nil { // Notify pingLoop to delay the next idle ping. select {