From f3314bb6df4c86e650f0e47cbb5a21ca0616ac11 Mon Sep 17 00:00:00 2001 From: mmsqe Date: Tue, 13 Jun 2023 19:38:58 +0800 Subject: [PATCH] rpc: add limit for batch request items and response size (#26681) This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches were limited only by processing time. The server would pick calls from the batch and answer them until the response timeout occurred, then stop processing the remaining batch items. Here, we are adding two additional limits which can be configured: - the 'item limit': batches can have at most N items - the 'response size limit': batches can contain at most X response bytes These limits are optional in package rpc. In Geth, we set a default limit of 1000 items and 25MB response size. When a batch goes over the limit, an error response is returned to the client. However, doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid `id` can be responded to. Since batches may also contain non-call messages or notifications, the best effort thing we can do to report an error with the batch itself is reporting the limit violation as an error for the first method call in the batch. If a batch is too large, but contains only notifications and responses, the error will be reported with a null `id`. The RPC client was also changed so it can deal with errors resulting from too large batches. An older client connected to the server code in this PR could get stuck until the request timeout occurred when the batch is too large. **Upgrading to a version of the RPC client containing this change is strongly recommended to avoid timeout issues.** For some weird reason, when writing the original client implementation, @fjl worked off of the assumption that responses could be distributed across batches arbitrarily. So for a batch request containing requests `[A B C]`, the server could respond with `[A B C]` but also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the client. So in the implementation of BatchCallContext, the client waited for all requests in the batch individually. If the server didn't respond to some of the requests in the batch, the client would eventually just time out (if a context was used). With the addition of batch limits into the server, we anticipate that people will hit this kind of error way more often. To handle this properly, the client now waits for a single response batch and expects it to contain all responses to the requests. --------- Co-authored-by: Felix Lange Co-authored-by: Martin Holst Swende --- cmd/clef/main.go | 1 + cmd/geth/main.go | 2 + cmd/utils/flags.go | 20 ++ node/api.go | 8 + node/config.go | 6 + node/defaults.go | 24 +- node/node.go | 31 ++- node/rpcstack.go | 18 +- node/rpcstack_test.go | 6 +- rpc/client.go | 131 +++++++---- rpc/client_opt.go | 29 +++ rpc/client_test.go | 104 +++++++-- rpc/errors.go | 5 +- rpc/handler.go | 289 +++++++++++++++---------- rpc/http.go | 19 +- rpc/inproc.go | 3 +- rpc/ipc.go | 3 +- rpc/server.go | 28 ++- rpc/server_test.go | 39 ++++ rpc/stdio.go | 3 +- rpc/testdata/invalid-batch-toolarge.js | 13 ++ rpc/websocket.go | 4 +- 22 files changed, 557 insertions(+), 229 deletions(-) create mode 100644 rpc/testdata/invalid-batch-toolarge.js diff --git a/cmd/clef/main.go b/cmd/clef/main.go index cebe74797..14f09dc1a 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -732,6 +732,7 @@ func signer(c *cli.Context) error { cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name)) srv := rpc.NewServer() + srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize) err := node.RegisterApis(rpcAPI, []string{"account"}, srv) if err != nil { utils.Fatalf("Could not register API: %w", err) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 2289a72a1..2794e37e3 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -168,6 +168,8 @@ var ( utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalTxFeeCapFlag, utils.AllowUnprotectedTxs, + utils.BatchRequestLimit, + utils.BatchResponseMaxSize, } metricsFlags = []cli.Flag{ diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 1ebc998a4..692970477 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -713,6 +713,18 @@ var ( Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC", Category: flags.APICategory, } + BatchRequestLimit = &cli.IntFlag{ + Name: "rpc.batch-request-limit", + Usage: "Maximum number of requests in a batch", + Value: node.DefaultConfig.BatchRequestLimit, + Category: flags.APICategory, + } + BatchResponseMaxSize = &cli.IntFlag{ + Name: "rpc.batch-response-max-size", + Usage: "Maximum number of bytes returned from a batched call", + Value: node.DefaultConfig.BatchResponseMaxSize, + Category: flags.APICategory, + } EnablePersonal = &cli.BoolFlag{ Name: "rpc.enabledeprecatedpersonal", Usage: "Enables the (deprecated) personal namespace", @@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { if ctx.IsSet(AllowUnprotectedTxs.Name) { cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name) } + + if ctx.IsSet(BatchRequestLimit.Name) { + cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name) + } + + if ctx.IsSet(BatchResponseMaxSize.Name) { + cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name) + } } // setGraphQL creates the GraphQL listener interface string from the set diff --git a/node/api.go b/node/api.go index 15892a270..f81f394be 100644 --- a/node/api.go +++ b/node/api.go @@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri CorsAllowedOrigins: api.node.config.HTTPCors, Vhosts: api.node.config.HTTPVirtualHosts, Modules: api.node.config.HTTPModules, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if cors != nil { config.CorsAllowedOrigins = nil @@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap Modules: api.node.config.WSModules, Origins: api.node.config.WSOrigins, // ExposeAll: api.node.config.WSExposeAll, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if apis != nil { config.Modules = nil diff --git a/node/config.go b/node/config.go index 1765811e8..37c1e4882 100644 --- a/node/config.go +++ b/node/config.go @@ -197,6 +197,12 @@ type Config struct { // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. AllowUnprotectedTxs bool `toml:",omitempty"` + // BatchRequestLimit is the maximum number of requests in a batch. + BatchRequestLimit int `toml:",omitempty"` + + // BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call. + BatchResponseMaxSize int `toml:",omitempty"` + // JWTSecret is the path to the hex-encoded jwt secret. JWTSecret string `toml:",omitempty"` diff --git a/node/defaults.go b/node/defaults.go index fcfbc934b..d8f718121 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -46,17 +46,19 @@ var ( // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ - DataDir: DefaultDataDir(), - HTTPPort: DefaultHTTPPort, - AuthAddr: DefaultAuthHost, - AuthPort: DefaultAuthPort, - AuthVirtualHosts: DefaultAuthVhosts, - HTTPModules: []string{"net", "web3"}, - HTTPVirtualHosts: []string{"localhost"}, - HTTPTimeouts: rpc.DefaultHTTPTimeouts, - WSPort: DefaultWSPort, - WSModules: []string{"net", "web3"}, - GraphQLVirtualHosts: []string{"localhost"}, + DataDir: DefaultDataDir(), + HTTPPort: DefaultHTTPPort, + AuthAddr: DefaultAuthHost, + AuthPort: DefaultAuthPort, + AuthVirtualHosts: DefaultAuthVhosts, + HTTPModules: []string{"net", "web3"}, + HTTPVirtualHosts: []string{"localhost"}, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, + WSPort: DefaultWSPort, + WSModules: []string{"net", "web3"}, + BatchRequestLimit: 1000, + BatchResponseMaxSize: 25 * 1000 * 1000, + GraphQLVirtualHosts: []string{"localhost"}, P2P: p2p.Config{ ListenAddr: ":30303", MaxPeers: 50, diff --git a/node/node.go b/node/node.go index e8494ac3b..553d451ab 100644 --- a/node/node.go +++ b/node/node.go @@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) { if strings.HasSuffix(conf.Name, ".ipc") { return nil, errors.New(`Config.Name cannot end in ".ipc"`) } - + server := rpc.NewServer() + server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize) node := &Node{ config: conf, - inprocHandler: rpc.NewServer(), + inprocHandler: server, eventmux: new(event.TypeMux), log: conf.Logger, stop: make(chan struct{}), @@ -403,6 +404,11 @@ func (n *Node) startRPC() error { openAPIs, allAPIs = n.getAPIs() ) + rpcConfig := rpcEndpointConfig{ + batchItemLimit: n.config.BatchRequestLimit, + batchResponseSizeLimit: n.config.BatchResponseMaxSize, + } + initHttp := func(server *httpServer, port int) error { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { return err @@ -412,6 +418,7 @@ func (n *Node) startRPC() error { Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + rpcEndpointConfig: rpcConfig, }); err != nil { return err } @@ -425,9 +432,10 @@ func (n *Node) startRPC() error { return err } if err := server.enableWS(openAPIs, wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, + rpcEndpointConfig: rpcConfig, }); err != nil { return err } @@ -441,26 +449,29 @@ func (n *Node) startRPC() error { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } + sharedConfig := rpcConfig + sharedConfig.jwtSecret = secret if err := server.enableRPC(allAPIs, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, Modules: DefaultAuthModules, prefix: DefaultAuthPrefix, - jwtSecret: secret, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } servers = append(servers, server) + // Enable auth via WS server = n.wsServerForPort(port, true) if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } if err := server.enableWS(allAPIs, wsConfig{ - Modules: DefaultAuthModules, - Origins: DefaultAuthOrigins, - prefix: DefaultAuthPrefix, - jwtSecret: secret, + Modules: DefaultAuthModules, + Origins: DefaultAuthOrigins, + prefix: DefaultAuthPrefix, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } diff --git a/node/rpcstack.go b/node/rpcstack.go index 97d591642..e91585a2b 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -41,15 +41,21 @@ type httpConfig struct { CorsAllowedOrigins []string Vhosts []string prefix string // path prefix on which to mount http handler - jwtSecret []byte // optional JWT secret + rpcEndpointConfig } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler - jwtSecret []byte // optional JWT secret + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + rpcEndpointConfig +} + +type rpcEndpointConfig struct { + jwtSecret []byte // optional JWT secret + batchItemLimit int + batchResponseSizeLimit int } type rpcHandler struct { @@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } @@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 0790dddec..e41cc51ad 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -339,8 +339,10 @@ func TestJWT(t *testing.T) { ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) return ss } - srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) + cfg := rpcEndpointConfig{jwtSecret: []byte("secret")} + httpcfg := &httpConfig{rpcEndpointConfig: cfg} + wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg} + srv := createAndStartServer(t, httpcfg, true, wscfg, nil) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) diff --git a/rpc/client.go b/rpc/client.go index fae8536b2..c3114ef1d 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -34,14 +34,15 @@ import ( var ( ErrBadResult = errors.New("bad result in JSON-RPC response") ErrClientQuit = errors.New("client is closed") - ErrNoResult = errors.New("no result in JSON-RPC response") + ErrNoResult = errors.New("JSON-RPC response has no result") + ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") errClientReconnected = errors.New("client reconnected") errDead = errors.New("connection lost") ) +// Timeouts const ( - // Timeouts defaultDialTimeout = 10 * time.Second // used if context has no deadline subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) @@ -84,6 +85,10 @@ type Client struct { // This function, if non-nil, is called when the connection is lost. reconnectFunc reconnectFunc + // config fields + batchItemLimit int + batchResponseMaxSize int + // writeConn is used for writing to the connection on the caller's goroutine. It should // only be accessed outside of dispatch, with the write lock held. The write lock is // taken by sending on reqInit and released by sending on reqSent. @@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.Background() ctx = context.WithValue(ctx, clientContextKey{}, c) ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) - handler := newHandler(ctx, conn, c.idgen, c.services) + handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize) return &clientConn{conn, handler} } @@ -128,14 +133,17 @@ type readOp struct { batch bool } +// requestOp represents a pending request. This is used for both batch and non-batch +// requests. type requestOp struct { - ids []json.RawMessage - err error - resp chan *jsonrpcMessage // receives up to len(ids) responses - sub *ClientSubscription // only set for EthSubscribe requests + ids []json.RawMessage + err error + resp chan []*jsonrpcMessage // the response goes here + sub *ClientSubscription // set for Subscribe requests. + hadResponse bool // true when the request was responded to } -func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { +func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) { select { case <-ctx.Done(): // Send the timeout to dispatch so it can remove the request IDs. @@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (* return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } - return newClient(ctx, reconnect) + return newClient(ctx, cfg, reconnect) } // ClientFromContext retrieves the client from the context, if any. This can be used to perform @@ -221,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) { return client, ok } -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { +func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) { conn, err := connect(initctx) if err != nil { return nil, err } - c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) + c := initClient(conn, new(serviceRegistry), cfg) c.reconnectFunc = connect return c, nil } -func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { +func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client { _, isHTTP := conn.(*httpConn) c := &Client{ - isHTTP: isHTTP, - idgen: idgen, - services: services, - writeConn: conn, - close: make(chan struct{}), - closing: make(chan struct{}), - didClose: make(chan struct{}), - reconnected: make(chan ServerCodec), - readOp: make(chan readOp), - readErr: make(chan error), - reqInit: make(chan *requestOp), - reqSent: make(chan error, 1), - reqTimeout: make(chan *requestOp), + isHTTP: isHTTP, + services: services, + idgen: cfg.idgen, + batchItemLimit: cfg.batchItemLimit, + batchResponseMaxSize: cfg.batchResponseLimit, + writeConn: conn, + close: make(chan struct{}), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), + readErr: make(chan error), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), } + + // Set defaults. + if c.idgen == nil { + c.idgen = randomIDGenerator() + } + + // Launch the main loop. if !isHTTP { go c.dispatch(conn) } @@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str if err != nil { return err } - op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + } if c.isHTTP { err = c.sendHTTP(ctx, op, msg) @@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } // dispatch has accepted the request and will close the channel when it quits. - switch resp, err := op.wait(ctx, c); { - case err != nil: + batchresp, err := op.wait(ctx, c) + if err != nil { return err + } + resp := batchresp[0] + switch { case resp.Error != nil: return resp.Error case len(resp.Result) == 0: @@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { ) op := &requestOp{ ids: make([]json.RawMessage, len(b)), - resp: make(chan *jsonrpcMessage, len(b)), + resp: make(chan []*jsonrpcMessage, 1), } for i, elem := range b { msg, err := c.newMessage(elem.Method, elem.Args...) @@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { } else { err = c.send(ctx, op, msgs) } + if err != nil { + return err + } + + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } // Wait for all responses to come back. - for n := 0; n < len(b) && err == nil; n++ { - var resp *jsonrpcMessage - resp, err = op.wait(ctx, c) - if err != nil { - break + for n := 0; n < len(batchresp) && err == nil; n++ { + resp := batchresp[n] + if resp == nil { + // Ignore null responses. These can happen for batches sent via HTTP. + continue } + // Find the element corresponding to this response. - // The element is guaranteed to be present because dispatch - // only sends valid IDs to our channel. - elem := &b[byID[string(resp.ID)]] - if resp.Error != nil { + index, ok := byID[string(resp.ID)] + if !ok { + continue + } + delete(byID, string(resp.ID)) + + // Assign result and error. + elem := &b[index] + switch { + case resp.Error != nil: elem.Error = resp.Error - continue - } - if len(resp.Result) == 0 { + case resp.Result == nil: elem.Error = ErrNoResult - continue + default: + elem.Error = json.Unmarshal(resp.Result, elem.Result) } - elem.Error = json.Unmarshal(resp.Result, elem.Result) } + + // Check that all expected responses have been received. + for _, index := range byID { + elem := &b[index] + elem.Error = ErrMissingBatchResponse + } + return err } @@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf } op := &requestOp{ ids: []json.RawMessage{msg.ID}, - resp: make(chan *jsonrpcMessage), + resp: make(chan []*jsonrpcMessage, 1), sub: newClientSubscription(c, namespace, chanVal), } diff --git a/rpc/client_opt.go b/rpc/client_opt.go index 5ad7c22b3..5bef08cca 100644 --- a/rpc/client_opt.go +++ b/rpc/client_opt.go @@ -28,11 +28,18 @@ type ClientOption interface { } type clientConfig struct { + // HTTP settings httpClient *http.Client httpHeaders http.Header httpAuth HTTPAuth + // WebSocket options wsDialer *websocket.Dialer + + // RPC handler options + idgen func() ID + batchItemLimit int + batchResponseLimit int } func (cfg *clientConfig) initHeaders() { @@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption { // Usually, HTTPAuth functions will call h.Set("authorization", "...") to add // auth information to the request. type HTTPAuth func(h http.Header) error + +// WithBatchItemLimit changes the maximum number of items allowed in batch requests. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchItemLimit(limit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchItemLimit = limit + }) +} + +// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be +// generated for batch requests. When this limit is reached, further calls in the batch +// will not be processed. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchResponseSizeLimit(sizeLimit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchResponseLimit = sizeLimit + }) +} diff --git a/rpc/client_test.go b/rpc/client_test.go index a94a54929..7c96b2d66 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) { } } +// This checks that, for HTTP connections, the length of batch responses is validated to +// match the request exactly. func TestClientBatchRequest_len(t *testing.T) { b, err := json.Marshal([]jsonrpcMessage{ - {Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)}, - {Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)}, + {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)}, }) if err != nil { t.Fatal("failed to encode jsonrpc message:", err) @@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) { })) t.Cleanup(s.Close) - client, err := Dial(s.URL) - if err != nil { - t.Fatal("failed to dial test server:", err) - } - defer client.Close() - t.Run("too-few", func(t *testing.T) { + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + batch := []BatchElem{ - {Method: "foo"}, - {Method: "bar"}, - {Method: "baz"}, + {Method: "foo", Result: new(string)}, + {Method: "bar", Result: new(string)}, + {Method: "baz", Result: new(string)}, } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { - t.Errorf("expected %q but got: %v", ErrBadResult, err) + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:2] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[2:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } } }) t.Run("too-many", func(t *testing.T) { + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + batch := []BatchElem{ - {Method: "foo"}, + {Method: "foo", Result: new(string)}, } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { - t.Errorf("expected %q but got: %v", ErrBadResult, err) + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:1] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } } }) } +// This checks that the client can handle the case where the server doesn't +// respond to all requests in a batch. +func TestClientBatchRequestLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(2, 100000) + client := DialInProc(server) + + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + err := client.BatchCall(batch) + if err != nil { + t.Fatal("unexpected error:", err) + } + + // Check that the first response indicates an error with batch size. + var err0 Error + if !errors.As(batch[0].Error, &err0) { + t.Log("error zero:", batch[0].Error) + t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error) + } else { + if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge { + t.Fatalf("wrong error on batch elem zero: %v", err0) + } + } + + // Check that remaining response batch elements are reported as absent. + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error) + } + } +} + func TestClientNotify(t *testing.T) { server := newTestServer() defer server.Stop() @@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) { _, hasDeadline := ctx.Deadline() t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) // default: - // t.Logf("got expected error with %v wait time: %v", timeout, err) + // t.Logf("got expected error with %v wait time: %v", timeout, err) } cancel() } @@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { defer srv.Stop() // Create the client on the other end of the pipe. - client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { return NewCodec(p2), nil }) defer client.Close() diff --git a/rpc/errors.go b/rpc/errors.go index 7188332d5..abb698af7 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -61,12 +61,15 @@ const ( errcodeDefault = -32000 errcodeNotificationsUnsupported = -32001 errcodeTimeout = -32002 + errcodeResponseTooLarge = -32003 errcodePanic = -32603 errcodeMarshalError = -32603 ) const ( - errMsgTimeout = "request timed out" + errMsgTimeout = "request timed out" + errMsgResponseTooLarge = "response too large" + errMsgBatchTooLarge = "batch too large" ) type methodNotFoundError struct{ method string } diff --git a/rpc/handler.go b/rpc/handler.go index c2e7d7dc0..4f48c7931 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -49,17 +49,19 @@ import ( // h.removeRequestOp(op) // timeout, etc. // } type handler struct { - reg *serviceRegistry - unsubscribeCb *callback - idgen func() ID // subscription ID generator - respWait map[string]*requestOp // active client requests - clientSubs map[string]*ClientSubscription // active client subscriptions - callWG sync.WaitGroup // pending call goroutines - rootCtx context.Context // canceled by close() - cancelRoot func() // cancel function for rootCtx - conn jsonWriter // where responses will be sent - log log.Logger - allowSubscribe bool + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + batchRequestLimit int + batchResponseMaxSize int subLock sync.Mutex serverSubs map[ID]*Subscription @@ -70,19 +72,21 @@ type callProc struct { notifiers []*Notifier } -func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler { rootCtx, cancelRoot := context.WithCancel(connCtx) h := &handler{ - reg: reg, - idgen: idgen, - conn: conn, - respWait: make(map[string]*requestOp), - clientSubs: make(map[string]*ClientSubscription), - rootCtx: rootCtx, - cancelRoot: cancelRoot, - allowSubscribe: true, - serverSubs: make(map[ID]*Subscription), - log: log.Root(), + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + batchRequestLimit: batchRequestLimit, + batchResponseMaxSize: batchResponseMaxSize, } if conn.remoteAddr() != "" { h.log = h.log.New("conn", conn.remoteAddr()) @@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { b.doWrite(ctx, conn, false) } -// timeout sends the responses added so far. For the remaining unanswered call -// messages, it sends a timeout error response. -func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { +// respondWithError sends the responses added so far. For the remaining unanswered call +// messages, it responds with the given error. +func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) { b.mutex.Lock() defer b.mutex.Unlock() for _, msg := range b.calls { if !msg.isNotification() { - resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) - b.resp = append(b.resp, resp) + b.resp = append(b.resp, msg.errorResponse(err)) } } b.doWrite(ctx, conn, true) @@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { }) return } - - // Handle non-call messages first: - calls := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range msgs { - if handled := h.handleImmediate(msg); !handled { - calls = append(calls, msg) - } + // Apply limit on total number of requests. + if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit { + h.startCallProc(func(cp *callProc) { + h.respondWithBatchTooLarge(cp, msgs) + }) + return } + + // Handle non-call messages first. + // Here we need to find the requestOp that sent the request batch. + calls := make([]*jsonrpcMessage, 0, len(msgs)) + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + calls = append(calls, msg) + }) if len(calls) == 0 { return } + // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { var ( @@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { if timeout, ok := ContextRequestTimeout(cp.ctx); ok { timer = time.AfterFunc(timeout, func() { cancel() - callBuffer.timeout(cp.ctx, h.conn) + err := &internalServerError{errcodeTimeout, errMsgTimeout} + callBuffer.respondWithError(cp.ctx, h.conn, err) }) } + responseBytes := 0 for { // No need to handle rest of calls if timed out. if cp.ctx.Err() != nil { @@ -214,61 +226,88 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } resp := h.handleCallMsg(cp, msg) callBuffer.pushResponse(resp) + if resp != nil && h.batchResponseMaxSize != 0 { + responseBytes += len(resp.Result) + if responseBytes > h.batchResponseMaxSize { + err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge} + callBuffer.respondWithError(cp.ctx, h.conn, err) + break + } + } } if timer != nil { timer.Stop() } - callBuffer.write(cp.ctx, h.conn) + h.addSubscriptions(cp.notifiers) + callBuffer.write(cp.ctx, h.conn) for _, n := range cp.notifiers { n.activate() } }) } -// handleMsg handles a single message. -func (h *handler) handleMsg(msg *jsonrpcMessage) { - if ok := h.handleImmediate(msg); ok { - return +func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) { + resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge}) + // Find the first call and add its "id" field to the error. + // This is the best we can do, given that the protocol doesn't have a way + // of reporting an error for the entire batch. + for _, msg := range batch { + if msg.isCall() { + resp.ID = msg.ID + break + } } - h.startCallProc(func(cp *callProc) { - var ( - responded sync.Once - timer *time.Timer - cancel context.CancelFunc - ) - cp.ctx, cancel = context.WithCancel(cp.ctx) - defer cancel() + h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) +} - // Cancel the request context after timeout and send an error response. Since the - // running method might not return immediately on timeout, we must wait for the - // timeout concurrently with processing the request. - if timeout, ok := ContextRequestTimeout(cp.ctx); ok { - timer = time.AfterFunc(timeout, func() { - cancel() - responded.Do(func() { - resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) - h.conn.writeJSON(cp.ctx, resp, true) - }) - }) - } - - answer := h.handleCallMsg(cp, msg) - if timer != nil { - timer.Stop() - } - h.addSubscriptions(cp.notifiers) - if answer != nil { - responded.Do(func() { - h.conn.writeJSON(cp.ctx, answer, false) - }) - } - for _, n := range cp.notifiers { - n.activate() - } +// handleMsg handles a single non-batch message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + msgs := []*jsonrpcMessage{msg} + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + h.startCallProc(func(cp *callProc) { + h.handleNonBatchCall(cp, msg) + }) }) } +func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } + h.addSubscriptions(cp.notifiers) + if answer != nil { + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) + } + for _, n := range cp.notifiers { + n.activate() + } +} + // close cancels all requests except for inflightReq and waits for // call goroutines to shut down. func (h *handler) close(err error, inflightReq *requestOp) { @@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) { }() } -// handleImmediate executes non-call messages. It returns false if the message is a -// call or requires a reply. -func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { - start := time.Now() - switch { - case msg.isNotification(): - if strings.HasSuffix(msg.Method, notificationMethodSuffix) { - h.handleSubscriptionResult(msg) - return true +// handleResponse processes method call responses. +func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) { + var resolvedops []*requestOp + handleResp := func(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return } - return false - case msg.isResponse(): - h.handleResponse(msg) - h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) - return true - default: - return false + resolvedops = append(resolvedops, op) + delete(h.respWait, string(msg.ID)) + + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + if op.sub != nil { + if msg.Error != nil { + op.err = msg.Error + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + if op.err == nil { + go op.sub.run() + h.clientSubs[op.sub.subid] = op.sub + } + } + } + + if !op.hadResponse { + op.hadResponse = true + op.resp <- batch + } + } + + for _, msg := range batch { + start := time.Now() + switch { + case msg.isResponse(): + handleResp(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + + case msg.isNotification(): + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + continue + } + handleCall(msg) + + default: + handleCall(msg) + } + } + + for _, op := range resolvedops { + h.removeRequestOp(op) } } @@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { } } -// handleResponse processes method call responses. -func (h *handler) handleResponse(msg *jsonrpcMessage) { - op := h.respWait[string(msg.ID)] - if op == nil { - h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) - return - } - delete(h.respWait, string(msg.ID)) - // For normal responses, just forward the reply to Call/BatchCall. - if op.sub == nil { - op.resp <- msg - return - } - // For subscription responses, start the subscription if the server - // indicates success. EthSubscribe gets unblocked in either case through - // the op.resp channel. - defer close(op.resp) - if msg.Error != nil { - op.err = msg.Error - return - } - if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { - go op.sub.run() - h.clientSubs[op.sub.subid] = op.sub - } -} - // handleCallMsg executes a call message and returns the answer. func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { start := time.Now() @@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.handleCall(ctx, msg) h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) return nil + case msg.isCall(): resp := h.handleCall(ctx, msg) var ctx []interface{} @@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.log.Debug("Served "+msg.Method, ctx...) } return resp + case msg.hasValidID(): return msg.errorResponse(&invalidRequestError{"invalid request"}) + default: return errorMessage(&invalidRequestError{"invalid request"}) } @@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage if callb == nil { return msg.errorResponse(&methodNotFoundError{method: msg.Method}) } + args, err := parsePositionalArguments(msg.Params, callb.argTypes) if err != nil { return msg.errorResponse(&invalidParamsError{err.Error()}) } start := time.Now() answer := h.runMethod(cp.ctx, msg, callb, args) + // Collect the statistics for RPC calls if metrics is enabled. // We only care about pure rpc call. Filter out subscription. if callb != h.unsubscribeCb { @@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage rpcServingTimer.UpdateSince(start) updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start)) } + return answer } diff --git a/rpc/http.go b/rpc/http.go index 8712f9961..741fa1c0e 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { var cfg clientConfig cfg.httpClient = client fn := newClientTransportHTTP(endpoint, &cfg) - return newClient(context.Background(), fn) + return newClient(context.Background(), &cfg, fn) } func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { @@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e } defer respBody.Close() - var respmsg jsonrpcMessage - if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil { + var resp jsonrpcMessage + batch := [1]*jsonrpcMessage{&resp} + if err := json.NewDecoder(respBody).Decode(&resp); err != nil { return err } - op.resp <- &respmsg + op.resp <- batch[:] return nil } @@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr return err } defer respBody.Close() - var respmsgs []jsonrpcMessage + + var respmsgs []*jsonrpcMessage if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { return err } - if len(respmsgs) != len(msgs) { - return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult) - } - for i := 0; i < len(respmsgs); i++ { - op.resp <- &respmsgs[i] - } + op.resp <- respmsgs return nil } diff --git a/rpc/inproc.go b/rpc/inproc.go index fbe9a40ce..306974e04 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -24,7 +24,8 @@ import ( // DialInProc attaches an in-process connection to the given RPC server. func DialInProc(handler *Server) *Client { initctx := context.Background() - c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() go handler.ServeCodec(NewCodec(p1), 0) return NewCodec(p2), nil diff --git a/rpc/ipc.go b/rpc/ipc.go index d9e0de62e..a08245b27 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, newClientTransportIPC(endpoint)) + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIPC(endpoint)) } func newClientTransportIPC(endpoint string) reconnectFunc { diff --git a/rpc/server.go b/rpc/server.go index 089bbb1fd..2742adf07 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -46,9 +46,11 @@ type Server struct { services serviceRegistry idgen func() ID - mutex sync.Mutex - codecs map[ServerCodec]struct{} - run atomic.Bool + mutex sync.Mutex + codecs map[ServerCodec]struct{} + run atomic.Bool + batchItemLimit int + batchResponseLimit int } // NewServer creates a new server instance with no registered handlers. @@ -65,6 +67,17 @@ func NewServer() *Server { return server } +// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit' +// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of +// response bytes across all requests in a batch. +// +// This method should be called before processing any requests via ServeCodec, ServeHTTP, +// ServeListener etc. +func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { + s.batchItemLimit = itemLimit + s.batchResponseLimit = maxResponseSize +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either a RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the @@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { } defer s.untrackCodec(codec) - c := initClient(codec, s.idgen, &s.services) + cfg := &clientConfig{ + idgen: s.idgen, + batchItemLimit: s.batchItemLimit, + batchResponseLimit: s.batchResponseLimit, + } + c := initClient(codec, &s.services, cfg) <-codec.closed() c.Close() } @@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { return } - h := newHandler(ctx, codec, s.idgen, &s.services) + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit) h.allowSubscribe = false defer h.close(io.EOF, nil) diff --git a/rpc/server_test.go b/rpc/server_test.go index f1a9b3d5c..5d3929dfd 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -70,6 +70,7 @@ func TestServer(t *testing.T) { func runTestScript(t *testing.T, file string) { server := newTestServer() + server.SetBatchLimits(4, 100000) content, err := os.ReadFile(file) if err != nil { t.Fatal(err) @@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) { } } } + +func TestServerBatchResponseSizeLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(100, 60) + var ( + batch []BatchElem + client = DialInProc(server) + ) + for i := 0; i < 5; i++ { + batch = append(batch, BatchElem{ + Method: "test_echo", + Args: []any{"x", 1}, + Result: new(echoResult), + }) + } + if err := client.BatchCall(batch); err != nil { + t.Fatal("error sending batch:", err) + } + for i := range batch { + // We expect the first two queries to be ok, but after that the size limit takes effect. + if i < 2 { + if batch[i].Error != nil { + t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error) + } + continue + } + // After two, we expect an error. + re, ok := batch[i].Error.(Error) + if !ok { + t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error) + } + wantedCode := errcodeResponseTooLarge + if re.ErrorCode() != wantedCode { + t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode) + } + } +} diff --git a/rpc/stdio.go b/rpc/stdio.go index ae32db26e..084e5f070 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) { // DialIO creates a client which uses the given IO channels func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { - return newClient(ctx, newClientTransportIO(in, out)) + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIO(in, out)) } func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { diff --git a/rpc/testdata/invalid-batch-toolarge.js b/rpc/testdata/invalid-batch-toolarge.js new file mode 100644 index 000000000..218fea58a --- /dev/null +++ b/rpc/testdata/invalid-batch-toolarge.js @@ -0,0 +1,13 @@ +// This file checks the behavior of the batch item limit code. +// In tests, the batch item limit is set to 4. So to trigger the error, +// all batches in this file have 5 elements. + +// For batches that do not contain any calls, a response message with "id" == null +// is returned. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}] + +// For batches with at least one call, the call's "id" is used. +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}] diff --git a/rpc/websocket.go b/rpc/websocket.go index 889562d1a..b1213fdfa 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale if err != nil { return nil, err } - return newClient(ctx, connect) + return newClient(ctx, cfg, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error if err != nil { return nil, err } - return newClient(ctx, connect) + return newClient(ctx, cfg, connect) } func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {