graphql, node, rpc: improve HTTP write timeout handling (#25457)

Here we add special handling for sending an error response when the write timeout of the
HTTP server is just about to expire. This is surprisingly difficult to get right, since is
must be ensured that all output is fully flushed in time, which needs support from
multiple levels of the RPC handler stack:

The timeout response can't use chunked transfer-encoding because there is no way to write
the final terminating chunk. net/http writes it when the topmost handler returns, but the
timeout will already be over by the time that happens. We decided to disable chunked
encoding by setting content-length explicitly.

Gzip compression must also be disabled for timeout responses because we don't know the
true content-length before compressing all output, i.e. compression would reintroduce
chunked transfer-encoding.
This commit is contained in:
Sina Mahmoodi 2022-12-07 14:02:14 +01:00 committed by GitHub
parent b44abf56a9
commit f20eba426a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 606 additions and 80 deletions

View File

@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) {
func createNode(t *testing.T) *node.Node { func createNode(t *testing.T) *node.Node {
stack, err := node.New(&node.Config{ stack, err := node.New(&node.Config{
HTTPHost: "127.0.0.1", HTTPHost: "127.0.0.1",
HTTPPort: 0, HTTPPort: 0,
WSHost: "127.0.0.1", WSHost: "127.0.0.1",
WSPort: 0, WSPort: 0,
HTTPTimeouts: node.DefaultConfig.HTTPTimeouts,
}) })
if err != nil { if err != nil {
t.Fatalf("could not create node: %v", err) t.Fatalf("could not create node: %v", err)

View File

@ -20,12 +20,16 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/rpc"
"github.com/graph-gophers/graphql-go" "github.com/graph-gophers/graphql-go"
gqlErrors "github.com/graph-gophers/graphql-go/errors"
) )
type handler struct { type handler struct {
@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) var (
ctx = r.Context()
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
ctx, cancel = context.WithCancel(ctx)
defer cancel() defer cancel()
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables) if timeout, ok := rpc.ContextRequestTimeout(ctx); ok {
responseJSON, err := json.Marshal(response) timer = time.AfterFunc(timeout, func() {
if err != nil { responded.Do(func() {
http.Error(w, err.Error(), http.StatusInternalServerError) // Cancel request handling.
return cancel()
}
if len(response.Errors) > 0 { // Create the timeout response.
w.WriteHeader(http.StatusBadRequest) response := &graphql.Response{
Errors: []*gqlErrors.QueryError{{Message: "request timed out"}},
}
responseJSON, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Setting this disables gzip compression in package node.
w.Header().Set("transfer-encoding", "identity")
// Flush the response. Since we are writing close to the response timeout,
// chunked transfer encoding must be disabled by setting content-length.
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.Itoa(len(responseJSON)))
w.Write(responseJSON)
if flush, ok := w.(http.Flusher); ok {
flush.Flush()
}
})
})
} }
w.Header().Set("Content-Type", "application/json") response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
w.Write(responseJSON) timer.Stop()
responded.Do(func() {
responseJSON, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(response.Errors) > 0 {
w.WriteHeader(http.StatusBadRequest)
}
w.Header().Set("Content-Type", "application/json")
w.Write(responseJSON)
})
} }
// New constructs a new GraphQL service instance. // New constructs a new GraphQL service instance.

View File

@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) {
config := test.cfg config := test.cfg
// config.Logger = testlog.Logger(t, log.LvlDebug) // config.Logger = testlog.Logger(t, log.LvlDebug)
config.P2P.NoDiscovery = true config.P2P.NoDiscovery = true
if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) {
config.HTTPTimeouts = rpc.DefaultHTTPTimeouts
}
// Create Node. // Create Node.
stack, err := New(&config) stack, err := New(&config)

View File

@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
} }
for _, path := range test.wantHTTP { for _, path := range test.wantHTTP {
resp := rpcRequest(t, httpBase+path) resp := rpcRequest(t, httpBase+path, testMethod)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode) t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode)
} }
} }
for _, path := range test.wantNoHTTP { for _, path := range test.wantNoHTTP {
resp := rpcRequest(t, httpBase+path) resp := rpcRequest(t, httpBase+path, testMethod)
if resp.StatusCode != 404 { if resp.StatusCode != 404 {
t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode) t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode)
} }
@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
func createNode(t *testing.T, httpPort, wsPort int) *Node { func createNode(t *testing.T, httpPort, wsPort int) *Node {
conf := &Config{ conf := &Config{
HTTPHost: "127.0.0.1", HTTPHost: "127.0.0.1",
HTTPPort: httpPort, HTTPPort: httpPort,
WSHost: "127.0.0.1", WSHost: "127.0.0.1",
WSPort: wsPort, WSPort: wsPort,
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
} }
node, err := New(conf) node, err := New(conf)
if err != nil { if err != nil {

View File

@ -24,6 +24,7 @@ import (
"net" "net"
"net/http" "net/http"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
return return
} }
// if http-rpc is enabled, try to serve request // if http-rpc is enabled, try to serve request
rpc := h.httpHandler.Load().(*rpcHandler) rpc := h.httpHandler.Load().(*rpcHandler)
if rpc != nil { if rpc != nil {
@ -462,17 +464,94 @@ var gzPool = sync.Pool{
} }
type gzipResponseWriter struct { type gzipResponseWriter struct {
io.Writer resp http.ResponseWriter
http.ResponseWriter
gz *gzip.Writer
contentLength uint64 // total length of the uncompressed response
written uint64 // amount of written bytes from the uncompressed response
hasLength bool // true if uncompressed response had Content-Length
inited bool // true after init was called for the first time
}
// init runs just before response headers are written. Among other things, this function
// also decides whether compression will be applied at all.
func (w *gzipResponseWriter) init() {
if w.inited {
return
}
w.inited = true
hdr := w.resp.Header()
length := hdr.Get("content-length")
if len(length) > 0 {
if n, err := strconv.ParseUint(length, 10, 64); err != nil {
w.hasLength = true
w.contentLength = n
}
}
// Setting Transfer-Encoding to "identity" explicitly disables compression. net/http
// also recognizes this header value and uses it to disable "chunked" transfer
// encoding, trimming the header from the response. This means downstream handlers can
// set this without harm, even if they aren't wrapped by newGzipHandler.
//
// In go-ethereum, we use this signal to disable compression for certain error
// responses which are flushed out close to the write deadline of the response. For
// these cases, we want to avoid chunked transfer encoding and compression because
// they require additional output that may not get written in time.
passthrough := hdr.Get("transfer-encoding") == "identity"
if !passthrough {
w.gz = gzPool.Get().(*gzip.Writer)
w.gz.Reset(w.resp)
hdr.Del("content-length")
hdr.Set("content-encoding", "gzip")
}
}
func (w *gzipResponseWriter) Header() http.Header {
return w.resp.Header()
} }
func (w *gzipResponseWriter) WriteHeader(status int) { func (w *gzipResponseWriter) WriteHeader(status int) {
w.Header().Del("Content-Length") w.init()
w.ResponseWriter.WriteHeader(status) w.resp.WriteHeader(status)
} }
func (w *gzipResponseWriter) Write(b []byte) (int, error) { func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b) w.init()
if w.gz == nil {
// Compression is disabled.
return w.resp.Write(b)
}
n, err := w.gz.Write(b)
w.written += uint64(n)
if w.hasLength && w.written >= w.contentLength {
// The HTTP handler has finished writing the entire uncompressed response. Close
// the gzip stream to ensure the footer will be seen by the client in case the
// response is flushed after this call to write.
err = w.gz.Close()
}
return n, err
}
func (w *gzipResponseWriter) Flush() {
if w.gz != nil {
w.gz.Flush()
}
if f, ok := w.resp.(http.Flusher); ok {
f.Flush()
}
}
func (w *gzipResponseWriter) close() {
if w.gz == nil {
return
}
w.gz.Close()
gzPool.Put(w.gz)
w.gz = nil
} }
func newGzipHandler(next http.Handler) http.Handler { func newGzipHandler(next http.Handler) http.Handler {
@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler {
return return
} }
w.Header().Set("Content-Encoding", "gzip") wrapper := &gzipResponseWriter{resp: w}
defer wrapper.close()
gz := gzPool.Get().(*gzip.Writer) next.ServeHTTP(wrapper, r)
defer gzPool.Put(gz)
gz.Reset(w)
defer gz.Close()
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
}) })
} }

View File

@ -19,7 +19,9 @@ package node
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@ -34,29 +36,31 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const testMethod = "rpc_modules"
// TestCorsHandler makes sure CORS are properly handled on the http server. // TestCorsHandler makes sure CORS are properly handled on the http server.
func TestCorsHandler(t *testing.T) { func TestCorsHandler(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}) srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil)
defer srv.stop() defer srv.stop()
url := "http://" + srv.listenAddr() url := "http://" + srv.listenAddr()
resp := rpcRequest(t, url, "origin", "test.com") resp := rpcRequest(t, url, testMethod, "origin", "test.com")
assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin"))
resp2 := rpcRequest(t, url, "origin", "bad") resp2 := rpcRequest(t, url, testMethod, "origin", "bad")
assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin"))
} }
// TestVhosts makes sure vhosts are properly handled on the http server. // TestVhosts makes sure vhosts are properly handled on the http server.
func TestVhosts(t *testing.T) { func TestVhosts(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}) srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil)
defer srv.stop() defer srv.stop()
url := "http://" + srv.listenAddr() url := "http://" + srv.listenAddr()
resp := rpcRequest(t, url, "host", "test") resp := rpcRequest(t, url, testMethod, "host", "test")
assert.Equal(t, resp.StatusCode, http.StatusOK) assert.Equal(t, resp.StatusCode, http.StatusOK)
resp2 := rpcRequest(t, url, "host", "bad") resp2 := rpcRequest(t, url, testMethod, "host", "bad")
assert.Equal(t, resp2.StatusCode, http.StatusForbidden) assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
} }
@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) {
}, },
} }
for _, tc := range tests { for _, tc := range tests {
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil)
url := fmt.Sprintf("ws://%v", srv.listenAddr()) url := fmt.Sprintf("ws://%v", srv.listenAddr())
for _, origin := range tc.expOk { for _, origin := range tc.expOk {
if err := wsRequest(t, url, "Origin", origin); err != nil { if err := wsRequest(t, url, "Origin", origin); err != nil {
@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) {
} }
} }
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer { func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer {
t.Helper() t.Helper()
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) if timeouts == nil {
assert.NoError(t, srv.enableRPC(nil, *conf)) timeouts = &rpc.DefaultHTTPTimeouts
}
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts)
assert.NoError(t, srv.enableRPC(apis(), *conf))
if ws { if ws {
assert.NoError(t, srv.enableWS(nil, *wsConf)) assert.NoError(t, srv.enableWS(nil, *wsConf))
} }
@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
} }
// rpcRequest performs a JSON-RPC request to the given URL. // rpcRequest performs a JSON-RPC request to the given URL.
func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response { func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response {
t.Helper()
body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method)
return baseRpcRequest(t, url, body, extraHeaders...)
}
func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response {
reqs := make([]string, len(methods))
for i, m := range methods {
reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m)
}
body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ","))
return baseRpcRequest(t, url, body, extraHeaders...)
}
func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response {
t.Helper() t.Helper()
// Create the request. // Create the request.
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`)) body := bytes.NewReader([]byte(bodyStr))
req, err := http.NewRequest("POST", url, body) req, err := http.NewRequest("POST", url, body)
if err != nil { if err != nil {
t.Fatal("could not create http request:", err) t.Fatal("could not create http request:", err)
} }
req.Header.Set("content-type", "application/json") req.Header.Set("content-type", "application/json")
req.Header.Set("accept-encoding", "identity")
// Apply extra headers. // Apply extra headers.
if len(extraHeaders)%2 != 0 { if len(extraHeaders)%2 != 0 {
@ -315,7 +339,7 @@ func TestJWT(t *testing.T) {
return ss return ss
} }
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
@ -348,7 +372,7 @@ func TestJWT(t *testing.T) {
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
} }
token = tokenFn() token = tokenFn()
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 {
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
} }
} }
@ -414,10 +438,176 @@ func TestJWT(t *testing.T) {
} }
token = tokenFn() token = tokenFn()
resp := rpcRequest(t, htUrl, "Authorization", token) resp := rpcRequest(t, htUrl, testMethod, "Authorization", token)
if resp.StatusCode != http.StatusUnauthorized { if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
} }
} }
srv.stop() srv.stop()
} }
func TestGzipHandler(t *testing.T) {
type gzipTest struct {
name string
handler http.HandlerFunc
status int
isGzip bool
header map[string]string
}
tests := []gzipTest{
{
name: "Write",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("response"))
},
isGzip: true,
status: 200,
},
{
name: "WriteHeader",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-foo", "bar")
w.WriteHeader(205)
w.Write([]byte("response"))
},
isGzip: true,
status: 205,
header: map[string]string{"x-foo": "bar"},
},
{
name: "WriteContentLength",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-length", "8")
w.Write([]byte("response"))
},
isGzip: true,
status: 200,
},
{
name: "Flush",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("res"))
w.(http.Flusher).Flush()
w.Write([]byte("ponse"))
},
isGzip: true,
status: 200,
},
{
name: "disable",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("transfer-encoding", "identity")
w.Header().Set("x-foo", "bar")
w.Write([]byte("response"))
},
isGzip: false,
status: 200,
header: map[string]string{"x-foo": "bar"},
},
{
name: "disable-WriteHeader",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("transfer-encoding", "identity")
w.Header().Set("x-foo", "bar")
w.WriteHeader(205)
w.Write([]byte("response"))
},
isGzip: false,
status: 205,
header: map[string]string{"x-foo": "bar"},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
srv := httptest.NewServer(newGzipHandler(test.handler))
defer srv.Close()
resp, err := http.Get(srv.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
wasGzip := resp.Uncompressed
if string(content) != "response" {
t.Fatalf("wrong response content %q", content)
}
if wasGzip != test.isGzip {
t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip)
}
if resp.StatusCode != test.status {
t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status)
}
for name, expectedValue := range test.header {
if v := resp.Header.Get(name); v != expectedValue {
t.Fatalf("response header %s == %s, want %s", name, v, expectedValue)
}
}
})
}
}
func TestHTTPWriteTimeout(t *testing.T) {
const (
timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}`
greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}`
)
// Set-up server
timeouts := rpc.DefaultHTTPTimeouts
timeouts.WriteTimeout = time.Second
srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts)
url := fmt.Sprintf("http://%v", srv.listenAddr())
// Send normal request
t.Run("message", func(t *testing.T) {
resp := rpcRequest(t, url, "test_sleep")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != timeoutRes {
t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes)
}
})
// Batch request
t.Run("batch", func(t *testing.T) {
want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes)
resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"})
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != want {
t.Errorf("wrong response. have %s, want %s", string(body), want)
}
})
}
func apis() []rpc.API {
return []rpc.API{
{
Namespace: "test",
Service: &testService{},
},
}
}
type testService struct{}
func (s *testService) Greet() string {
return "Hello"
}
func (s *testService) Sleep() {
time.Sleep(1500 * time.Millisecond)
}

View File

@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
if err != nil { if err != nil {
return err return err
} }
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON)
handler.ServeCodec(codec, 0) handler.ServeCodec(codec, 0)
return nil return nil
} }

View File

@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
return err return err
} }
} }
err := c.writeConn.writeJSON(ctx, msg) err := c.writeConn.writeJSON(ctx, msg, false)
if err != nil { if err != nil {
c.writeConn = nil c.writeConn = nil
if !retry { if !retry {
@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) {
for { for {
msgs, batch, err := codec.readBatch() msgs, batch, err := codec.readBatch()
if _, ok := err.(*json.SyntaxError); ok { if _, ok := err.(*json.SyntaxError); ok {
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) msg := errorMessage(&parseError{err.Error()})
codec.writeJSON(context.Background(), msg, true)
} }
if err != nil { if err != nil {
c.readErr <- err c.readErr <- err

View File

@ -60,10 +60,15 @@ var (
const ( const (
errcodeDefault = -32000 errcodeDefault = -32000
errcodeNotificationsUnsupported = -32001 errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002
errcodePanic = -32603 errcodePanic = -32603
errcodeMarshalError = -32603 errcodeMarshalError = -32603
) )
const (
errMsgTimeout = "request timed out"
)
type methodNotFoundError struct{ method string } type methodNotFoundError struct{ method string }
func (e *methodNotFoundError) ErrorCode() int { return -32601 } func (e *methodNotFoundError) ErrorCode() int { return -32601 }

View File

@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
return h return h
} }
// batchCallBuffer manages in progress call messages and their responses during a batch
// call. Calls need to be synchronized between the processing and timeout-triggering
// goroutines.
type batchCallBuffer struct {
mutex sync.Mutex
calls []*jsonrpcMessage
resp []*jsonrpcMessage
wrote bool
}
// nextCall returns the next unprocessed message.
func (b *batchCallBuffer) nextCall() *jsonrpcMessage {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.calls) == 0 {
return nil
}
// The popping happens in `pushAnswer`. The in progress call is kept
// so we can return an error for it in case of timeout.
msg := b.calls[0]
return msg
}
// pushResponse adds the response to last call returned by nextCall.
func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) {
b.mutex.Lock()
defer b.mutex.Unlock()
if answer != nil {
b.resp = append(b.resp, answer)
}
b.calls = b.calls[1:]
}
// write sends the responses.
func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.doWrite(ctx, conn, false)
}
// timeout sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()
for _, msg := range b.calls {
if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
b.resp = append(b.resp, resp)
}
}
b.doWrite(ctx, conn, true)
}
// doWrite actually writes the response.
// This assumes b.mutex is held.
func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) {
if b.wrote {
return
}
b.wrote = true // can only write once
if len(b.resp) > 0 {
conn.writeJSON(ctx, b.resp, isErrorResponse)
}
}
// handleBatch executes all messages in a batch and returns the responses. // handleBatch executes all messages in a batch and returns the responses.
func (h *handler) handleBatch(msgs []*jsonrpcMessage) { func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
// Emit error response for empty batches: // Emit error response for empty batches:
if len(msgs) == 0 { if len(msgs) == 0 {
h.startCallProc(func(cp *callProc) { h.startCallProc(func(cp *callProc) {
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) resp := errorMessage(&invalidRequestError{"empty batch"})
h.conn.writeJSON(cp.ctx, resp, true)
}) })
return return
} }
@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
} }
// Process calls on a goroutine because they may block indefinitely: // Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) { h.startCallProc(func(cp *callProc) {
answers := make([]*jsonrpcMessage, 0, len(msgs)) var (
for _, msg := range calls { timer *time.Timer
if answer := h.handleCallMsg(cp, msg); answer != nil { cancel context.CancelFunc
answers = append(answers, answer) callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))}
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// currently-running method might not return immediately on timeout, we must wait
// for the timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
callBuffer.timeout(cp.ctx, h.conn)
})
}
for {
// No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil {
break
} }
msg := callBuffer.nextCall()
if msg == nil {
break
}
resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp)
} }
if timer != nil {
timer.Stop()
}
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers) h.addSubscriptions(cp.notifiers)
if len(answers) > 0 {
h.conn.writeJSON(cp.ctx, answers)
}
for _, n := range cp.notifiers { for _, n := range cp.notifiers {
n.activate() n.activate()
} }
@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
return return
} }
h.startCallProc(func(cp *callProc) { h.startCallProc(func(cp *callProc) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg) answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers) h.addSubscriptions(cp.notifiers)
if answer != nil { if answer != nil {
h.conn.writeJSON(cp.ctx, answer) responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
} }
for _, n := range cp.notifiers { for _, n := range cp.notifiers {
n.activate() n.activate()
@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
} }
start := time.Now() start := time.Now()
answer := h.runMethod(cp.ctx, msg, callb, args) answer := h.runMethod(cp.ctx, msg, callb, args)
// Collect the statistics for RPC calls if metrics is enabled. // Collect the statistics for RPC calls if metrics is enabled.
// We only care about pure rpc call. Filter out subscription. // We only care about pure rpc call. Filter out subscription.
if callb != h.unsubscribeCb { if callb != h.unsubscribeCb {

View File

@ -23,9 +23,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"mime" "mime"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"sync" "sync"
"time" "time"
) )
@ -52,7 +54,7 @@ type httpConn struct {
// and some methods don't work. The panic() stubs here exist to ensure // and some methods don't work. The panic() stubs here exist to ensure
// this special treatment is correct. // this special treatment is correct.
func (hc *httpConn) writeJSON(context.Context, interface{}) error { func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error {
panic("writeJSON called on httpConn") panic("writeJSON called on httpConn")
} }
@ -256,7 +258,42 @@ type httpServerConn struct {
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
body := io.LimitReader(r.Body, maxRequestContentLength) body := io.LimitReader(r.Body, maxRequestContentLength)
conn := &httpServerConn{Reader: body, Writer: w, r: r} conn := &httpServerConn{Reader: body, Writer: w, r: r}
return NewCodec(conn)
encoder := func(v any, isErrorResponse bool) error {
if !isErrorResponse {
return json.NewEncoder(conn).Encode(v)
}
// It's an error response and requires special treatment.
//
// In case of a timeout error, the response must be written before the HTTP
// server's write timeout occurs. So we need to flush the response. The
// Content-Length header also needs to be set to ensure the client knows
// when it has the full response.
encdata, err := json.Marshal(v)
if err != nil {
return err
}
w.Header().Set("content-length", strconv.Itoa(len(encdata)))
// If this request is wrapped in a handler that might remove Content-Length (such
// as the automatic gzip we do in package node), we need to ensure the HTTP server
// doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked
// encoding might not be finished correctly, and some clients do not like it when
// the final chunk is missing.
w.Header().Set("transfer-encoding", "identity")
_, err = w.Write(encdata)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return err
}
dec := json.NewDecoder(conn)
dec.UseNumber()
return NewFuncCodec(conn, encoder, dec.Decode)
} }
// Close does nothing and always returns nil. // Close does nothing and always returns nil.
@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType) err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err return http.StatusUnsupportedMediaType, err
} }
// ContextRequestTimeout returns the request timeout derived from the given context.
func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) {
timeout := time.Duration(math.MaxInt64)
hasTimeout := false
setTimeout := func(d time.Duration) {
if d < timeout {
timeout = d
hasTimeout = true
}
}
if deadline, ok := ctx.Deadline(); ok {
setTimeout(time.Until(deadline))
}
// If the context is an HTTP request context, use the server's WriteTimeout.
httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server)
if ok && httpSrv.WriteTimeout > 0 {
wt := httpSrv.WriteTimeout
// When a write timeout is configured, we need to send the response message before
// the HTTP server cuts connection. So our internal timeout must be earlier than
// the server's true timeout.
//
// Note: Timeouts are sanitized to be a minimum of 1 second.
// Also see issue: https://github.com/golang/go/issues/47229
wt -= 100 * time.Millisecond
setTimeout(wt)
}
return timeout, hasTimeout
}

View File

@ -168,18 +168,22 @@ type ConnRemoteAddr interface {
// support for parsing arguments and serializing (result) objects. // support for parsing arguments and serializing (result) objects.
type jsonCodec struct { type jsonCodec struct {
remote string remote string
closer sync.Once // close closed channel once closer sync.Once // close closed channel once
closeCh chan interface{} // closed on Close closeCh chan interface{} // closed on Close
decode func(v interface{}) error // decoder to allow multiple transports decode decodeFunc // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports encode encodeFunc // encoder to allow multiple transports
conn deadlineCloser conn deadlineCloser
} }
type encodeFunc = func(v interface{}, isErrorResponse bool) error
type decodeFunc = func(v interface{}) error
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn // NewFuncCodec creates a codec which uses the given functions to read and write. If conn
// implements ConnRemoteAddr, log messages will use it to include the remote address of // implements ConnRemoteAddr, log messages will use it to include the remote address of
// the connection. // the connection.
func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec {
codec := &jsonCodec{ codec := &jsonCodec{
closeCh: make(chan interface{}), closeCh: make(chan interface{}),
encode: encode, encode: encode,
@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec {
enc := json.NewEncoder(conn) enc := json.NewEncoder(conn)
dec := json.NewDecoder(conn) dec := json.NewDecoder(conn)
dec.UseNumber() dec.UseNumber()
return NewFuncCodec(conn, enc.Encode, dec.Decode)
encode := func(v interface{}, isErrorResponse bool) error {
return enc.Encode(v)
}
return NewFuncCodec(conn, encode, dec.Decode)
} }
func (c *jsonCodec) peerInfo() PeerInfo { func (c *jsonCodec) peerInfo() PeerInfo {
@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err
return messages, batch, nil return messages, batch, nil
} }
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error {
c.encMu.Lock() c.encMu.Lock()
defer c.encMu.Unlock() defer c.encMu.Unlock()
@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
deadline = time.Now().Add(defaultWriteTimeout) deadline = time.Now().Add(defaultWriteTimeout)
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
return c.encode(v) return c.encode(v, isErrorResponse)
} }
func (c *jsonCodec) close() { func (c *jsonCodec) close() {

View File

@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
reqs, batch, err := codec.readBatch() reqs, batch, err := codec.readBatch()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) resp := errorMessage(&invalidMessageError{"parse error"})
codec.writeJSON(ctx, resp, true)
} }
return return
} }

View File

@ -175,11 +175,13 @@ func (n *Notifier) activate() error {
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
ctx := context.Background() ctx := context.Background()
return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
msg := &jsonrpcMessage{
Version: vsn, Version: vsn,
Method: n.namespace + notificationMethodSuffix, Method: n.namespace + notificationMethodSuffix,
Params: params, Params: params,
}) }
return n.h.conn.writeJSON(ctx, msg, false)
} }
// A Subscription is created by a notifier and tied to that notifier. The client can use // A Subscription is created by a notifier and tied to that notifier. The client can use

View File

@ -51,7 +51,9 @@ type ServerCodec interface {
// jsonWriter can write JSON messages to its underlying connection. // jsonWriter can write JSON messages to its underlying connection.
// Implementations must be safe for concurrent use. // Implementations must be safe for concurrent use.
type jsonWriter interface { type jsonWriter interface {
writeJSON(context.Context, interface{}) error // writeJSON writes a message to the connection.
writeJSON(ctx context.Context, msg interface{}, isError bool) error
// Closed returns a channel which is closed when the connection is closed. // Closed returns a channel which is closed when the connection is closed.
closed() <-chan interface{} closed() <-chan interface{}
// RemoteAddr returns the peer address of the connection. // RemoteAddr returns the peer address of the connection.

View File

@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
return nil return nil
}) })
encode := func(v interface{}, isErrorResponse bool) error {
return conn.WriteJSON(v)
}
wc := &websocketCodec{ wc := &websocketCodec{
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
conn: conn, conn: conn,
pingReset: make(chan struct{}, 1), pingReset: make(chan struct{}, 1),
info: PeerInfo{ info: PeerInfo{
@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo {
return wc.info return wc.info
} }
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
err := wc.jsonCodec.writeJSON(ctx, v) err := wc.jsonCodec.writeJSON(ctx, v, isError)
if err == nil { if err == nil {
// Notify pingLoop to delay the next idle ping. // Notify pingLoop to delay the next idle ping.
select { select {