Patch for concurrent iterator & others (onto v1.11.6) #386

Closed
roysc wants to merge 1565 commits from v1.11.6-statediff-v5 into master
16 changed files with 606 additions and 80 deletions
Showing only changes of commit f20eba426a - Show all commits

View File

@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) {
func createNode(t *testing.T) *node.Node {
stack, err := node.New(&node.Config{
HTTPHost: "127.0.0.1",
HTTPPort: 0,
WSHost: "127.0.0.1",
WSPort: 0,
HTTPHost: "127.0.0.1",
HTTPPort: 0,
WSHost: "127.0.0.1",
WSPort: 0,
HTTPTimeouts: node.DefaultConfig.HTTPTimeouts,
})
if err != nil {
t.Fatalf("could not create node: %v", err)

View File

@ -20,12 +20,16 @@ import (
"context"
"encoding/json"
"net/http"
"strconv"
"sync"
"time"
"github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/rpc"
"github.com/graph-gophers/graphql-go"
gqlErrors "github.com/graph-gophers/graphql-go/errors"
)
type handler struct {
@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
var (
ctx = r.Context()
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
ctx, cancel = context.WithCancel(ctx)
defer cancel()
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
responseJSON, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(response.Errors) > 0 {
w.WriteHeader(http.StatusBadRequest)
if timeout, ok := rpc.ContextRequestTimeout(ctx); ok {
timer = time.AfterFunc(timeout, func() {
responded.Do(func() {
// Cancel request handling.
cancel()
// Create the timeout response.
response := &graphql.Response{
Errors: []*gqlErrors.QueryError{{Message: "request timed out"}},
}
responseJSON, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Setting this disables gzip compression in package node.
w.Header().Set("transfer-encoding", "identity")
// Flush the response. Since we are writing close to the response timeout,
// chunked transfer encoding must be disabled by setting content-length.
w.Header().Set("content-type", "application/json")
w.Header().Set("content-length", strconv.Itoa(len(responseJSON)))
w.Write(responseJSON)
if flush, ok := w.(http.Flusher); ok {
flush.Flush()
}
})
})
}
w.Header().Set("Content-Type", "application/json")
w.Write(responseJSON)
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
timer.Stop()
responded.Do(func() {
responseJSON, err := json.Marshal(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(response.Errors) > 0 {
w.WriteHeader(http.StatusBadRequest)
}
w.Header().Set("Content-Type", "application/json")
w.Write(responseJSON)
})
}
// New constructs a new GraphQL service instance.

View File

@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) {
config := test.cfg
// config.Logger = testlog.Logger(t, log.LvlDebug)
config.P2P.NoDiscovery = true
if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) {
config.HTTPTimeouts = rpc.DefaultHTTPTimeouts
}
// Create Node.
stack, err := New(&config)

View File

@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
}
for _, path := range test.wantHTTP {
resp := rpcRequest(t, httpBase+path)
resp := rpcRequest(t, httpBase+path, testMethod)
if resp.StatusCode != 200 {
t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode)
}
}
for _, path := range test.wantNoHTTP {
resp := rpcRequest(t, httpBase+path)
resp := rpcRequest(t, httpBase+path, testMethod)
if resp.StatusCode != 404 {
t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode)
}
@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
func createNode(t *testing.T, httpPort, wsPort int) *Node {
conf := &Config{
HTTPHost: "127.0.0.1",
HTTPPort: httpPort,
WSHost: "127.0.0.1",
WSPort: wsPort,
HTTPHost: "127.0.0.1",
HTTPPort: httpPort,
WSHost: "127.0.0.1",
WSPort: wsPort,
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
}
node, err := New(conf)
if err != nil {

View File

@ -24,6 +24,7 @@ import (
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
return
}
// if http-rpc is enabled, try to serve request
rpc := h.httpHandler.Load().(*rpcHandler)
if rpc != nil {
@ -462,17 +464,94 @@ var gzPool = sync.Pool{
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
resp http.ResponseWriter
gz *gzip.Writer
contentLength uint64 // total length of the uncompressed response
written uint64 // amount of written bytes from the uncompressed response
hasLength bool // true if uncompressed response had Content-Length
inited bool // true after init was called for the first time
}
// init runs just before response headers are written. Among other things, this function
// also decides whether compression will be applied at all.
func (w *gzipResponseWriter) init() {
if w.inited {
return
}
w.inited = true
hdr := w.resp.Header()
length := hdr.Get("content-length")
if len(length) > 0 {
if n, err := strconv.ParseUint(length, 10, 64); err != nil {
w.hasLength = true
w.contentLength = n
}
}
// Setting Transfer-Encoding to "identity" explicitly disables compression. net/http
// also recognizes this header value and uses it to disable "chunked" transfer
// encoding, trimming the header from the response. This means downstream handlers can
// set this without harm, even if they aren't wrapped by newGzipHandler.
//
// In go-ethereum, we use this signal to disable compression for certain error
// responses which are flushed out close to the write deadline of the response. For
// these cases, we want to avoid chunked transfer encoding and compression because
// they require additional output that may not get written in time.
passthrough := hdr.Get("transfer-encoding") == "identity"
if !passthrough {
w.gz = gzPool.Get().(*gzip.Writer)
w.gz.Reset(w.resp)
hdr.Del("content-length")
hdr.Set("content-encoding", "gzip")
}
}
func (w *gzipResponseWriter) Header() http.Header {
return w.resp.Header()
}
func (w *gzipResponseWriter) WriteHeader(status int) {
w.Header().Del("Content-Length")
w.ResponseWriter.WriteHeader(status)
w.init()
w.resp.WriteHeader(status)
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
w.init()
if w.gz == nil {
// Compression is disabled.
return w.resp.Write(b)
}
n, err := w.gz.Write(b)
w.written += uint64(n)
if w.hasLength && w.written >= w.contentLength {
// The HTTP handler has finished writing the entire uncompressed response. Close
// the gzip stream to ensure the footer will be seen by the client in case the
// response is flushed after this call to write.
err = w.gz.Close()
}
return n, err
}
func (w *gzipResponseWriter) Flush() {
if w.gz != nil {
w.gz.Flush()
}
if f, ok := w.resp.(http.Flusher); ok {
f.Flush()
}
}
func (w *gzipResponseWriter) close() {
if w.gz == nil {
return
}
w.gz.Close()
gzPool.Put(w.gz)
w.gz = nil
}
func newGzipHandler(next http.Handler) http.Handler {
@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler {
return
}
w.Header().Set("Content-Encoding", "gzip")
wrapper := &gzipResponseWriter{resp: w}
defer wrapper.close()
gz := gzPool.Get().(*gzip.Writer)
defer gzPool.Put(gz)
gz.Reset(w)
defer gz.Close()
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
next.ServeHTTP(wrapper, r)
})
}

View File

@ -19,7 +19,9 @@ package node
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
@ -34,29 +36,31 @@ import (
"github.com/stretchr/testify/assert"
)
const testMethod = "rpc_modules"
// TestCorsHandler makes sure CORS are properly handled on the http server.
func TestCorsHandler(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{})
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil)
defer srv.stop()
url := "http://" + srv.listenAddr()
resp := rpcRequest(t, url, "origin", "test.com")
resp := rpcRequest(t, url, testMethod, "origin", "test.com")
assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin"))
resp2 := rpcRequest(t, url, "origin", "bad")
resp2 := rpcRequest(t, url, testMethod, "origin", "bad")
assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin"))
}
// TestVhosts makes sure vhosts are properly handled on the http server.
func TestVhosts(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{})
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil)
defer srv.stop()
url := "http://" + srv.listenAddr()
resp := rpcRequest(t, url, "host", "test")
resp := rpcRequest(t, url, testMethod, "host", "test")
assert.Equal(t, resp.StatusCode, http.StatusOK)
resp2 := rpcRequest(t, url, "host", "bad")
resp2 := rpcRequest(t, url, testMethod, "host", "bad")
assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
}
@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) {
},
}
for _, tc := range tests {
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)})
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil)
url := fmt.Sprintf("ws://%v", srv.listenAddr())
for _, origin := range tc.expOk {
if err := wsRequest(t, url, "Origin", origin); err != nil {
@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) {
}
}
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer {
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer {
t.Helper()
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
assert.NoError(t, srv.enableRPC(nil, *conf))
if timeouts == nil {
timeouts = &rpc.DefaultHTTPTimeouts
}
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts)
assert.NoError(t, srv.enableRPC(apis(), *conf))
if ws {
assert.NoError(t, srv.enableWS(nil, *wsConf))
}
@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
}
// rpcRequest performs a JSON-RPC request to the given URL.
func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response {
func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response {
t.Helper()
body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method)
return baseRpcRequest(t, url, body, extraHeaders...)
}
func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response {
reqs := make([]string, len(methods))
for i, m := range methods {
reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m)
}
body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ","))
return baseRpcRequest(t, url, body, extraHeaders...)
}
func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response {
t.Helper()
// Create the request.
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`))
body := bytes.NewReader([]byte(bodyStr))
req, err := http.NewRequest("POST", url, body)
if err != nil {
t.Fatal("could not create http request:", err)
}
req.Header.Set("content-type", "application/json")
req.Header.Set("accept-encoding", "identity")
// Apply extra headers.
if len(extraHeaders)%2 != 0 {
@ -315,7 +339,7 @@ func TestJWT(t *testing.T) {
return ss
}
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
@ -348,7 +372,7 @@ func TestJWT(t *testing.T) {
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
}
token = tokenFn()
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 {
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
}
}
@ -414,10 +438,176 @@ func TestJWT(t *testing.T) {
}
token = tokenFn()
resp := rpcRequest(t, htUrl, "Authorization", token)
resp := rpcRequest(t, htUrl, testMethod, "Authorization", token)
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
}
}
srv.stop()
}
func TestGzipHandler(t *testing.T) {
type gzipTest struct {
name string
handler http.HandlerFunc
status int
isGzip bool
header map[string]string
}
tests := []gzipTest{
{
name: "Write",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("response"))
},
isGzip: true,
status: 200,
},
{
name: "WriteHeader",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-foo", "bar")
w.WriteHeader(205)
w.Write([]byte("response"))
},
isGzip: true,
status: 205,
header: map[string]string{"x-foo": "bar"},
},
{
name: "WriteContentLength",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-length", "8")
w.Write([]byte("response"))
},
isGzip: true,
status: 200,
},
{
name: "Flush",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("res"))
w.(http.Flusher).Flush()
w.Write([]byte("ponse"))
},
isGzip: true,
status: 200,
},
{
name: "disable",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("transfer-encoding", "identity")
w.Header().Set("x-foo", "bar")
w.Write([]byte("response"))
},
isGzip: false,
status: 200,
header: map[string]string{"x-foo": "bar"},
},
{
name: "disable-WriteHeader",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("transfer-encoding", "identity")
w.Header().Set("x-foo", "bar")
w.WriteHeader(205)
w.Write([]byte("response"))
},
isGzip: false,
status: 205,
header: map[string]string{"x-foo": "bar"},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
srv := httptest.NewServer(newGzipHandler(test.handler))
defer srv.Close()
resp, err := http.Get(srv.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
content, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
wasGzip := resp.Uncompressed
if string(content) != "response" {
t.Fatalf("wrong response content %q", content)
}
if wasGzip != test.isGzip {
t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip)
}
if resp.StatusCode != test.status {
t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status)
}
for name, expectedValue := range test.header {
if v := resp.Header.Get(name); v != expectedValue {
t.Fatalf("response header %s == %s, want %s", name, v, expectedValue)
}
}
})
}
}
func TestHTTPWriteTimeout(t *testing.T) {
const (
timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}`
greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}`
)
// Set-up server
timeouts := rpc.DefaultHTTPTimeouts
timeouts.WriteTimeout = time.Second
srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts)
url := fmt.Sprintf("http://%v", srv.listenAddr())
// Send normal request
t.Run("message", func(t *testing.T) {
resp := rpcRequest(t, url, "test_sleep")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != timeoutRes {
t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes)
}
})
// Batch request
t.Run("batch", func(t *testing.T) {
want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes)
resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"})
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != want {
t.Errorf("wrong response. have %s, want %s", string(body), want)
}
})
}
func apis() []rpc.API {
return []rpc.API{
{
Namespace: "test",
Service: &testService{},
},
}
}
type testService struct{}
func (s *testService) Greet() string {
return "Hello"
}
func (s *testService) Sleep() {
time.Sleep(1500 * time.Millisecond)
}

View File

@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
if err != nil {
return err
}
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON)
handler.ServeCodec(codec, 0)
return nil
}

View File

@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
return err
}
}
err := c.writeConn.writeJSON(ctx, msg)
err := c.writeConn.writeJSON(ctx, msg, false)
if err != nil {
c.writeConn = nil
if !retry {
@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) {
for {
msgs, batch, err := codec.readBatch()
if _, ok := err.(*json.SyntaxError); ok {
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
msg := errorMessage(&parseError{err.Error()})
codec.writeJSON(context.Background(), msg, true)
}
if err != nil {
c.readErr <- err

View File

@ -60,10 +60,15 @@ var (
const (
errcodeDefault = -32000
errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002
errcodePanic = -32603
errcodeMarshalError = -32603
)
const (
errMsgTimeout = "request timed out"
)
type methodNotFoundError struct{ method string }
func (e *methodNotFoundError) ErrorCode() int { return -32601 }

View File

@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
return h
}
// batchCallBuffer manages in progress call messages and their responses during a batch
// call. Calls need to be synchronized between the processing and timeout-triggering
// goroutines.
type batchCallBuffer struct {
mutex sync.Mutex
calls []*jsonrpcMessage
resp []*jsonrpcMessage
wrote bool
}
// nextCall returns the next unprocessed message.
func (b *batchCallBuffer) nextCall() *jsonrpcMessage {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.calls) == 0 {
return nil
}
// The popping happens in `pushAnswer`. The in progress call is kept
// so we can return an error for it in case of timeout.
msg := b.calls[0]
return msg
}
// pushResponse adds the response to last call returned by nextCall.
func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) {
b.mutex.Lock()
defer b.mutex.Unlock()
if answer != nil {
b.resp = append(b.resp, answer)
}
b.calls = b.calls[1:]
}
// write sends the responses.
func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.doWrite(ctx, conn, false)
}
// timeout sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()
for _, msg := range b.calls {
if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
b.resp = append(b.resp, resp)
}
}
b.doWrite(ctx, conn, true)
}
// doWrite actually writes the response.
// This assumes b.mutex is held.
func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) {
if b.wrote {
return
}
b.wrote = true // can only write once
if len(b.resp) > 0 {
conn.writeJSON(ctx, b.resp, isErrorResponse)
}
}
// handleBatch executes all messages in a batch and returns the responses.
func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
// Emit error response for empty batches:
if len(msgs) == 0 {
h.startCallProc(func(cp *callProc) {
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
resp := errorMessage(&invalidRequestError{"empty batch"})
h.conn.writeJSON(cp.ctx, resp, true)
})
return
}
@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}
// Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) {
answers := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range calls {
if answer := h.handleCallMsg(cp, msg); answer != nil {
answers = append(answers, answer)
var (
timer *time.Timer
cancel context.CancelFunc
callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))}
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// currently-running method might not return immediately on timeout, we must wait
// for the timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
callBuffer.timeout(cp.ctx, h.conn)
})
}
for {
// No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil {
break
}
msg := callBuffer.nextCall()
if msg == nil {
break
}
resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp)
}
if timer != nil {
timer.Stop()
}
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers)
if len(answers) > 0 {
h.conn.writeJSON(cp.ctx, answers)
}
for _, n := range cp.notifiers {
n.activate()
}
@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
return
}
h.startCallProc(func(cp *callProc) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
h.conn.writeJSON(cp.ctx, answer)
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
}
start := time.Now()
answer := h.runMethod(cp.ctx, msg, callb, args)
// Collect the statistics for RPC calls if metrics is enabled.
// We only care about pure rpc call. Filter out subscription.
if callb != h.unsubscribeCb {

View File

@ -23,9 +23,11 @@ import (
"errors"
"fmt"
"io"
"math"
"mime"
"net/http"
"net/url"
"strconv"
"sync"
"time"
)
@ -52,7 +54,7 @@ type httpConn struct {
// and some methods don't work. The panic() stubs here exist to ensure
// this special treatment is correct.
func (hc *httpConn) writeJSON(context.Context, interface{}) error {
func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error {
panic("writeJSON called on httpConn")
}
@ -256,7 +258,42 @@ type httpServerConn struct {
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
body := io.LimitReader(r.Body, maxRequestContentLength)
conn := &httpServerConn{Reader: body, Writer: w, r: r}
return NewCodec(conn)
encoder := func(v any, isErrorResponse bool) error {
if !isErrorResponse {
return json.NewEncoder(conn).Encode(v)
}
// It's an error response and requires special treatment.
//
// In case of a timeout error, the response must be written before the HTTP
// server's write timeout occurs. So we need to flush the response. The
// Content-Length header also needs to be set to ensure the client knows
// when it has the full response.
encdata, err := json.Marshal(v)
if err != nil {
return err
}
w.Header().Set("content-length", strconv.Itoa(len(encdata)))
// If this request is wrapped in a handler that might remove Content-Length (such
// as the automatic gzip we do in package node), we need to ensure the HTTP server
// doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked
// encoding might not be finished correctly, and some clients do not like it when
// the final chunk is missing.
w.Header().Set("transfer-encoding", "identity")
_, err = w.Write(encdata)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return err
}
dec := json.NewDecoder(conn)
dec.UseNumber()
return NewFuncCodec(conn, encoder, dec.Decode)
}
// Close does nothing and always returns nil.
@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err
}
// ContextRequestTimeout returns the request timeout derived from the given context.
func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) {
timeout := time.Duration(math.MaxInt64)
hasTimeout := false
setTimeout := func(d time.Duration) {
if d < timeout {
timeout = d
hasTimeout = true
}
}
if deadline, ok := ctx.Deadline(); ok {
setTimeout(time.Until(deadline))
}
// If the context is an HTTP request context, use the server's WriteTimeout.
httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server)
if ok && httpSrv.WriteTimeout > 0 {
wt := httpSrv.WriteTimeout
// When a write timeout is configured, we need to send the response message before
// the HTTP server cuts connection. So our internal timeout must be earlier than
// the server's true timeout.
//
// Note: Timeouts are sanitized to be a minimum of 1 second.
// Also see issue: https://github.com/golang/go/issues/47229
wt -= 100 * time.Millisecond
setTimeout(wt)
}
return timeout, hasTimeout
}

View File

@ -168,18 +168,22 @@ type ConnRemoteAddr interface {
// support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
remote string
closer sync.Once // close closed channel once
closeCh chan interface{} // closed on Close
decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports
closer sync.Once // close closed channel once
closeCh chan interface{} // closed on Close
decode decodeFunc // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode encodeFunc // encoder to allow multiple transports
conn deadlineCloser
}
type encodeFunc = func(v interface{}, isErrorResponse bool) error
type decodeFunc = func(v interface{}) error
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn
// implements ConnRemoteAddr, log messages will use it to include the remote address of
// the connection.
func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec {
codec := &jsonCodec{
closeCh: make(chan interface{}),
encode: encode,
@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec {
enc := json.NewEncoder(conn)
dec := json.NewDecoder(conn)
dec.UseNumber()
return NewFuncCodec(conn, enc.Encode, dec.Decode)
encode := func(v interface{}, isErrorResponse bool) error {
return enc.Encode(v)
}
return NewFuncCodec(conn, encode, dec.Decode)
}
func (c *jsonCodec) peerInfo() PeerInfo {
@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err
return messages, batch, nil
}
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error {
c.encMu.Lock()
defer c.encMu.Unlock()
@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
deadline = time.Now().Add(defaultWriteTimeout)
}
c.conn.SetWriteDeadline(deadline)
return c.encode(v)
return c.encode(v, isErrorResponse)
}
func (c *jsonCodec) close() {

View File

@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
reqs, batch, err := codec.readBatch()
if err != nil {
if err != io.EOF {
codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"}))
resp := errorMessage(&invalidMessageError{"parse error"})
codec.writeJSON(ctx, resp, true)
}
return
}

View File

@ -175,11 +175,13 @@ func (n *Notifier) activate() error {
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
ctx := context.Background()
return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
msg := &jsonrpcMessage{
Version: vsn,
Method: n.namespace + notificationMethodSuffix,
Params: params,
})
}
return n.h.conn.writeJSON(ctx, msg, false)
}
// A Subscription is created by a notifier and tied to that notifier. The client can use

View File

@ -51,7 +51,9 @@ type ServerCodec interface {
// jsonWriter can write JSON messages to its underlying connection.
// Implementations must be safe for concurrent use.
type jsonWriter interface {
writeJSON(context.Context, interface{}) error
// writeJSON writes a message to the connection.
writeJSON(ctx context.Context, msg interface{}, isError bool) error
// Closed returns a channel which is closed when the connection is closed.
closed() <-chan interface{}
// RemoteAddr returns the peer address of the connection.

View File

@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve
conn.SetReadDeadline(time.Time{})
return nil
})
encode := func(v interface{}, isErrorResponse bool) error {
return conn.WriteJSON(v)
}
wc := &websocketCodec{
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
conn: conn,
pingReset: make(chan struct{}, 1),
info: PeerInfo{
@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo {
return wc.info
}
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
err := wc.jsonCodec.writeJSON(ctx, v)
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
err := wc.jsonCodec.writeJSON(ctx, v, isError)
if err == nil {
// Notify pingLoop to delay the next idle ping.
select {