diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 491c73152..46acd1529 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) { func createNode(t *testing.T) *node.Node { stack, err := node.New(&node.Config{ - HTTPHost: "127.0.0.1", - HTTPPort: 0, - WSHost: "127.0.0.1", - WSPort: 0, + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + HTTPTimeouts: node.DefaultConfig.HTTPTimeouts, }) if err != nil { t.Fatalf("could not create node: %v", err) diff --git a/graphql/service.go b/graphql/service.go index 684fdc712..4392dd83e 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -20,12 +20,16 @@ import ( "context" "encoding/json" "net/http" + "strconv" + "sync" "time" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/rpc" "github.com/graph-gophers/graphql-go" + gqlErrors "github.com/graph-gophers/graphql-go/errors" ) type handler struct { @@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + var ( + ctx = r.Context() + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + ctx, cancel = context.WithCancel(ctx) defer cancel() - response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) - responseJSON, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if len(response.Errors) > 0 { - w.WriteHeader(http.StatusBadRequest) + if timeout, ok := rpc.ContextRequestTimeout(ctx); ok { + timer = time.AfterFunc(timeout, func() { + responded.Do(func() { + // Cancel request handling. + cancel() + + // Create the timeout response. + response := &graphql.Response{ + Errors: []*gqlErrors.QueryError{{Message: "request timed out"}}, + } + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Setting this disables gzip compression in package node. + w.Header().Set("transfer-encoding", "identity") + + // Flush the response. Since we are writing close to the response timeout, + // chunked transfer encoding must be disabled by setting content-length. + w.Header().Set("content-type", "application/json") + w.Header().Set("content-length", strconv.Itoa(len(responseJSON))) + w.Write(responseJSON) + if flush, ok := w.(http.Flusher); ok { + flush.Flush() + } + }) + }) } - w.Header().Set("Content-Type", "application/json") - w.Write(responseJSON) + response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) + timer.Stop() + responded.Do(func() { + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(response.Errors) > 0 { + w.WriteHeader(http.StatusBadRequest) + } + w.Header().Set("Content-Type", "application/json") + w.Write(responseJSON) + }) } // New constructs a new GraphQL service instance. diff --git a/node/api_test.go b/node/api_test.go index d76cb943e..8761c4883 100644 --- a/node/api_test.go +++ b/node/api_test.go @@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) { config := test.cfg // config.Logger = testlog.Logger(t, log.LvlDebug) config.P2P.NoDiscovery = true + if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) { + config.HTTPTimeouts = rpc.DefaultHTTPTimeouts + } // Create Node. stack, err := New(&config) diff --git a/node/node_test.go b/node/node_test.go index 7c76e21f6..560d487fa 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } for _, path := range test.wantHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 200 { t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode) } } for _, path := range test.wantNoHTTP { - resp := rpcRequest(t, httpBase+path) + resp := rpcRequest(t, httpBase+path, testMethod) if resp.StatusCode != 404 { t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode) } @@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { func createNode(t *testing.T, httpPort, wsPort int) *Node { conf := &Config{ - HTTPHost: "127.0.0.1", - HTTPPort: httpPort, - WSHost: "127.0.0.1", - WSPort: wsPort, + HTTPHost: "127.0.0.1", + HTTPPort: httpPort, + WSHost: "127.0.0.1", + WSPort: wsPort, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, } node, err := New(conf) if err != nil { diff --git a/node/rpcstack.go b/node/rpcstack.go index 8244c892f..97d591642 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -24,6 +24,7 @@ import ( "net" "net/http" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + // if http-rpc is enabled, try to serve request rpc := h.httpHandler.Load().(*rpcHandler) if rpc != nil { @@ -462,17 +464,94 @@ var gzPool = sync.Pool{ } type gzipResponseWriter struct { - io.Writer - http.ResponseWriter + resp http.ResponseWriter + + gz *gzip.Writer + contentLength uint64 // total length of the uncompressed response + written uint64 // amount of written bytes from the uncompressed response + hasLength bool // true if uncompressed response had Content-Length + inited bool // true after init was called for the first time +} + +// init runs just before response headers are written. Among other things, this function +// also decides whether compression will be applied at all. +func (w *gzipResponseWriter) init() { + if w.inited { + return + } + w.inited = true + + hdr := w.resp.Header() + length := hdr.Get("content-length") + if len(length) > 0 { + if n, err := strconv.ParseUint(length, 10, 64); err != nil { + w.hasLength = true + w.contentLength = n + } + } + + // Setting Transfer-Encoding to "identity" explicitly disables compression. net/http + // also recognizes this header value and uses it to disable "chunked" transfer + // encoding, trimming the header from the response. This means downstream handlers can + // set this without harm, even if they aren't wrapped by newGzipHandler. + // + // In go-ethereum, we use this signal to disable compression for certain error + // responses which are flushed out close to the write deadline of the response. For + // these cases, we want to avoid chunked transfer encoding and compression because + // they require additional output that may not get written in time. + passthrough := hdr.Get("transfer-encoding") == "identity" + if !passthrough { + w.gz = gzPool.Get().(*gzip.Writer) + w.gz.Reset(w.resp) + hdr.Del("content-length") + hdr.Set("content-encoding", "gzip") + } +} + +func (w *gzipResponseWriter) Header() http.Header { + return w.resp.Header() } func (w *gzipResponseWriter) WriteHeader(status int) { - w.Header().Del("Content-Length") - w.ResponseWriter.WriteHeader(status) + w.init() + w.resp.WriteHeader(status) } func (w *gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) + w.init() + + if w.gz == nil { + // Compression is disabled. + return w.resp.Write(b) + } + + n, err := w.gz.Write(b) + w.written += uint64(n) + if w.hasLength && w.written >= w.contentLength { + // The HTTP handler has finished writing the entire uncompressed response. Close + // the gzip stream to ensure the footer will be seen by the client in case the + // response is flushed after this call to write. + err = w.gz.Close() + } + return n, err +} + +func (w *gzipResponseWriter) Flush() { + if w.gz != nil { + w.gz.Flush() + } + if f, ok := w.resp.(http.Flusher); ok { + f.Flush() + } +} + +func (w *gzipResponseWriter) close() { + if w.gz == nil { + return + } + w.gz.Close() + gzPool.Put(w.gz) + w.gz = nil } func newGzipHandler(next http.Handler) http.Handler { @@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler { return } - w.Header().Set("Content-Encoding", "gzip") + wrapper := &gzipResponseWriter{resp: w} + defer wrapper.close() - gz := gzPool.Get().(*gzip.Writer) - defer gzPool.Put(gz) - - gz.Reset(w) - defer gz.Close() - - next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + next.ServeHTTP(wrapper, r) }) } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index ebc253800..795bc93c8 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -19,7 +19,9 @@ package node import ( "bytes" "fmt" + "io" "net/http" + "net/http/httptest" "net/url" "strconv" "strings" @@ -34,29 +36,31 @@ import ( "github.com/stretchr/testify/assert" ) +const testMethod = "rpc_modules" + // TestCorsHandler makes sure CORS are properly handled on the http server. func TestCorsHandler(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "origin", "test.com") + resp := rpcRequest(t, url, testMethod, "origin", "test.com") assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) - resp2 := rpcRequest(t, url, "origin", "bad") + resp2 := rpcRequest(t, url, testMethod, "origin", "bad") assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) } // TestVhosts makes sure vhosts are properly handled on the http server. func TestVhosts(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}) + srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil) defer srv.stop() url := "http://" + srv.listenAddr() - resp := rpcRequest(t, url, "host", "test") + resp := rpcRequest(t, url, testMethod, "host", "test") assert.Equal(t, resp.StatusCode, http.StatusOK) - resp2 := rpcRequest(t, url, "host", "bad") + resp2 := rpcRequest(t, url, testMethod, "host", "bad") assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } @@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) { }, } for _, tc := range tests { - srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) + srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil) url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { if err := wsRequest(t, url, "Origin", origin); err != nil { @@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) { } } -func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer { +func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer { t.Helper() - srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) - assert.NoError(t, srv.enableRPC(nil, *conf)) + if timeouts == nil { + timeouts = &rpc.DefaultHTTPTimeouts + } + srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts) + assert.NoError(t, srv.enableRPC(apis(), *conf)) if ws { assert.NoError(t, srv.enableWS(nil, *wsConf)) } @@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error { } // rpcRequest performs a JSON-RPC request to the given URL. -func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response { +func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response { + t.Helper() + + body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response { + reqs := make([]string, len(methods)) + for i, m := range methods { + reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m) + } + body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ",")) + return baseRpcRequest(t, url, body, extraHeaders...) +} + +func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response { t.Helper() // Create the request. - body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`)) + body := bytes.NewReader([]byte(bodyStr)) req, err := http.NewRequest("POST", url, body) if err != nil { t.Fatal("could not create http request:", err) } req.Header.Set("content-type", "application/json") + req.Header.Set("accept-encoding", "identity") // Apply extra headers. if len(extraHeaders)%2 != 0 { @@ -315,7 +339,7 @@ func TestJWT(t *testing.T) { return ss } srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) @@ -348,7 +372,7 @@ func TestJWT(t *testing.T) { t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) } token = tokenFn() - if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 { t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) } } @@ -414,10 +438,176 @@ func TestJWT(t *testing.T) { } token = tokenFn() - resp := rpcRequest(t, htUrl, "Authorization", token) + resp := rpcRequest(t, htUrl, testMethod, "Authorization", token) if resp.StatusCode != http.StatusUnauthorized { t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) } } srv.stop() } + +func TestGzipHandler(t *testing.T) { + type gzipTest struct { + name string + handler http.HandlerFunc + status int + isGzip bool + header map[string]string + } + tests := []gzipTest{ + { + name: "Write", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: true, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "WriteContentLength", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-length", "8") + w.Write([]byte("response")) + }, + isGzip: true, + status: 200, + }, + { + name: "Flush", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("res")) + w.(http.Flusher).Flush() + w.Write([]byte("ponse")) + }, + isGzip: true, + status: 200, + }, + { + name: "disable", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.Write([]byte("response")) + }, + isGzip: false, + status: 200, + header: map[string]string{"x-foo": "bar"}, + }, + { + name: "disable-WriteHeader", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("transfer-encoding", "identity") + w.Header().Set("x-foo", "bar") + w.WriteHeader(205) + w.Write([]byte("response")) + }, + isGzip: false, + status: 205, + header: map[string]string{"x-foo": "bar"}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + srv := httptest.NewServer(newGzipHandler(test.handler)) + defer srv.Close() + + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + content, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + wasGzip := resp.Uncompressed + + if string(content) != "response" { + t.Fatalf("wrong response content %q", content) + } + if wasGzip != test.isGzip { + t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip) + } + if resp.StatusCode != test.status { + t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status) + } + for name, expectedValue := range test.header { + if v := resp.Header.Get(name); v != expectedValue { + t.Fatalf("response header %s == %s, want %s", name, v, expectedValue) + } + } + }) + } +} + +func TestHTTPWriteTimeout(t *testing.T) { + const ( + timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}` + greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}` + ) + // Set-up server + timeouts := rpc.DefaultHTTPTimeouts + timeouts.WriteTimeout = time.Second + srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts) + url := fmt.Sprintf("http://%v", srv.listenAddr()) + + // Send normal request + t.Run("message", func(t *testing.T) { + resp := rpcRequest(t, url, "test_sleep") + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != timeoutRes { + t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes) + } + }) + + // Batch request + t.Run("batch", func(t *testing.T) { + want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes) + resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"}) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != want { + t.Errorf("wrong response. have %s, want %s", string(body), want) + } + }) +} + +func apis() []rpc.API { + return []rpc.API{ + { + Namespace: "test", + Service: &testService{}, + }, + } +} + +type testService struct{} + +func (s *testService) Greet() string { + return "Hello" +} + +func (s *testService) Sleep() { + time.Sleep(1500 * time.Millisecond) +} diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 1cb26a8ea..36b528651 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error { if err != nil { return err } - codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) + codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON) handler.ServeCodec(codec, 0) return nil } diff --git a/rpc/client.go b/rpc/client.go index d89aa6927..a509cb2e0 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { return err } } - err := c.writeConn.writeJSON(ctx, msg) + err := c.writeConn.writeJSON(ctx, msg, false) if err != nil { c.writeConn = nil if !retry { @@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) { for { msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { - codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) + msg := errorMessage(&parseError{err.Error()}) + codec.writeJSON(context.Background(), msg, true) } if err != nil { c.readErr <- err diff --git a/rpc/errors.go b/rpc/errors.go index 9a19e9fe6..7188332d5 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -60,10 +60,15 @@ var ( const ( errcodeDefault = -32000 errcodeNotificationsUnsupported = -32001 + errcodeTimeout = -32002 errcodePanic = -32603 errcodeMarshalError = -32603 ) +const ( + errMsgTimeout = "request timed out" +) + type methodNotFoundError struct{ method string } func (e *methodNotFoundError) ErrorCode() int { return -32601 } diff --git a/rpc/handler.go b/rpc/handler.go index f3052e7eb..c2e7d7dc0 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * return h } +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// timeout sends the responses added so far. For the remaining unanswered call +// messages, it sends a timeout error response. +func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + b.resp = append(b.resp, resp) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { - h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + resp := errorMessage(&invalidRequestError{"empty batch"}) + h.conn.writeJSON(cp.ctx, resp, true) }) return } @@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - answers := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range calls { - if answer := h.handleCallMsg(cp, msg); answer != nil { - answers = append(answers, answer) + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + callBuffer.timeout(cp.ctx, h.conn) + }) + } + + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) } + if timer != nil { + timer.Stop() + } + callBuffer.write(cp.ctx, h.conn) h.addSubscriptions(cp.notifiers) - if len(answers) > 0 { - h.conn.writeJSON(cp.ctx, answers) - } for _, n := range cp.notifiers { n.activate() } @@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { return } h.startCallProc(func(cp *callProc) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } h.addSubscriptions(cp.notifiers) if answer != nil { - h.conn.writeJSON(cp.ctx, answer) + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) } for _, n := range cp.notifiers { n.activate() @@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage } start := time.Now() answer := h.runMethod(cp.ctx, msg, callb, args) - // Collect the statistics for RPC calls if metrics is enabled. // We only care about pure rpc call. Filter out subscription. if callb != h.unsubscribeCb { diff --git a/rpc/http.go b/rpc/http.go index 0ba6588f9..bbabe15ba 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -23,9 +23,11 @@ import ( "errors" "fmt" "io" + "math" "mime" "net/http" "net/url" + "strconv" "sync" "time" ) @@ -52,7 +54,7 @@ type httpConn struct { // and some methods don't work. The panic() stubs here exist to ensure // this special treatment is correct. -func (hc *httpConn) writeJSON(context.Context, interface{}) error { +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { panic("writeJSON called on httpConn") } @@ -256,7 +258,42 @@ type httpServerConn struct { func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { body := io.LimitReader(r.Body, maxRequestContentLength) conn := &httpServerConn{Reader: body, Writer: w, r: r} - return NewCodec(conn) + + encoder := func(v any, isErrorResponse bool) error { + if !isErrorResponse { + return json.NewEncoder(conn).Encode(v) + } + + // It's an error response and requires special treatment. + // + // In case of a timeout error, the response must be written before the HTTP + // server's write timeout occurs. So we need to flush the response. The + // Content-Length header also needs to be set to ensure the client knows + // when it has the full response. + encdata, err := json.Marshal(v) + if err != nil { + return err + } + w.Header().Set("content-length", strconv.Itoa(len(encdata))) + + // If this request is wrapped in a handler that might remove Content-Length (such + // as the automatic gzip we do in package node), we need to ensure the HTTP server + // doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked + // encoding might not be finished correctly, and some clients do not like it when + // the final chunk is missing. + w.Header().Set("transfer-encoding", "identity") + + _, err = w.Write(encdata) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return err + } + + dec := json.NewDecoder(conn) + dec.UseNumber() + + return NewFuncCodec(conn, encoder, dec.Decode) } // Close does nothing and always returns nil. @@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/rpc/json.go b/rpc/json.go index 1064939ff..8a3b162ca 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -168,18 +168,22 @@ type ConnRemoteAddr interface { // support for parsing arguments and serializing (result) objects. type jsonCodec struct { remote string - closer sync.Once // close closed channel once - closeCh chan interface{} // closed on Close - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports conn deadlineCloser } +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + // NewFuncCodec creates a codec which uses the given functions to read and write. If conn // implements ConnRemoteAddr, log messages will use it to include the remote address of // the connection. -func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { codec := &jsonCodec{ closeCh: make(chan interface{}), encode: encode, @@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewFuncCodec(conn, enc.Encode, dec.Decode) + + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) } func (c *jsonCodec) peerInfo() PeerInfo { @@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err return messages, batch, nil } -func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { c.encMu.Lock() defer c.encMu.Unlock() @@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { deadline = time.Now().Add(defaultWriteTimeout) } c.conn.SetWriteDeadline(deadline) - return c.encode(v) + return c.encode(v, isErrorResponse) } func (c *jsonCodec) close() { diff --git a/rpc/server.go b/rpc/server.go index fe162d5a4..9c72c26d7 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { reqs, batch, err := codec.readBatch() if err != nil { if err != io.EOF { - codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) + resp := errorMessage(&invalidMessageError{"parse error"}) + codec.writeJSON(ctx, resp, true) } return } diff --git a/rpc/subscription.go b/rpc/subscription.go index d7ba784fc..334ead3ac 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -175,11 +175,13 @@ func (n *Notifier) activate() error { func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) ctx := context.Background() - return n.h.conn.writeJSON(ctx, &jsonrpcMessage{ + + msg := &jsonrpcMessage{ Version: vsn, Method: n.namespace + notificationMethodSuffix, Params: params, - }) + } + return n.h.conn.writeJSON(ctx, msg, false) } // A Subscription is created by a notifier and tied to that notifier. The client can use diff --git a/rpc/types.go b/rpc/types.go index e7158796e..9dda067e7 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -51,7 +51,9 @@ type ServerCodec interface { // jsonWriter can write JSON messages to its underlying connection. // Implementations must be safe for concurrent use. type jsonWriter interface { - writeJSON(context.Context, interface{}) error + // writeJSON writes a message to the connection. + writeJSON(ctx context.Context, msg interface{}, isError bool) error + // Closed returns a channel which is closed when the connection is closed. closed() <-chan interface{} // RemoteAddr returns the peer address of the connection. diff --git a/rpc/websocket.go b/rpc/websocket.go index 21e446e9f..0ac2a2792 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve conn.SetReadDeadline(time.Time{}) return nil }) + + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } wc := &websocketCodec{ - jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), info: PeerInfo{ @@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo { return wc.info } -func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { - err := wc.jsonCodec.writeJSON(ctx, v) +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) if err == nil { // Notify pingLoop to delay the next idle ping. select {