Patch for concurrent iterator & others (onto v1.11.6) #386
@ -321,10 +321,11 @@ func TestGraphQLTransactionLogs(t *testing.T) {
|
||||
|
||||
func createNode(t *testing.T) *node.Node {
|
||||
stack, err := node.New(&node.Config{
|
||||
HTTPHost: "127.0.0.1",
|
||||
HTTPPort: 0,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: 0,
|
||||
HTTPHost: "127.0.0.1",
|
||||
HTTPPort: 0,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: 0,
|
||||
HTTPTimeouts: node.DefaultConfig.HTTPTimeouts,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("could not create node: %v", err)
|
||||
|
@ -20,12 +20,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/eth/filters"
|
||||
"github.com/ethereum/go-ethereum/internal/ethapi"
|
||||
"github.com/ethereum/go-ethereum/node"
|
||||
"github.com/ethereum/go-ethereum/rpc"
|
||||
"github.com/graph-gophers/graphql-go"
|
||||
gqlErrors "github.com/graph-gophers/graphql-go/errors"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
@ -43,21 +47,60 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||
var (
|
||||
ctx = r.Context()
|
||||
responded sync.Once
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if len(response.Errors) > 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
if timeout, ok := rpc.ContextRequestTimeout(ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
responded.Do(func() {
|
||||
// Cancel request handling.
|
||||
cancel()
|
||||
|
||||
// Create the timeout response.
|
||||
response := &graphql.Response{
|
||||
Errors: []*gqlErrors.QueryError{{Message: "request timed out"}},
|
||||
}
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Setting this disables gzip compression in package node.
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
|
||||
// Flush the response. Since we are writing close to the response timeout,
|
||||
// chunked transfer encoding must be disabled by setting content-length.
|
||||
w.Header().Set("content-type", "application/json")
|
||||
w.Header().Set("content-length", strconv.Itoa(len(responseJSON)))
|
||||
w.Write(responseJSON)
|
||||
if flush, ok := w.(http.Flusher); ok {
|
||||
flush.Flush()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(responseJSON)
|
||||
response := h.Schema.Exec(ctx, params.Query, params.OperationName, params.Variables)
|
||||
timer.Stop()
|
||||
responded.Do(func() {
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if len(response.Errors) > 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(responseJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// New constructs a new GraphQL service instance.
|
||||
|
@ -252,6 +252,9 @@ func TestStartRPC(t *testing.T) {
|
||||
config := test.cfg
|
||||
// config.Logger = testlog.Logger(t, log.LvlDebug)
|
||||
config.P2P.NoDiscovery = true
|
||||
if config.HTTPTimeouts == (rpc.HTTPTimeouts{}) {
|
||||
config.HTTPTimeouts = rpc.DefaultHTTPTimeouts
|
||||
}
|
||||
|
||||
// Create Node.
|
||||
stack, err := New(&config)
|
||||
|
@ -559,13 +559,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
|
||||
}
|
||||
|
||||
for _, path := range test.wantHTTP {
|
||||
resp := rpcRequest(t, httpBase+path)
|
||||
resp := rpcRequest(t, httpBase+path, testMethod)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
for _, path := range test.wantNoHTTP {
|
||||
resp := rpcRequest(t, httpBase+path)
|
||||
resp := rpcRequest(t, httpBase+path, testMethod)
|
||||
if resp.StatusCode != 404 {
|
||||
t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode)
|
||||
}
|
||||
@ -586,10 +586,11 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
|
||||
|
||||
func createNode(t *testing.T, httpPort, wsPort int) *Node {
|
||||
conf := &Config{
|
||||
HTTPHost: "127.0.0.1",
|
||||
HTTPPort: httpPort,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: wsPort,
|
||||
HTTPHost: "127.0.0.1",
|
||||
HTTPPort: httpPort,
|
||||
WSHost: "127.0.0.1",
|
||||
WSPort: wsPort,
|
||||
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
|
||||
}
|
||||
node, err := New(conf)
|
||||
if err != nil {
|
||||
|
100
node/rpcstack.go
100
node/rpcstack.go
@ -24,6 +24,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -196,6 +197,7 @@ func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// if http-rpc is enabled, try to serve request
|
||||
rpc := h.httpHandler.Load().(*rpcHandler)
|
||||
if rpc != nil {
|
||||
@ -462,17 +464,94 @@ var gzPool = sync.Pool{
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
resp http.ResponseWriter
|
||||
|
||||
gz *gzip.Writer
|
||||
contentLength uint64 // total length of the uncompressed response
|
||||
written uint64 // amount of written bytes from the uncompressed response
|
||||
hasLength bool // true if uncompressed response had Content-Length
|
||||
inited bool // true after init was called for the first time
|
||||
}
|
||||
|
||||
// init runs just before response headers are written. Among other things, this function
|
||||
// also decides whether compression will be applied at all.
|
||||
func (w *gzipResponseWriter) init() {
|
||||
if w.inited {
|
||||
return
|
||||
}
|
||||
w.inited = true
|
||||
|
||||
hdr := w.resp.Header()
|
||||
length := hdr.Get("content-length")
|
||||
if len(length) > 0 {
|
||||
if n, err := strconv.ParseUint(length, 10, 64); err != nil {
|
||||
w.hasLength = true
|
||||
w.contentLength = n
|
||||
}
|
||||
}
|
||||
|
||||
// Setting Transfer-Encoding to "identity" explicitly disables compression. net/http
|
||||
// also recognizes this header value and uses it to disable "chunked" transfer
|
||||
// encoding, trimming the header from the response. This means downstream handlers can
|
||||
// set this without harm, even if they aren't wrapped by newGzipHandler.
|
||||
//
|
||||
// In go-ethereum, we use this signal to disable compression for certain error
|
||||
// responses which are flushed out close to the write deadline of the response. For
|
||||
// these cases, we want to avoid chunked transfer encoding and compression because
|
||||
// they require additional output that may not get written in time.
|
||||
passthrough := hdr.Get("transfer-encoding") == "identity"
|
||||
if !passthrough {
|
||||
w.gz = gzPool.Get().(*gzip.Writer)
|
||||
w.gz.Reset(w.resp)
|
||||
hdr.Del("content-length")
|
||||
hdr.Set("content-encoding", "gzip")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Header() http.Header {
|
||||
return w.resp.Header()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(status int) {
|
||||
w.Header().Del("Content-Length")
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
w.init()
|
||||
w.resp.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.Writer.Write(b)
|
||||
w.init()
|
||||
|
||||
if w.gz == nil {
|
||||
// Compression is disabled.
|
||||
return w.resp.Write(b)
|
||||
}
|
||||
|
||||
n, err := w.gz.Write(b)
|
||||
w.written += uint64(n)
|
||||
if w.hasLength && w.written >= w.contentLength {
|
||||
// The HTTP handler has finished writing the entire uncompressed response. Close
|
||||
// the gzip stream to ensure the footer will be seen by the client in case the
|
||||
// response is flushed after this call to write.
|
||||
err = w.gz.Close()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if w.gz != nil {
|
||||
w.gz.Flush()
|
||||
}
|
||||
if f, ok := w.resp.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) close() {
|
||||
if w.gz == nil {
|
||||
return
|
||||
}
|
||||
w.gz.Close()
|
||||
gzPool.Put(w.gz)
|
||||
w.gz = nil
|
||||
}
|
||||
|
||||
func newGzipHandler(next http.Handler) http.Handler {
|
||||
@ -482,15 +561,10 @@ func newGzipHandler(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
wrapper := &gzipResponseWriter{resp: w}
|
||||
defer wrapper.close()
|
||||
|
||||
gz := gzPool.Get().(*gzip.Writer)
|
||||
defer gzPool.Put(gz)
|
||||
|
||||
gz.Reset(w)
|
||||
defer gz.Close()
|
||||
|
||||
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
|
||||
next.ServeHTTP(wrapper, r)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,9 @@ package node
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -34,29 +36,31 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const testMethod = "rpc_modules"
|
||||
|
||||
// TestCorsHandler makes sure CORS are properly handled on the http server.
|
||||
func TestCorsHandler(t *testing.T) {
|
||||
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{})
|
||||
srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}, nil)
|
||||
defer srv.stop()
|
||||
url := "http://" + srv.listenAddr()
|
||||
|
||||
resp := rpcRequest(t, url, "origin", "test.com")
|
||||
resp := rpcRequest(t, url, testMethod, "origin", "test.com")
|
||||
assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
resp2 := rpcRequest(t, url, "origin", "bad")
|
||||
resp2 := rpcRequest(t, url, testMethod, "origin", "bad")
|
||||
assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// TestVhosts makes sure vhosts are properly handled on the http server.
|
||||
func TestVhosts(t *testing.T) {
|
||||
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{})
|
||||
srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}, nil)
|
||||
defer srv.stop()
|
||||
url := "http://" + srv.listenAddr()
|
||||
|
||||
resp := rpcRequest(t, url, "host", "test")
|
||||
resp := rpcRequest(t, url, testMethod, "host", "test")
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
|
||||
resp2 := rpcRequest(t, url, "host", "bad")
|
||||
resp2 := rpcRequest(t, url, testMethod, "host", "bad")
|
||||
assert.Equal(t, resp2.StatusCode, http.StatusForbidden)
|
||||
}
|
||||
|
||||
@ -145,7 +149,7 @@ func TestWebsocketOrigins(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)})
|
||||
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil)
|
||||
url := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||
for _, origin := range tc.expOk {
|
||||
if err := wsRequest(t, url, "Origin", origin); err != nil {
|
||||
@ -231,11 +235,14 @@ func Test_checkPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer {
|
||||
func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer {
|
||||
t.Helper()
|
||||
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
|
||||
assert.NoError(t, srv.enableRPC(nil, *conf))
|
||||
if timeouts == nil {
|
||||
timeouts = &rpc.DefaultHTTPTimeouts
|
||||
}
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), *timeouts)
|
||||
assert.NoError(t, srv.enableRPC(apis(), *conf))
|
||||
if ws {
|
||||
assert.NoError(t, srv.enableWS(nil, *wsConf))
|
||||
}
|
||||
@ -266,16 +273,33 @@ func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
|
||||
}
|
||||
|
||||
// rpcRequest performs a JSON-RPC request to the given URL.
|
||||
func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response {
|
||||
func rpcRequest(t *testing.T, url, method string, extraHeaders ...string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
body := fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, method)
|
||||
return baseRpcRequest(t, url, body, extraHeaders...)
|
||||
}
|
||||
|
||||
func batchRpcRequest(t *testing.T, url string, methods []string, extraHeaders ...string) *http.Response {
|
||||
reqs := make([]string, len(methods))
|
||||
for i, m := range methods {
|
||||
reqs[i] = fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s","params":[]}`, m)
|
||||
}
|
||||
body := fmt.Sprintf(`[%s]`, strings.Join(reqs, ","))
|
||||
return baseRpcRequest(t, url, body, extraHeaders...)
|
||||
}
|
||||
|
||||
func baseRpcRequest(t *testing.T, url, bodyStr string, extraHeaders ...string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
// Create the request.
|
||||
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`))
|
||||
body := bytes.NewReader([]byte(bodyStr))
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
t.Fatal("could not create http request:", err)
|
||||
}
|
||||
req.Header.Set("content-type", "application/json")
|
||||
req.Header.Set("accept-encoding", "identity")
|
||||
|
||||
// Apply extra headers.
|
||||
if len(extraHeaders)%2 != 0 {
|
||||
@ -315,7 +339,7 @@ func TestJWT(t *testing.T) {
|
||||
return ss
|
||||
}
|
||||
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
|
||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
|
||||
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
|
||||
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
|
||||
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||
|
||||
@ -348,7 +372,7 @@ func TestJWT(t *testing.T) {
|
||||
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
|
||||
}
|
||||
token = tokenFn()
|
||||
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
|
||||
if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 {
|
||||
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
@ -414,10 +438,176 @@ func TestJWT(t *testing.T) {
|
||||
}
|
||||
|
||||
token = tokenFn()
|
||||
resp := rpcRequest(t, htUrl, "Authorization", token)
|
||||
resp := rpcRequest(t, htUrl, testMethod, "Authorization", token)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
srv.stop()
|
||||
}
|
||||
|
||||
func TestGzipHandler(t *testing.T) {
|
||||
type gzipTest struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
status int
|
||||
isGzip bool
|
||||
header map[string]string
|
||||
}
|
||||
tests := []gzipTest{
|
||||
{
|
||||
name: "Write",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "WriteHeader",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.WriteHeader(205)
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 205,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "WriteContentLength",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("content-length", "8")
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "Flush",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("res"))
|
||||
w.(http.Flusher).Flush()
|
||||
w.Write([]byte("ponse"))
|
||||
},
|
||||
isGzip: true,
|
||||
status: 200,
|
||||
},
|
||||
{
|
||||
name: "disable",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: false,
|
||||
status: 200,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "disable-WriteHeader",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
w.Header().Set("x-foo", "bar")
|
||||
w.WriteHeader(205)
|
||||
w.Write([]byte("response"))
|
||||
},
|
||||
isGzip: false,
|
||||
status: 205,
|
||||
header: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(newGzipHandler(test.handler))
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wasGzip := resp.Uncompressed
|
||||
|
||||
if string(content) != "response" {
|
||||
t.Fatalf("wrong response content %q", content)
|
||||
}
|
||||
if wasGzip != test.isGzip {
|
||||
t.Fatalf("response gzipped == %t, want %t", wasGzip, test.isGzip)
|
||||
}
|
||||
if resp.StatusCode != test.status {
|
||||
t.Fatalf("response status == %d, want %d", resp.StatusCode, test.status)
|
||||
}
|
||||
for name, expectedValue := range test.header {
|
||||
if v := resp.Header.Get(name); v != expectedValue {
|
||||
t.Fatalf("response header %s == %s, want %s", name, v, expectedValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPWriteTimeout(t *testing.T) {
|
||||
const (
|
||||
timeoutRes = `{"jsonrpc":"2.0","id":1,"error":{"code":-32002,"message":"request timed out"}}`
|
||||
greetRes = `{"jsonrpc":"2.0","id":1,"result":"Hello"}`
|
||||
)
|
||||
// Set-up server
|
||||
timeouts := rpc.DefaultHTTPTimeouts
|
||||
timeouts.WriteTimeout = time.Second
|
||||
srv := createAndStartServer(t, &httpConfig{Modules: []string{"test"}}, false, &wsConfig{}, &timeouts)
|
||||
url := fmt.Sprintf("http://%v", srv.listenAddr())
|
||||
|
||||
// Send normal request
|
||||
t.Run("message", func(t *testing.T) {
|
||||
resp := rpcRequest(t, url, "test_sleep")
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != timeoutRes {
|
||||
t.Errorf("wrong response. have %s, want %s", string(body), timeoutRes)
|
||||
}
|
||||
})
|
||||
|
||||
// Batch request
|
||||
t.Run("batch", func(t *testing.T) {
|
||||
want := fmt.Sprintf("[%s,%s,%s]", greetRes, timeoutRes, timeoutRes)
|
||||
resp := batchRpcRequest(t, url, []string{"test_greet", "test_sleep", "test_greet"})
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(body) != want {
|
||||
t.Errorf("wrong response. have %s, want %s", string(body), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func apis() []rpc.API {
|
||||
return []rpc.API{
|
||||
{
|
||||
Namespace: "test",
|
||||
Service: &testService{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type testService struct{}
|
||||
|
||||
func (s *testService) Greet() string {
|
||||
return "Hello"
|
||||
}
|
||||
|
||||
func (s *testService) Sleep() {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
|
||||
codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON)
|
||||
handler.ServeCodec(codec, 0)
|
||||
return nil
|
||||
}
|
||||
|
@ -527,7 +527,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := c.writeConn.writeJSON(ctx, msg)
|
||||
err := c.writeConn.writeJSON(ctx, msg, false)
|
||||
if err != nil {
|
||||
c.writeConn = nil
|
||||
if !retry {
|
||||
@ -660,7 +660,8 @@ func (c *Client) read(codec ServerCodec) {
|
||||
for {
|
||||
msgs, batch, err := codec.readBatch()
|
||||
if _, ok := err.(*json.SyntaxError); ok {
|
||||
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
|
||||
msg := errorMessage(&parseError{err.Error()})
|
||||
codec.writeJSON(context.Background(), msg, true)
|
||||
}
|
||||
if err != nil {
|
||||
c.readErr <- err
|
||||
|
@ -60,10 +60,15 @@ var (
|
||||
const (
|
||||
errcodeDefault = -32000
|
||||
errcodeNotificationsUnsupported = -32001
|
||||
errcodeTimeout = -32002
|
||||
errcodePanic = -32603
|
||||
errcodeMarshalError = -32603
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgTimeout = "request timed out"
|
||||
)
|
||||
|
||||
type methodNotFoundError struct{ method string }
|
||||
|
||||
func (e *methodNotFoundError) ErrorCode() int { return -32601 }
|
||||
|
142
rpc/handler.go
142
rpc/handler.go
@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
return h
|
||||
}
|
||||
|
||||
// batchCallBuffer manages in progress call messages and their responses during a batch
|
||||
// call. Calls need to be synchronized between the processing and timeout-triggering
|
||||
// goroutines.
|
||||
type batchCallBuffer struct {
|
||||
mutex sync.Mutex
|
||||
calls []*jsonrpcMessage
|
||||
resp []*jsonrpcMessage
|
||||
wrote bool
|
||||
}
|
||||
|
||||
// nextCall returns the next unprocessed message.
|
||||
func (b *batchCallBuffer) nextCall() *jsonrpcMessage {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
if len(b.calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
// The popping happens in `pushAnswer`. The in progress call is kept
|
||||
// so we can return an error for it in case of timeout.
|
||||
msg := b.calls[0]
|
||||
return msg
|
||||
}
|
||||
|
||||
// pushResponse adds the response to last call returned by nextCall.
|
||||
func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
if answer != nil {
|
||||
b.resp = append(b.resp, answer)
|
||||
}
|
||||
b.calls = b.calls[1:]
|
||||
}
|
||||
|
||||
// write sends the responses.
|
||||
func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
b.doWrite(ctx, conn, false)
|
||||
}
|
||||
|
||||
// timeout sends the responses added so far. For the remaining unanswered call
|
||||
// messages, it sends a timeout error response.
|
||||
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
for _, msg := range b.calls {
|
||||
if !msg.isNotification() {
|
||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||
b.resp = append(b.resp, resp)
|
||||
}
|
||||
}
|
||||
b.doWrite(ctx, conn, true)
|
||||
}
|
||||
|
||||
// doWrite actually writes the response.
|
||||
// This assumes b.mutex is held.
|
||||
func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) {
|
||||
if b.wrote {
|
||||
return
|
||||
}
|
||||
b.wrote = true // can only write once
|
||||
if len(b.resp) > 0 {
|
||||
conn.writeJSON(ctx, b.resp, isErrorResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// handleBatch executes all messages in a batch and returns the responses.
|
||||
func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
// Emit error response for empty batches:
|
||||
if len(msgs) == 0 {
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
|
||||
resp := errorMessage(&invalidRequestError{"empty batch"})
|
||||
h.conn.writeJSON(cp.ctx, resp, true)
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
|
||||
}
|
||||
// Process calls on a goroutine because they may block indefinitely:
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
answers := make([]*jsonrpcMessage, 0, len(msgs))
|
||||
for _, msg := range calls {
|
||||
if answer := h.handleCallMsg(cp, msg); answer != nil {
|
||||
answers = append(answers, answer)
|
||||
var (
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))}
|
||||
)
|
||||
|
||||
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
||||
defer cancel()
|
||||
|
||||
// Cancel the request context after timeout and send an error response. Since the
|
||||
// currently-running method might not return immediately on timeout, we must wait
|
||||
// for the timeout concurrently with processing the request.
|
||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
cancel()
|
||||
callBuffer.timeout(cp.ctx, h.conn)
|
||||
})
|
||||
}
|
||||
|
||||
for {
|
||||
// No need to handle rest of calls if timed out.
|
||||
if cp.ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
msg := callBuffer.nextCall()
|
||||
if msg == nil {
|
||||
break
|
||||
}
|
||||
resp := h.handleCallMsg(cp, msg)
|
||||
callBuffer.pushResponse(resp)
|
||||
}
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
callBuffer.write(cp.ctx, h.conn)
|
||||
h.addSubscriptions(cp.notifiers)
|
||||
if len(answers) > 0 {
|
||||
h.conn.writeJSON(cp.ctx, answers)
|
||||
}
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
}
|
||||
@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
|
||||
return
|
||||
}
|
||||
h.startCallProc(func(cp *callProc) {
|
||||
var (
|
||||
responded sync.Once
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
cp.ctx, cancel = context.WithCancel(cp.ctx)
|
||||
defer cancel()
|
||||
|
||||
// Cancel the request context after timeout and send an error response. Since the
|
||||
// running method might not return immediately on timeout, we must wait for the
|
||||
// timeout concurrently with processing the request.
|
||||
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
|
||||
timer = time.AfterFunc(timeout, func() {
|
||||
cancel()
|
||||
responded.Do(func() {
|
||||
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
|
||||
h.conn.writeJSON(cp.ctx, resp, true)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
answer := h.handleCallMsg(cp, msg)
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
h.addSubscriptions(cp.notifiers)
|
||||
if answer != nil {
|
||||
h.conn.writeJSON(cp.ctx, answer)
|
||||
responded.Do(func() {
|
||||
h.conn.writeJSON(cp.ctx, answer, false)
|
||||
})
|
||||
}
|
||||
for _, n := range cp.notifiers {
|
||||
n.activate()
|
||||
@ -334,7 +457,6 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
||||
}
|
||||
start := time.Now()
|
||||
answer := h.runMethod(cp.ctx, msg, callb, args)
|
||||
|
||||
// Collect the statistics for RPC calls if metrics is enabled.
|
||||
// We only care about pure rpc call. Filter out subscription.
|
||||
if callb != h.unsubscribeCb {
|
||||
|
73
rpc/http.go
73
rpc/http.go
@ -23,9 +23,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -52,7 +54,7 @@ type httpConn struct {
|
||||
// and some methods don't work. The panic() stubs here exist to ensure
|
||||
// this special treatment is correct.
|
||||
|
||||
func (hc *httpConn) writeJSON(context.Context, interface{}) error {
|
||||
func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error {
|
||||
panic("writeJSON called on httpConn")
|
||||
}
|
||||
|
||||
@ -256,7 +258,42 @@ type httpServerConn struct {
|
||||
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
|
||||
body := io.LimitReader(r.Body, maxRequestContentLength)
|
||||
conn := &httpServerConn{Reader: body, Writer: w, r: r}
|
||||
return NewCodec(conn)
|
||||
|
||||
encoder := func(v any, isErrorResponse bool) error {
|
||||
if !isErrorResponse {
|
||||
return json.NewEncoder(conn).Encode(v)
|
||||
}
|
||||
|
||||
// It's an error response and requires special treatment.
|
||||
//
|
||||
// In case of a timeout error, the response must be written before the HTTP
|
||||
// server's write timeout occurs. So we need to flush the response. The
|
||||
// Content-Length header also needs to be set to ensure the client knows
|
||||
// when it has the full response.
|
||||
encdata, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("content-length", strconv.Itoa(len(encdata)))
|
||||
|
||||
// If this request is wrapped in a handler that might remove Content-Length (such
|
||||
// as the automatic gzip we do in package node), we need to ensure the HTTP server
|
||||
// doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked
|
||||
// encoding might not be finished correctly, and some clients do not like it when
|
||||
// the final chunk is missing.
|
||||
w.Header().Set("transfer-encoding", "identity")
|
||||
|
||||
_, err = w.Write(encdata)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(conn)
|
||||
dec.UseNumber()
|
||||
|
||||
return NewFuncCodec(conn, encoder, dec.Decode)
|
||||
}
|
||||
|
||||
// Close does nothing and always returns nil.
|
||||
@ -326,3 +363,35 @@ func validateRequest(r *http.Request) (int, error) {
|
||||
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
|
||||
return http.StatusUnsupportedMediaType, err
|
||||
}
|
||||
|
||||
// ContextRequestTimeout returns the request timeout derived from the given context.
|
||||
func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) {
|
||||
timeout := time.Duration(math.MaxInt64)
|
||||
hasTimeout := false
|
||||
setTimeout := func(d time.Duration) {
|
||||
if d < timeout {
|
||||
timeout = d
|
||||
hasTimeout = true
|
||||
}
|
||||
}
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
setTimeout(time.Until(deadline))
|
||||
}
|
||||
|
||||
// If the context is an HTTP request context, use the server's WriteTimeout.
|
||||
httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server)
|
||||
if ok && httpSrv.WriteTimeout > 0 {
|
||||
wt := httpSrv.WriteTimeout
|
||||
// When a write timeout is configured, we need to send the response message before
|
||||
// the HTTP server cuts connection. So our internal timeout must be earlier than
|
||||
// the server's true timeout.
|
||||
//
|
||||
// Note: Timeouts are sanitized to be a minimum of 1 second.
|
||||
// Also see issue: https://github.com/golang/go/issues/47229
|
||||
wt -= 100 * time.Millisecond
|
||||
setTimeout(wt)
|
||||
}
|
||||
|
||||
return timeout, hasTimeout
|
||||
}
|
||||
|
26
rpc/json.go
26
rpc/json.go
@ -168,18 +168,22 @@ type ConnRemoteAddr interface {
|
||||
// support for parsing arguments and serializing (result) objects.
|
||||
type jsonCodec struct {
|
||||
remote string
|
||||
closer sync.Once // close closed channel once
|
||||
closeCh chan interface{} // closed on Close
|
||||
decode func(v interface{}) error // decoder to allow multiple transports
|
||||
encMu sync.Mutex // guards the encoder
|
||||
encode func(v interface{}) error // encoder to allow multiple transports
|
||||
closer sync.Once // close closed channel once
|
||||
closeCh chan interface{} // closed on Close
|
||||
decode decodeFunc // decoder to allow multiple transports
|
||||
encMu sync.Mutex // guards the encoder
|
||||
encode encodeFunc // encoder to allow multiple transports
|
||||
conn deadlineCloser
|
||||
}
|
||||
|
||||
type encodeFunc = func(v interface{}, isErrorResponse bool) error
|
||||
|
||||
type decodeFunc = func(v interface{}) error
|
||||
|
||||
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn
|
||||
// implements ConnRemoteAddr, log messages will use it to include the remote address of
|
||||
// the connection.
|
||||
func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
|
||||
func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec {
|
||||
codec := &jsonCodec{
|
||||
closeCh: make(chan interface{}),
|
||||
encode: encode,
|
||||
@ -198,7 +202,11 @@ func NewCodec(conn Conn) ServerCodec {
|
||||
enc := json.NewEncoder(conn)
|
||||
dec := json.NewDecoder(conn)
|
||||
dec.UseNumber()
|
||||
return NewFuncCodec(conn, enc.Encode, dec.Decode)
|
||||
|
||||
encode := func(v interface{}, isErrorResponse bool) error {
|
||||
return enc.Encode(v)
|
||||
}
|
||||
return NewFuncCodec(conn, encode, dec.Decode)
|
||||
}
|
||||
|
||||
func (c *jsonCodec) peerInfo() PeerInfo {
|
||||
@ -228,7 +236,7 @@ func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err err
|
||||
return messages, batch, nil
|
||||
}
|
||||
|
||||
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error {
|
||||
c.encMu.Lock()
|
||||
defer c.encMu.Unlock()
|
||||
|
||||
@ -237,7 +245,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
deadline = time.Now().Add(defaultWriteTimeout)
|
||||
}
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
return c.encode(v)
|
||||
return c.encode(v, isErrorResponse)
|
||||
}
|
||||
|
||||
func (c *jsonCodec) close() {
|
||||
|
@ -125,7 +125,8 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
|
||||
reqs, batch, err := codec.readBatch()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"}))
|
||||
resp := errorMessage(&invalidMessageError{"parse error"})
|
||||
codec.writeJSON(ctx, resp, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -175,11 +175,13 @@ func (n *Notifier) activate() error {
|
||||
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
|
||||
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
|
||||
ctx := context.Background()
|
||||
return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
|
||||
|
||||
msg := &jsonrpcMessage{
|
||||
Version: vsn,
|
||||
Method: n.namespace + notificationMethodSuffix,
|
||||
Params: params,
|
||||
})
|
||||
}
|
||||
return n.h.conn.writeJSON(ctx, msg, false)
|
||||
}
|
||||
|
||||
// A Subscription is created by a notifier and tied to that notifier. The client can use
|
||||
|
@ -51,7 +51,9 @@ type ServerCodec interface {
|
||||
// jsonWriter can write JSON messages to its underlying connection.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type jsonWriter interface {
|
||||
writeJSON(context.Context, interface{}) error
|
||||
// writeJSON writes a message to the connection.
|
||||
writeJSON(ctx context.Context, msg interface{}, isError bool) error
|
||||
|
||||
// Closed returns a channel which is closed when the connection is closed.
|
||||
closed() <-chan interface{}
|
||||
// RemoteAddr returns the peer address of the connection.
|
||||
|
@ -287,8 +287,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
return nil
|
||||
})
|
||||
|
||||
encode := func(v interface{}, isErrorResponse bool) error {
|
||||
return conn.WriteJSON(v)
|
||||
}
|
||||
wc := &websocketCodec{
|
||||
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
|
||||
jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec),
|
||||
conn: conn,
|
||||
pingReset: make(chan struct{}, 1),
|
||||
info: PeerInfo{
|
||||
@ -315,8 +319,8 @@ func (wc *websocketCodec) peerInfo() PeerInfo {
|
||||
return wc.info
|
||||
}
|
||||
|
||||
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
|
||||
err := wc.jsonCodec.writeJSON(ctx, v)
|
||||
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error {
|
||||
err := wc.jsonCodec.writeJSON(ctx, v, isError)
|
||||
if err == nil {
|
||||
// Notify pingLoop to delay the next idle ping.
|
||||
select {
|
||||
|
Loading…
Reference in New Issue
Block a user