forked from cerc-io/plugeth
graphql, node, rpc: improve HTTP write timeout handling (#25457)
Here we add special handling for sending an error response when the write timeout of the HTTP server is just about to expire. This is surprisingly difficult to get right, since is must be ensured that all output is fully flushed in time, which needs support from multiple levels of the RPC handler stack: The timeout response can't use chunked transfer-encoding because there is no way to write the final terminating chunk. net/http writes it when the topmost handler returns, but the timeout will already be over by the time that happens. We decided to disable chunked encoding by setting content-length explicitly. Gzip compression must also be disabled for timeout responses because we don't know the true content-length before compressing all output, i.e. compression would reintroduce chunked transfer-encoding.
This commit is contained in:
parent
b44abf56a9
commit
f20eba426a
@ -325,6 +325,7 @@ func createNode(t *testing.T) *node.Node {
|
||||
HTTPPort: 0,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: 0,
|
||||
HTTPTimeouts: node.DefaultConfig.HTTPTimeouts,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("could not create node: %v", err)
|
||||
|
@ -20,12 +20,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/eth/filters"
|
||||
"github.com/ethereum/go-ethereum/internal/ethapi"
|
||||
"github.com/ethereum/go-ethereum/node"
|
||||
"github.com/ethereum/go-ethereum/rpc"
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
gqlErrors "github.com/graph-gophers/graphql-go/errors"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
@ -43,10 +47,49 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||
var (
|
||||
ctx = r.Context()
|
||||
responded sync.Once
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
if timeout, ok := rpc.ContextRequestTimeout(ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
responded.Do(func() {
|
||||
// Cancel request handling.
|
||||
cancel()
|
||||
|
||||
// Create the timeout response.
|
||||
response := &graphql.Response{
|
||||
Errors: []*gqlErrors.QueryError{{Message: "request timed out"}},
|
||||
}
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Setting this disables gzip compression in package node.
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
|
||||
// Flush the response. Since we are writing close to the response timeout,
|
||||
// chunked transfer encoding must be disabled by setting content-length.
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.Header().Set("content-length", strconv.Itoa(len(responseJSON)))
|
||||
w.Write(responseJSON)
|
||||
if flush, ok := w.(http.Flusher); ok {
|
||||
flush.Flush()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
|
||||
timer.Stop()
|
||||
responded.Do(func() {
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@ -55,9 +98,9 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(response.Errors) > 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(responseJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// New constructs a new GraphQL service instance.
|
||||
|
@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) {
|
||||
config := test.cfg
|
||||
// config.Logger = testlog.Logger(t, log.LvlDebug)
|
||||
config.P2P.NoDiscovery = true
|
||||
if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) {
|
||||
config.HTTPTimeouts = rpc.DefaultHTTPTimeouts
|
||||
}
|
||||
|
||||
// Create Node.
|
||||
stack, err := New(&config)
|
||||
|
@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
|
||||
}
|
||||
|
||||
for _, path := range test.wantHTTP {
|
||||
resp := rpcRequest(t, httpBase+path)
|
||||
resp := rpcRequest(t, httpBase+path, testMethod)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
for _, path := range test.wantNoHTTP {
|
||||
resp := rpcRequest(t, httpBase+path)
|
||||
resp := rpcRequest(t, httpBase+path, testMethod)
|
||||
if resp.StatusCode != 404 {
|
||||
t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode)
|
||||
}
|
||||
@ -590,6 +590,7 @@ func createNode(t *testing.T, httpPort, wsPort int) *Node {
|
||||
HTTPPort: httpPort,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: wsPort,
|
||||
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
|
||||
}
|
||||
node, err := New(conf)
|
||||
if err != nil {
|
||||
|
100
node/rpcstack.go
100
node/rpcstack.go
@ -24,6 +24,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// if http-rpc is enabled, try to serve request
|
||||
rpc := h.httpHandler.Load().(*rpcHandler)
|
||||
if rpc != nil {
|
||||
@ -462,17 +464,94 @@ var gzPool = sync.Pool{
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
resp http.ResponseWriter
|
||||
|
||||
gz *gzip.Writer
|
||||
contentLength uint64 // total length of the uncompressed response
|
||||
written uint64 // amount of written bytes from the uncompressed response
|
||||
hasLength bool // true if uncompressed response had Content-Length
|
||||
inited bool // true after init was called for the first time
|
||||
}
|
||||
|
||||
// init runs just before response headers are written. Among other things, this function
|
||||
// also decides whether compression will be applied at all.
|
||||
func (w *gzipResponseWriter) init() {
|
||||
if w.inited {
|
||||
return
|
||||
}
|
||||
w.inited = true
|
||||
|
||||
hdr := w.resp.Header()
|
||||
length := hdr.Get("content-length")
|
||||
if len(length) > 0 {
|
||||
if n, err := strconv.ParseUint(length, 10, 64); err != nil {
|
||||
w.hasLength = true
|
||||
w.contentLength = n
|
||||
}
|
||||
}
|
||||
|
||||
// Setting Transfer-Encoding to "identity" explicitly disables compression. net/http
|
||||
// also recognizes this header value and uses it to disable "chunked" transfer
|
||||
// encoding, trimming the header from the response. This means downstream handlers can
|
||||
// set this without harm, even if they aren't wrapped by newGzipHandler.
|
||||
//
|
||||
// In go-ethereum, we use this signal to disable compression for certain error
|
||||
// responses which are flushed out close to the write deadline of the response. For
|
||||
// these cases, we want to avoid chunked transfer encoding and compression because
|
||||
// they require additional output that may not get written in time.
|
||||
passthrough := hdr.Get("transfer-encoding") == "identity"
|
||||
if !passthrough {
|
||||
w.gz = gzPool.Get().(*gzip.Writer)
|
||||
w.gz.Reset(w.resp)
|
||||
hdr.Del("content-length")
|
||||
hdr.Set("content-encoding", "gzip")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Header() http.Header {
|
||||
return w.resp.Header()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(status int) {
|
||||
w.Header().Del("Content-Length")
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
w.init()
|
||||
w.resp.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.Writer.Write(b)
|
||||
w.init()
|
||||
|
||||
if w.gz == nil {
|
||||
// Compression is disabled.
|
||||
return w.resp.Write(b)
|
||||
}
|
||||
|
||||
n, err := w.gz.Write(b)
|
||||
w.written += uint64(n)
|
||||
if w.hasLength && w.written >= w.contentLength {
|
||||
// The HTTP handler has finished writing the entire uncompressed response. Close
|
||||
// the gzip stream to ensure the footer will be seen by the client in case the
|
||||
// response is flushed after this call to write.
|
||||
err = w.gz.Close()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if w.gz != nil {
|
||||
w.gz.Flush()
|
||||
}
|
||||
if f, ok := w.resp.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) close() {
|
||||
if w.gz == nil {
|
||||
return
|
||||
}
|
||||
w.gz.Close()
|
||||
gzPool.Put(w.gz)
|
||||
w.gz = nil
|
||||
}
|
||||
|
||||
func newGzipHandler(next http.Handler) http.Handler {
|
||||
@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
wrapper := &gzipResponseWriter{resp: w}
|
||||
defer wrapper.close()
|
||||
|
||||
gz := gzPool.Get().(*gzip.Writer)
|
||||
defer gzPool.Put(gz)
|
||||
|
||||
gz.Reset(w)
|
||||
defer gz.Close()
|
||||
|
||||
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
|
||||
next.ServeHTTP(wrapper, r)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,9 @@ package node
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -34,29 +36,31 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const testMethod = "rpc_modules"
|
||||
|
||||
// TestCorsHandler makes sure CORS are properly handled on the http server.
|
||||
func TestCorsHandler(t *testing.T) {
|
||||
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{})
|
||||
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil)
|
||||
defer srv.stop()
|
||||
url := "http://" + srv.listenAddr()
|
||||
|
||||
resp := rpcRequest(t, url, "origin", "test.com")
|
||||
resp := rpcRequest(t, url, testMethod, "origin", "test.com")
|
||||
assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
resp2 := rpcRequest(t, url, "origin", "bad")
|
||||
resp2 := rpcRequest(t, url, testMethod, "origin", "bad")
|
||||
assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// TestVhosts makes sure vhosts are properly handled on the http server.
|
||||
func TestVhosts(t *testing.T) {
|
||||
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{})
|
||||
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil)
|
||||
defer srv.stop()
|
||||
url := "http://" + srv.listenAddr()
|
||||
|
||||
resp := rpcRequest(t, url, "host", "test")
|
||||
resp := rpcRequest(t, url, testMethod, "host", "test")
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
|
||||
resp2 := rpcRequest(t, url, "host", "bad")
|
||||
resp2 := rpcRequest(t, url, testMethod, "host", "bad")
|
||||
assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
|
||||
}
|
||||
|
||||
@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)})
|
||||
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil)
|
||||
url := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||
for _, origin := range tc.expOk {
|
||||
if err := wsRequest(t, url, "Origin", origin); err != nil {
|
||||
@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer {
|
||||
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer {
|
||||
t.Helper()
|
||||
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
|
||||
assert.NoError(t, srv.enableRPC(nil, *conf))
|
||||
if timeouts == nil {
|
||||
timeouts = &rpc.DefaultHTTPTimeouts
|
||||
}
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts)
|
||||
assert.NoError(t, srv.enableRPC(apis(), *conf))
|
||||
if ws {
|
||||
assert.NoError(t, srv.enableWS(nil, *wsConf))
|
||||
}
|
||||
@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
|
||||
}
|
||||
|
||||
// rpcRequest performs a JSON-RPC request to the given URL.
|
||||
func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response {
|
||||
func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method)
|
||||
return baseRpcRequest(t, url, body, extraHeaders...)
|
||||
}
|
||||
|
||||
func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response {
|
||||
reqs := make([]string, len(methods))
|
||||
for i, m := range methods {
|
||||
reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m)
|
||||
}
|
||||
body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ","))
|
||||
return baseRpcRequest(t, url, body, extraHeaders...)
|
||||
}
|
||||
|
||||
func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
// Create the request.
|
||||
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`))
|
||||
body := bytes.NewReader([]byte(bodyStr))
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
t.Fatal("could not create http request:", err)
|
||||
}
|
||||
req.Header.Set("content-type", "application/json")
|
||||
req.Header.Set("accept-encoding", "identity")
|
||||
|
||||
// Apply extra headers.
|
||||
if len(extraHeaders)%2 != 0 {
|
||||
@ -315,7 +339,7 @@ func TestJWT(t *testing.T) {
|
||||
return ss
|
||||
}
|
||||
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
|
||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
|
||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
|
||||
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||
|
||||
@ -348,7 +372,7 @@ func TestJWT(t *testing.T) {
|
||||
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
|
||||
}
|
||||
token = tokenFn()
|
||||
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
|
||||
if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 {
|
||||
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
@ -414,10 +438,176 @@ func TestJWT(t *testing.T) {
|
||||
}
|
||||
|
||||
token = tokenFn()
|
||||
resp := rpcRequest(t, htUrl, "Authorization", token)
|
||||
resp := rpcRequest(t, htUrl, testMethod, "Authorization", token)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
srv.stop()
|
||||
}
|
||||
|
||||
func TestGzipHandler(t *testing.T) {
|
||||
type gzipTest struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
status int
|
||||
isGzip bool
|
||||
header map[string]string
|
||||
}
|
||||
tests := []gzipTest{
|
||||
{
|
||||
name: "Write",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "WriteHeader",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.WriteHeader(205)
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 205,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "WriteContentLength",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("content-length", "8")
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "Flush",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("res"))
|
||||
w.(http.Flusher).Flush()
|
||||
w.Write([]byte("ponse"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "disable",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: false,
|
||||
status: 200,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "disable-WriteHeader",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.WriteHeader(205)
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: false,
|
||||
status: 205,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(newGzipHandler(test.handler))
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wasGzip := resp.Uncompressed
|
||||
|
||||
if string(content) != "response" {
|
||||
t.Fatalf("wrong response content %q", content)
|
||||
}
|
||||
if wasGzip != test.isGzip {
|
||||
t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip)
|
||||
}
|
||||
if resp.StatusCode != test.status {
|
||||
t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status)
|
||||
}
|
||||
for name, expectedValue := range test.header {
|
||||
if v := resp.Header.Get(name); v != expectedValue {
|
||||
t.Fatalf("response header %s == %s, want %s", name, v, expectedValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPWriteTimeout(t *testing.T) {
|
||||
const (
|
||||
timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}`
|
||||
greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}`
|
||||
)
|
||||
// Set-up server
|
||||
timeouts := rpc.DefaultHTTPTimeouts
|
||||
timeouts.WriteTimeout = time.Second
|
||||
srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts)
|
||||
url := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||
|
||||
// Send normal request
|
||||
t.Run("message", func(t *testing.T) {
|
||||
resp := rpcRequest(t, url, "test_sleep")
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != timeoutRes {
|
||||
t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes)
|
||||
}
|
||||
})
|
||||
|
||||
// Batch request
|
||||
t.Run("batch", func(t *testing.T) {
|
||||
want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes)
|
||||
resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"})
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != want {
|
||||
t.Errorf("wrong response. have %s, want %s", string(body), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func apis() []rpc.API {
|
||||
return []rpc.API{
|
||||
{
|
||||
Namespace: "test",
|
||||
Service: &testService{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type testService struct{}
|
||||
|
||||
func (s *testService) Greet() string {
|
||||
return "Hello"
|
||||
}
|
||||
|
||||
func (s *testService) Sleep() {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
|
||||
codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON)
|
||||
handler.ServeCodec(codec, 0)
|
||||
return nil
|
||||
}
|
||||
|
@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := c.writeConn.writeJSON(ctx, msg)
|
||||
err := c.writeConn.writeJSON(ctx, msg, false)
|
||||
if err != nil {
|
||||
c.writeConn = nil
|
||||
if !retry {
|
||||
@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) {
|
||||
for {
|
||||
msgs, batch, err := codec.readBatch()
|
||||
if _, ok := err.(*json.SyntaxError); ok {
|
||||
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
|
||||
msg := errorMessage(&parseError{err.Error()})
|
||||
codec.writeJSON(context.Background(), msg, true)
|
||||
}
|
||||
if err != nil {
|
||||
c.readErr <- err
|
||||
|
@ -60,10 +60,15 @@ var (
|
||||
const (
|
||||
errcodeDefault = -32000
|
||||
errcodeNotificationsUnsupported = -32001
|
||||
errcodeTimeout = -32002
|
||||
errcodePanic = -32603
|
||||
errcodeMarshalError = -32603
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgTimeout = "request timed out"
|
||||
)
|
||||
|
||||
type methodNotFoundError struct{ method string }
|
||||
|
||||
func (e *methodNotFoundError) ErrorCode() int { return -32601 }
|
||||
|
142
rpc/handler.go
142
rpc/handler.go
@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
return h
|
||||
}
|
||||
|
||||
// batchCallBuffer manages in progress call messages and their responses during a batch
|
||||
// call. Calls need to be synchronized between the processing and timeout-triggering
|
||||
// goroutines.
|
||||
type batchCallBuffer struct {
|
||||
mutex sync.Mutex
|
||||
calls []*jsonrpcMessage
|
||||
resp []*jsonrpcMessage
|
||||
wrote bool
|
||||
}
|
||||
|
||||
// nextCall returns the next unprocessed message.
|
||||
func (b *batchCallBuffer) nextCall() *jsonrpcMessage {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
if len(b.calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
// The popping happens in `pushAnswer`. The in progress call is kept
|
||||
// so we can return an error for it in case of timeout.
|
||||
msg := b.calls[0]
|
||||
return msg
|
||||
}
|
||||
|
||||
// pushResponse adds the response to last call returned by nextCall.
|
||||
func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
if answer != nil {
|
||||
b.resp = append(b.resp, answer)
|
||||
}
|
||||
b.calls = b.calls[1:]
|
||||
}
|
||||
|
||||
// write sends the responses.
|
||||
func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
b.doWrite(ctx, conn, false)
|
||||
}
|
||||
|
||||
// timeout sends the responses added so far. For the remaining unanswered call
|
||||
// messages, it sends a timeout error response.
|
||||
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
for _, msg := range b.calls {
|
||||
if !msg.isNotification() {
|
||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||
b.resp = append(b.resp, resp)
|
||||
}
|
||||
}
|
||||
b.doWrite(ctx, conn, true)
|
||||
}
|
||||
|
||||
// doWrite actually writes the response.
|
||||
// This assumes b.mutex is held.
|
||||
func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) {
|
||||
if b.wrote {
|
||||
return
|
||||
}
|
||||
b.wrote = true // can only write once
|
||||
if len(b.resp) > 0 {
|
||||
conn.writeJSON(ctx, b.resp, isErrorResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// handleBatch executes all messages in a batch and returns the responses.
|
||||
func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
// Emit error response for empty batches:
|
||||
if len(msgs) == 0 {
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
|
||||
resp := errorMessage(&invalidRequestError{"empty batch"})
|
||||
h.conn.writeJSON(cp.ctx, resp, true)
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
}
|
||||
// Process calls on a goroutine because they may block indefinitely:
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
answers := make([]*jsonrpcMessage, 0, len(msgs))
|
||||
for _, msg := range calls {
|
||||
if answer := h.handleCallMsg(cp, msg); answer != nil {
|
||||
answers = append(answers, answer)
|
||||
var (
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))}
|
||||
)
|
||||
|
||||
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
||||
defer cancel()
|
||||
|
||||
// Cancel the request context after timeout and send an error response. Since the
|
||||
// currently-running method might not return immediately on timeout, we must wait
|
||||
// for the timeout concurrently with processing the request.
|
||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
cancel()
|
||||
callBuffer.timeout(cp.ctx, h.conn)
|
||||
})
|
||||
}
|
||||
|
||||
for {
|
||||
// No need to handle rest of calls if timed out.
|
||||
if cp.ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
msg := callBuffer.nextCall()
|
||||
if msg == nil {
|
||||
break
|
||||
}
|
||||
resp := h.handleCallMsg(cp, msg)
|
||||
callBuffer.pushResponse(resp)
|
||||
}
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
callBuffer.write(cp.ctx, h.conn)
|
||||
h.addSubscriptions(cp.notifiers)
|
||||
if len(answers) > 0 {
|
||||
h.conn.writeJSON(cp.ctx, answers)
|
||||
}
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
}
|
||||
@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||
return
|
||||
}
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
var (
|
||||
responded sync.Once
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
||||
defer cancel()
|
||||
|
||||
// Cancel the request context after timeout and send an error response. Since the
|
||||
// running method might not return immediately on timeout, we must wait for the
|
||||
// timeout concurrently with processing the request.
|
||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
cancel()
|
||||
responded.Do(func() {
|
||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||
h.conn.writeJSON(cp.ctx, resp, true)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
answer := h.handleCallMsg(cp, msg)
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
h.addSubscriptions(cp.notifiers)
|
||||
if answer != nil {
|
||||
h.conn.writeJSON(cp.ctx, answer)
|
||||
responded.Do(func() {
|
||||
h.conn.writeJSON(cp.ctx, answer, false)
|
||||
})
|
||||
}
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
||||
}
|
||||
start := time.Now()
|
||||
answer := h.runMethod(cp.ctx, msg, callb, args)
|
||||
|
||||
// Collect the statistics for RPC calls if metrics is enabled.
|
||||
// We only care about pure rpc call. Filter out subscription.
|
||||
if callb != h.unsubscribeCb {
|
||||
|
73
rpc/http.go
73
rpc/http.go
@ -23,9 +23,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -52,7 +54,7 @@ type httpConn struct {
|
||||
// and some methods don't work. The panic() stubs here exist to ensure
|
||||
// this special treatment is correct.
|
||||
|
||||
func (hc *httpConn) writeJSON(context.Context, interface{}) error {
|
||||
func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error {
|
||||
panic("writeJSON called on httpConn")
|
||||
}
|
||||
|
||||
@ -256,7 +258,42 @@ type httpServerConn struct {
|
||||
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
|
||||
body := io.LimitReader(r.Body, maxRequestContentLength)
|
||||
conn := &httpServerConn{Reader: body, Writer: w, r: r}
|
||||
return NewCodec(conn)
|
||||
|
||||
encoder := func(v any, isErrorResponse bool) error {
|
||||
if !isErrorResponse {
|
||||
return json.NewEncoder(conn).Encode(v)
|
||||
}
|
||||
|
||||
// It's an error response and requires special treatment.
|
||||
//
|
||||
// In case of a timeout error, the response must be written before the HTTP
|
||||
// server's write timeout occurs. So we need to flush the response. The
|
||||
// Content-Length header also needs to be set to ensure the client knows
|
||||
// when it has the full response.
|
||||
encdata, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("content-length", strconv.Itoa(len(encdata)))
|
||||
|
||||
// If this request is wrapped in a handler that might remove Content-Length (such
|
||||
// as the automatic gzip we do in package node), we need to ensure the HTTP server
|
||||
// doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked
|
||||
// encoding might not be finished correctly, and some clients do not like it when
|
||||
// the final chunk is missing.
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
|
||||
_, err = w.Write(encdata)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(conn)
|
||||
dec.UseNumber()
|
||||
|
||||
return NewFuncCodec(conn, encoder, dec.Decode)
|
||||
}
|
||||
|
||||
// Close does nothing and always returns nil.
|
||||
@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) {
|
||||
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
|
||||
return http.StatusUnsupportedMediaType, err
|
||||
}
|
||||
|
||||
// ContextRequestTimeout returns the request timeout derived from the given context.
|
||||
func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) {
|
||||
timeout := time.Duration(math.MaxInt64)
|
||||
hasTimeout := false
|
||||
setTimeout := func(d time.Duration) {
|
||||
if d < timeout {
|
||||
timeout = d
|
||||
hasTimeout = true
|
||||
}
|
||||
}
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
setTimeout(time.Until(deadline))
|
||||
}
|
||||
|
||||
// If the context is an HTTP request context, use the server's WriteTimeout.
|
||||
httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server)
|
||||
if ok && httpSrv.WriteTimeout > 0 {
|
||||
wt := httpSrv.WriteTimeout
|
||||
// When a write timeout is configured, we need to send the response message before
|
||||
// the HTTP server cuts connection. So our internal timeout must be earlier than
|
||||
// the server's true timeout.
|
||||
//
|
||||
// Note: Timeouts are sanitized to be a minimum of 1 second.
|
||||
// Also see issue: https://github.com/golang/go/issues/47229
|
||||
wt -= 100 * time.Millisecond
|
||||
setTimeout(wt)
|
||||
}
|
||||
|
||||
return timeout, hasTimeout
|
||||
}
|
||||
|
20
rpc/json.go
20
rpc/json.go
@ -170,16 +170,20 @@ type jsonCodec struct {
|
||||
remote string
|
||||
closer sync.Once // close closed channel once
|
||||
closeCh chan interface{} // closed on Close
|
||||
decode func(v interface{}) error // decoder to allow multiple transports
|
||||
decode decodeFunc // decoder to allow multiple transports
|
||||
encMu sync.Mutex // guards the encoder
|
||||
encode func(v interface{}) error // encoder to allow multiple transports
|
||||
encode encodeFunc // encoder to allow multiple transports
|
||||
conn deadlineCloser
|
||||
}
|
||||
|
||||
type encodeFunc = func(v interface{}, isErrorResponse bool) error
|
||||
|
||||
type decodeFunc = func(v interface{}) error
|
||||
|
||||
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn
|
||||
// implements ConnRemoteAddr, log messages will use it to include the remote address of
|
||||
// the connection.
|
||||
func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
|
||||
func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec {
|
||||
codec := &jsonCodec{
|
||||
closeCh: make(chan interface{}),
|
||||
encode: encode,
|
||||
@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec {
|
||||
enc := json.NewEncoder(conn)
|
||||
dec := json.NewDecoder(conn)
|
||||
dec.UseNumber()
|
||||
return NewFuncCodec(conn, enc.Encode, dec.Decode)
|
||||
|
||||
encode := func(v interface{}, isErrorResponse bool) error {
|
||||
return enc.Encode(v)
|
||||
}
|
||||
return NewFuncCodec(conn, encode, dec.Decode)
|
||||
}
|
||||
|
||||
func (c *jsonCodec) peerInfo() PeerInfo {
|
||||
@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err
|
||||
return messages, batch, nil
|
||||
}
|
||||
|
||||
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error {
|
||||
c.encMu.Lock()
|
||||
defer c.encMu.Unlock()
|
||||
|
||||
@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
deadline = time.Now().Add(defaultWriteTimeout)
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
return c.encode(v)
|
||||
return c.encode(v, isErrorResponse)
|
||||
}
|
||||
|
||||
func (c *jsonCodec) close() {
|
||||
|
@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
|
||||
reqs, batch, err := codec.readBatch()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"}))
|
||||
resp := errorMessage(&invalidMessageError{"parse error"})
|
||||
codec.writeJSON(ctx, resp, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -175,11 +175,13 @@ func (n *Notifier) activate() error {
|
||||
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
|
||||
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
|
||||
ctx := context.Background()
|
||||
return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
|
||||
|
||||
msg := &jsonrpcMessage{
|
||||
Version: vsn,
|
||||
Method: n.namespace + notificationMethodSuffix,
|
||||
Params: params,
|
||||
})
|
||||
}
|
||||
return n.h.conn.writeJSON(ctx, msg, false)
|
||||
}
|
||||
|
||||
// A Subscription is created by a notifier and tied to that notifier. The client can use
|
||||
|
@ -51,7 +51,9 @@ type ServerCodec interface {
|
||||
// jsonWriter can write JSON messages to its underlying connection.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type jsonWriter interface {
|
||||
writeJSON(context.Context, interface{}) error
|
||||
// writeJSON writes a message to the connection.
|
||||
writeJSON(ctx context.Context, msg interface{}, isError bool) error
|
||||
|
||||
// Closed returns a channel which is closed when the connection is closed.
|
||||
closed() <-chan interface{}
|
||||
// RemoteAddr returns the peer address of the connection.
|
||||
|
@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
return nil
|
||||
})
|
||||
|
||||
encode := func(v interface{}, isErrorResponse bool) error {
|
||||
return conn.WriteJSON(v)
|
||||
}
|
||||
wc := &websocketCodec{
|
||||
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
|
||||
jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
|
||||
conn: conn,
|
||||
pingReset: make(chan struct{}, 1),
|
||||
info: PeerInfo{
|
||||
@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo {
|
||||
return wc.info
|
||||
}
|
||||
|
||||
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
err := wc.jsonCodec.writeJSON(ctx, v)
|
||||
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
|
||||
err := wc.jsonCodec.writeJSON(ctx, v, isError)
|
||||
if err == nil {
|
||||
// Notify pingLoop to delay the next idle ping.
|
||||
select {
|
||||
|
Loading…
Reference in New Issue
Block a user