diff --git a/cmd/clef/main.go b/cmd/clef/main.go index cebe74797..14f09dc1a 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -732,6 +732,7 @@ func signer(c *cli.Context) error { cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name)) srv := rpc.NewServer() + srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize) err := node.RegisterApis(rpcAPI, []string{"account"}, srv) if err != nil { utils.Fatalf("Could not register API: %w", err) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 2289a72a1..2794e37e3 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -168,6 +168,8 @@ var ( utils.RPCGlobalEVMTimeoutFlag, utils.RPCGlobalTxFeeCapFlag, utils.AllowUnprotectedTxs, + utils.BatchRequestLimit, + utils.BatchResponseMaxSize, } metricsFlags = []cli.Flag{ diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 1ebc998a4..692970477 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -713,6 +713,18 @@ var ( Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC", Category: flags.APICategory, } + BatchRequestLimit = &cli.IntFlag{ + Name: "rpc.batch-request-limit", + Usage: "Maximum number of requests in a batch", + Value: node.DefaultConfig.BatchRequestLimit, + Category: flags.APICategory, + } + BatchResponseMaxSize = &cli.IntFlag{ + Name: "rpc.batch-response-max-size", + Usage: "Maximum number of bytes returned from a batched call", + Value: node.DefaultConfig.BatchResponseMaxSize, + Category: flags.APICategory, + } EnablePersonal = &cli.BoolFlag{ Name: "rpc.enabledeprecatedpersonal", Usage: "Enables the (deprecated) personal namespace", @@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { if ctx.IsSet(AllowUnprotectedTxs.Name) { cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name) } + + if ctx.IsSet(BatchRequestLimit.Name) { + cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name) + } + + if ctx.IsSet(BatchResponseMaxSize.Name) { + cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name) + } } // setGraphQL creates the GraphQL listener interface string from the set diff --git a/node/api.go b/node/api.go index 15892a270..f81f394be 100644 --- a/node/api.go +++ b/node/api.go @@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri CorsAllowedOrigins: api.node.config.HTTPCors, Vhosts: api.node.config.HTTPVirtualHosts, Modules: api.node.config.HTTPModules, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if cors != nil { config.CorsAllowedOrigins = nil @@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap Modules: api.node.config.WSModules, Origins: api.node.config.WSOrigins, // ExposeAll: api.node.config.WSExposeAll, + rpcEndpointConfig: rpcEndpointConfig{ + batchItemLimit: api.node.config.BatchRequestLimit, + batchResponseSizeLimit: api.node.config.BatchResponseMaxSize, + }, } if apis != nil { config.Modules = nil diff --git a/node/config.go b/node/config.go index 1765811e8..37c1e4882 100644 --- a/node/config.go +++ b/node/config.go @@ -197,6 +197,12 @@ type Config struct { // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. AllowUnprotectedTxs bool `toml:",omitempty"` + // BatchRequestLimit is the maximum number of requests in a batch. + BatchRequestLimit int `toml:",omitempty"` + + // BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call. + BatchResponseMaxSize int `toml:",omitempty"` + // JWTSecret is the path to the hex-encoded jwt secret. JWTSecret string `toml:",omitempty"` diff --git a/node/defaults.go b/node/defaults.go index fcfbc934b..d8f718121 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -46,17 +46,19 @@ var ( // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ - DataDir: DefaultDataDir(), - HTTPPort: DefaultHTTPPort, - AuthAddr: DefaultAuthHost, - AuthPort: DefaultAuthPort, - AuthVirtualHosts: DefaultAuthVhosts, - HTTPModules: []string{"net", "web3"}, - HTTPVirtualHosts: []string{"localhost"}, - HTTPTimeouts: rpc.DefaultHTTPTimeouts, - WSPort: DefaultWSPort, - WSModules: []string{"net", "web3"}, - GraphQLVirtualHosts: []string{"localhost"}, + DataDir: DefaultDataDir(), + HTTPPort: DefaultHTTPPort, + AuthAddr: DefaultAuthHost, + AuthPort: DefaultAuthPort, + AuthVirtualHosts: DefaultAuthVhosts, + HTTPModules: []string{"net", "web3"}, + HTTPVirtualHosts: []string{"localhost"}, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, + WSPort: DefaultWSPort, + WSModules: []string{"net", "web3"}, + BatchRequestLimit: 1000, + BatchResponseMaxSize: 25 * 1000 * 1000, + GraphQLVirtualHosts: []string{"localhost"}, P2P: p2p.Config{ ListenAddr: ":30303", MaxPeers: 50, diff --git a/node/node.go b/node/node.go index e8494ac3b..553d451ab 100644 --- a/node/node.go +++ b/node/node.go @@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) { if strings.HasSuffix(conf.Name, ".ipc") { return nil, errors.New(`Config.Name cannot end in ".ipc"`) } - + server := rpc.NewServer() + server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize) node := &Node{ config: conf, - inprocHandler: rpc.NewServer(), + inprocHandler: server, eventmux: new(event.TypeMux), log: conf.Logger, stop: make(chan struct{}), @@ -403,6 +404,11 @@ func (n *Node) startRPC() error { openAPIs, allAPIs = n.getAPIs() ) + rpcConfig := rpcEndpointConfig{ + batchItemLimit: n.config.BatchRequestLimit, + batchResponseSizeLimit: n.config.BatchResponseMaxSize, + } + initHttp := func(server *httpServer, port int) error { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { return err @@ -412,6 +418,7 @@ func (n *Node) startRPC() error { Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + rpcEndpointConfig: rpcConfig, }); err != nil { return err } @@ -425,9 +432,10 @@ func (n *Node) startRPC() error { return err } if err := server.enableWS(openAPIs, wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, + rpcEndpointConfig: rpcConfig, }); err != nil { return err } @@ -441,26 +449,29 @@ func (n *Node) startRPC() error { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } + sharedConfig := rpcConfig + sharedConfig.jwtSecret = secret if err := server.enableRPC(allAPIs, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, Modules: DefaultAuthModules, prefix: DefaultAuthPrefix, - jwtSecret: secret, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } servers = append(servers, server) + // Enable auth via WS server = n.wsServerForPort(port, true) if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } if err := server.enableWS(allAPIs, wsConfig{ - Modules: DefaultAuthModules, - Origins: DefaultAuthOrigins, - prefix: DefaultAuthPrefix, - jwtSecret: secret, + Modules: DefaultAuthModules, + Origins: DefaultAuthOrigins, + prefix: DefaultAuthPrefix, + rpcEndpointConfig: sharedConfig, }); err != nil { return err } diff --git a/node/rpcstack.go b/node/rpcstack.go index 97d591642..e91585a2b 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -41,15 +41,21 @@ type httpConfig struct { CorsAllowedOrigins []string Vhosts []string prefix string // path prefix on which to mount http handler - jwtSecret []byte // optional JWT secret + rpcEndpointConfig } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler - jwtSecret []byte // optional JWT secret + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + rpcEndpointConfig +} + +type rpcEndpointConfig struct { + jwtSecret []byte // optional JWT secret + batchItemLimit int + batchResponseSizeLimit int } type rpcHandler struct { @@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } @@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } // Create RPC server and handler. srv := rpc.NewServer() + srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 0790dddec..e41cc51ad 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -339,8 +339,10 @@ func TestJWT(t *testing.T) { ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) return ss } - srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, - true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil) + cfg := rpcEndpointConfig{jwtSecret: []byte("secret")} + httpcfg := &httpConfig{rpcEndpointConfig: cfg} + wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg} + srv := createAndStartServer(t, httpcfg, true, wscfg, nil) wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) diff --git a/rpc/client.go b/rpc/client.go index fae8536b2..c3114ef1d 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -34,14 +34,15 @@ import ( var ( ErrBadResult = errors.New("bad result in JSON-RPC response") ErrClientQuit = errors.New("client is closed") - ErrNoResult = errors.New("no result in JSON-RPC response") + ErrNoResult = errors.New("JSON-RPC response has no result") + ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") errClientReconnected = errors.New("client reconnected") errDead = errors.New("connection lost") ) +// Timeouts const ( - // Timeouts defaultDialTimeout = 10 * time.Second // used if context has no deadline subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) @@ -84,6 +85,10 @@ type Client struct { // This function, if non-nil, is called when the connection is lost. reconnectFunc reconnectFunc + // config fields + batchItemLimit int + batchResponseMaxSize int + // writeConn is used for writing to the connection on the caller's goroutine. It should // only be accessed outside of dispatch, with the write lock held. The write lock is // taken by sending on reqInit and released by sending on reqSent. @@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.Background() ctx = context.WithValue(ctx, clientContextKey{}, c) ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) - handler := newHandler(ctx, conn, c.idgen, c.services) + handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize) return &clientConn{conn, handler} } @@ -128,14 +133,17 @@ type readOp struct { batch bool } +// requestOp represents a pending request. This is used for both batch and non-batch +// requests. type requestOp struct { - ids []json.RawMessage - err error - resp chan *jsonrpcMessage // receives up to len(ids) responses - sub *ClientSubscription // only set for EthSubscribe requests + ids []json.RawMessage + err error + resp chan []*jsonrpcMessage // the response goes here + sub *ClientSubscription // set for Subscribe requests. + hadResponse bool // true when the request was responded to } -func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { +func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) { select { case <-ctx.Done(): // Send the timeout to dispatch so it can remove the request IDs. @@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (* return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } - return newClient(ctx, reconnect) + return newClient(ctx, cfg, reconnect) } // ClientFromContext retrieves the client from the context, if any. This can be used to perform @@ -221,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) { return client, ok } -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { +func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) { conn, err := connect(initctx) if err != nil { return nil, err } - c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) + c := initClient(conn, new(serviceRegistry), cfg) c.reconnectFunc = connect return c, nil } -func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { +func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client { _, isHTTP := conn.(*httpConn) c := &Client{ - isHTTP: isHTTP, - idgen: idgen, - services: services, - writeConn: conn, - close: make(chan struct{}), - closing: make(chan struct{}), - didClose: make(chan struct{}), - reconnected: make(chan ServerCodec), - readOp: make(chan readOp), - readErr: make(chan error), - reqInit: make(chan *requestOp), - reqSent: make(chan error, 1), - reqTimeout: make(chan *requestOp), + isHTTP: isHTTP, + services: services, + idgen: cfg.idgen, + batchItemLimit: cfg.batchItemLimit, + batchResponseMaxSize: cfg.batchResponseLimit, + writeConn: conn, + close: make(chan struct{}), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), + readErr: make(chan error), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), } + + // Set defaults. + if c.idgen == nil { + c.idgen = randomIDGenerator() + } + + // Launch the main loop. if !isHTTP { go c.dispatch(conn) } @@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str if err != nil { return err } - op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + } if c.isHTTP { err = c.sendHTTP(ctx, op, msg) @@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } // dispatch has accepted the request and will close the channel when it quits. - switch resp, err := op.wait(ctx, c); { - case err != nil: + batchresp, err := op.wait(ctx, c) + if err != nil { return err + } + resp := batchresp[0] + switch { case resp.Error != nil: return resp.Error case len(resp.Result) == 0: @@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { ) op := &requestOp{ ids: make([]json.RawMessage, len(b)), - resp: make(chan *jsonrpcMessage, len(b)), + resp: make(chan []*jsonrpcMessage, 1), } for i, elem := range b { msg, err := c.newMessage(elem.Method, elem.Args...) @@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { } else { err = c.send(ctx, op, msgs) } + if err != nil { + return err + } + + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } // Wait for all responses to come back. - for n := 0; n < len(b) && err == nil; n++ { - var resp *jsonrpcMessage - resp, err = op.wait(ctx, c) - if err != nil { - break + for n := 0; n < len(batchresp) && err == nil; n++ { + resp := batchresp[n] + if resp == nil { + // Ignore null responses. These can happen for batches sent via HTTP. + continue } + // Find the element corresponding to this response. - // The element is guaranteed to be present because dispatch - // only sends valid IDs to our channel. - elem := &b[byID[string(resp.ID)]] - if resp.Error != nil { + index, ok := byID[string(resp.ID)] + if !ok { + continue + } + delete(byID, string(resp.ID)) + + // Assign result and error. + elem := &b[index] + switch { + case resp.Error != nil: elem.Error = resp.Error - continue - } - if len(resp.Result) == 0 { + case resp.Result == nil: elem.Error = ErrNoResult - continue + default: + elem.Error = json.Unmarshal(resp.Result, elem.Result) } - elem.Error = json.Unmarshal(resp.Result, elem.Result) } + + // Check that all expected responses have been received. + for _, index := range byID { + elem := &b[index] + elem.Error = ErrMissingBatchResponse + } + return err } @@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf } op := &requestOp{ ids: []json.RawMessage{msg.ID}, - resp: make(chan *jsonrpcMessage), + resp: make(chan []*jsonrpcMessage, 1), sub: newClientSubscription(c, namespace, chanVal), } diff --git a/rpc/client_opt.go b/rpc/client_opt.go index 5ad7c22b3..5bef08cca 100644 --- a/rpc/client_opt.go +++ b/rpc/client_opt.go @@ -28,11 +28,18 @@ type ClientOption interface { } type clientConfig struct { + // HTTP settings httpClient *http.Client httpHeaders http.Header httpAuth HTTPAuth + // WebSocket options wsDialer *websocket.Dialer + + // RPC handler options + idgen func() ID + batchItemLimit int + batchResponseLimit int } func (cfg *clientConfig) initHeaders() { @@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption { // Usually, HTTPAuth functions will call h.Set("authorization", "...") to add // auth information to the request. type HTTPAuth func(h http.Header) error + +// WithBatchItemLimit changes the maximum number of items allowed in batch requests. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchItemLimit(limit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchItemLimit = limit + }) +} + +// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be +// generated for batch requests. When this limit is reached, further calls in the batch +// will not be processed. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchResponseSizeLimit(sizeLimit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchResponseLimit = sizeLimit + }) +} diff --git a/rpc/client_test.go b/rpc/client_test.go index a94a54929..7c96b2d66 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) { } } +// This checks that, for HTTP connections, the length of batch responses is validated to +// match the request exactly. func TestClientBatchRequest_len(t *testing.T) { b, err := json.Marshal([]jsonrpcMessage{ - {Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)}, - {Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)}, + {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)}, }) if err != nil { t.Fatal("failed to encode jsonrpc message:", err) @@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) { })) t.Cleanup(s.Close) - client, err := Dial(s.URL) - if err != nil { - t.Fatal("failed to dial test server:", err) - } - defer client.Close() - t.Run("too-few", func(t *testing.T) { + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + batch := []BatchElem{ - {Method: "foo"}, - {Method: "bar"}, - {Method: "baz"}, + {Method: "foo", Result: new(string)}, + {Method: "bar", Result: new(string)}, + {Method: "baz", Result: new(string)}, } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { - t.Errorf("expected %q but got: %v", ErrBadResult, err) + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:2] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[2:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } } }) t.Run("too-many", func(t *testing.T) { + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + batch := []BatchElem{ - {Method: "foo"}, + {Method: "foo", Result: new(string)}, } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { - t.Errorf("expected %q but got: %v", ErrBadResult, err) + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:1] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } } }) } +// This checks that the client can handle the case where the server doesn't +// respond to all requests in a batch. +func TestClientBatchRequestLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(2, 100000) + client := DialInProc(server) + + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + err := client.BatchCall(batch) + if err != nil { + t.Fatal("unexpected error:", err) + } + + // Check that the first response indicates an error with batch size. + var err0 Error + if !errors.As(batch[0].Error, &err0) { + t.Log("error zero:", batch[0].Error) + t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error) + } else { + if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge { + t.Fatalf("wrong error on batch elem zero: %v", err0) + } + } + + // Check that remaining response batch elements are reported as absent. + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error) + } + } +} + func TestClientNotify(t *testing.T) { server := newTestServer() defer server.Stop() @@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) { _, hasDeadline := ctx.Deadline() t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) // default: - // t.Logf("got expected error with %v wait time: %v", timeout, err) + // t.Logf("got expected error with %v wait time: %v", timeout, err) } cancel() } @@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { defer srv.Stop() // Create the client on the other end of the pipe. - client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { return NewCodec(p2), nil }) defer client.Close() diff --git a/rpc/errors.go b/rpc/errors.go index 7188332d5..abb698af7 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -61,12 +61,15 @@ const ( errcodeDefault = -32000 errcodeNotificationsUnsupported = -32001 errcodeTimeout = -32002 + errcodeResponseTooLarge = -32003 errcodePanic = -32603 errcodeMarshalError = -32603 ) const ( - errMsgTimeout = "request timed out" + errMsgTimeout = "request timed out" + errMsgResponseTooLarge = "response too large" + errMsgBatchTooLarge = "batch too large" ) type methodNotFoundError struct{ method string } diff --git a/rpc/handler.go b/rpc/handler.go index c2e7d7dc0..4f48c7931 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -49,17 +49,19 @@ import ( // h.removeRequestOp(op) // timeout, etc. // } type handler struct { - reg *serviceRegistry - unsubscribeCb *callback - idgen func() ID // subscription ID generator - respWait map[string]*requestOp // active client requests - clientSubs map[string]*ClientSubscription // active client subscriptions - callWG sync.WaitGroup // pending call goroutines - rootCtx context.Context // canceled by close() - cancelRoot func() // cancel function for rootCtx - conn jsonWriter // where responses will be sent - log log.Logger - allowSubscribe bool + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + batchRequestLimit int + batchResponseMaxSize int subLock sync.Mutex serverSubs map[ID]*Subscription @@ -70,19 +72,21 @@ type callProc struct { notifiers []*Notifier } -func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler { rootCtx, cancelRoot := context.WithCancel(connCtx) h := &handler{ - reg: reg, - idgen: idgen, - conn: conn, - respWait: make(map[string]*requestOp), - clientSubs: make(map[string]*ClientSubscription), - rootCtx: rootCtx, - cancelRoot: cancelRoot, - allowSubscribe: true, - serverSubs: make(map[ID]*Subscription), - log: log.Root(), + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + batchRequestLimit: batchRequestLimit, + batchResponseMaxSize: batchResponseMaxSize, } if conn.remoteAddr() != "" { h.log = h.log.New("conn", conn.remoteAddr()) @@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { b.doWrite(ctx, conn, false) } -// timeout sends the responses added so far. For the remaining unanswered call -// messages, it sends a timeout error response. -func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) { +// respondWithError sends the responses added so far. For the remaining unanswered call +// messages, it responds with the given error. +func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) { b.mutex.Lock() defer b.mutex.Unlock() for _, msg := range b.calls { if !msg.isNotification() { - resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) - b.resp = append(b.resp, resp) + b.resp = append(b.resp, msg.errorResponse(err)) } } b.doWrite(ctx, conn, true) @@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { }) return } - - // Handle non-call messages first: - calls := make([]*jsonrpcMessage, 0, len(msgs)) - for _, msg := range msgs { - if handled := h.handleImmediate(msg); !handled { - calls = append(calls, msg) - } + // Apply limit on total number of requests. + if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit { + h.startCallProc(func(cp *callProc) { + h.respondWithBatchTooLarge(cp, msgs) + }) + return } + + // Handle non-call messages first. + // Here we need to find the requestOp that sent the request batch. + calls := make([]*jsonrpcMessage, 0, len(msgs)) + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + calls = append(calls, msg) + }) if len(calls) == 0 { return } + // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { var ( @@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { if timeout, ok := ContextRequestTimeout(cp.ctx); ok { timer = time.AfterFunc(timeout, func() { cancel() - callBuffer.timeout(cp.ctx, h.conn) + err := &internalServerError{errcodeTimeout, errMsgTimeout} + callBuffer.respondWithError(cp.ctx, h.conn, err) }) } + responseBytes := 0 for { // No need to handle rest of calls if timed out. if cp.ctx.Err() != nil { @@ -214,61 +226,88 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } resp := h.handleCallMsg(cp, msg) callBuffer.pushResponse(resp) + if resp != nil && h.batchResponseMaxSize != 0 { + responseBytes += len(resp.Result) + if responseBytes > h.batchResponseMaxSize { + err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge} + callBuffer.respondWithError(cp.ctx, h.conn, err) + break + } + } } if timer != nil { timer.Stop() } - callBuffer.write(cp.ctx, h.conn) + h.addSubscriptions(cp.notifiers) + callBuffer.write(cp.ctx, h.conn) for _, n := range cp.notifiers { n.activate() } }) } -// handleMsg handles a single message. -func (h *handler) handleMsg(msg *jsonrpcMessage) { - if ok := h.handleImmediate(msg); ok { - return +func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) { + resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge}) + // Find the first call and add its "id" field to the error. + // This is the best we can do, given that the protocol doesn't have a way + // of reporting an error for the entire batch. + for _, msg := range batch { + if msg.isCall() { + resp.ID = msg.ID + break + } } - h.startCallProc(func(cp *callProc) { - var ( - responded sync.Once - timer *time.Timer - cancel context.CancelFunc - ) - cp.ctx, cancel = context.WithCancel(cp.ctx) - defer cancel() + h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) +} - // Cancel the request context after timeout and send an error response. Since the - // running method might not return immediately on timeout, we must wait for the - // timeout concurrently with processing the request. - if timeout, ok := ContextRequestTimeout(cp.ctx); ok { - timer = time.AfterFunc(timeout, func() { - cancel() - responded.Do(func() { - resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) - h.conn.writeJSON(cp.ctx, resp, true) - }) - }) - } - - answer := h.handleCallMsg(cp, msg) - if timer != nil { - timer.Stop() - } - h.addSubscriptions(cp.notifiers) - if answer != nil { - responded.Do(func() { - h.conn.writeJSON(cp.ctx, answer, false) - }) - } - for _, n := range cp.notifiers { - n.activate() - } +// handleMsg handles a single non-batch message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + msgs := []*jsonrpcMessage{msg} + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + h.startCallProc(func(cp *callProc) { + h.handleNonBatchCall(cp, msg) + }) }) } +func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } + h.addSubscriptions(cp.notifiers) + if answer != nil { + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) + } + for _, n := range cp.notifiers { + n.activate() + } +} + // close cancels all requests except for inflightReq and waits for // call goroutines to shut down. func (h *handler) close(err error, inflightReq *requestOp) { @@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) { }() } -// handleImmediate executes non-call messages. It returns false if the message is a -// call or requires a reply. -func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { - start := time.Now() - switch { - case msg.isNotification(): - if strings.HasSuffix(msg.Method, notificationMethodSuffix) { - h.handleSubscriptionResult(msg) - return true +// handleResponse processes method call responses. +func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) { + var resolvedops []*requestOp + handleResp := func(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return } - return false - case msg.isResponse(): - h.handleResponse(msg) - h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) - return true - default: - return false + resolvedops = append(resolvedops, op) + delete(h.respWait, string(msg.ID)) + + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + if op.sub != nil { + if msg.Error != nil { + op.err = msg.Error + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + if op.err == nil { + go op.sub.run() + h.clientSubs[op.sub.subid] = op.sub + } + } + } + + if !op.hadResponse { + op.hadResponse = true + op.resp <- batch + } + } + + for _, msg := range batch { + start := time.Now() + switch { + case msg.isResponse(): + handleResp(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + + case msg.isNotification(): + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + continue + } + handleCall(msg) + + default: + handleCall(msg) + } + } + + for _, op := range resolvedops { + h.removeRequestOp(op) } } @@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { } } -// handleResponse processes method call responses. -func (h *handler) handleResponse(msg *jsonrpcMessage) { - op := h.respWait[string(msg.ID)] - if op == nil { - h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) - return - } - delete(h.respWait, string(msg.ID)) - // For normal responses, just forward the reply to Call/BatchCall. - if op.sub == nil { - op.resp <- msg - return - } - // For subscription responses, start the subscription if the server - // indicates success. EthSubscribe gets unblocked in either case through - // the op.resp channel. - defer close(op.resp) - if msg.Error != nil { - op.err = msg.Error - return - } - if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { - go op.sub.run() - h.clientSubs[op.sub.subid] = op.sub - } -} - // handleCallMsg executes a call message and returns the answer. func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { start := time.Now() @@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.handleCall(ctx, msg) h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) return nil + case msg.isCall(): resp := h.handleCall(ctx, msg) var ctx []interface{} @@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess h.log.Debug("Served "+msg.Method, ctx...) } return resp + case msg.hasValidID(): return msg.errorResponse(&invalidRequestError{"invalid request"}) + default: return errorMessage(&invalidRequestError{"invalid request"}) } @@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage if callb == nil { return msg.errorResponse(&methodNotFoundError{method: msg.Method}) } + args, err := parsePositionalArguments(msg.Params, callb.argTypes) if err != nil { return msg.errorResponse(&invalidParamsError{err.Error()}) } start := time.Now() answer := h.runMethod(cp.ctx, msg, callb, args) + // Collect the statistics for RPC calls if metrics is enabled. // We only care about pure rpc call. Filter out subscription. if callb != h.unsubscribeCb { @@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage rpcServingTimer.UpdateSince(start) updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start)) } + return answer } diff --git a/rpc/http.go b/rpc/http.go index 8712f9961..741fa1c0e 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { var cfg clientConfig cfg.httpClient = client fn := newClientTransportHTTP(endpoint, &cfg) - return newClient(context.Background(), fn) + return newClient(context.Background(), &cfg, fn) } func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { @@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e } defer respBody.Close() - var respmsg jsonrpcMessage - if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil { + var resp jsonrpcMessage + batch := [1]*jsonrpcMessage{&resp} + if err := json.NewDecoder(respBody).Decode(&resp); err != nil { return err } - op.resp <- &respmsg + op.resp <- batch[:] return nil } @@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr return err } defer respBody.Close() - var respmsgs []jsonrpcMessage + + var respmsgs []*jsonrpcMessage if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { return err } - if len(respmsgs) != len(msgs) { - return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult) - } - for i := 0; i < len(respmsgs); i++ { - op.resp <- &respmsgs[i] - } + op.resp <- respmsgs return nil } diff --git a/rpc/inproc.go b/rpc/inproc.go index fbe9a40ce..306974e04 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -24,7 +24,8 @@ import ( // DialInProc attaches an in-process connection to the given RPC server. func DialInProc(handler *Server) *Client { initctx := context.Background() - c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { + cfg := new(clientConfig) + c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() go handler.ServeCodec(NewCodec(p1), 0) return NewCodec(p2), nil diff --git a/rpc/ipc.go b/rpc/ipc.go index d9e0de62e..a08245b27 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, newClientTransportIPC(endpoint)) + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIPC(endpoint)) } func newClientTransportIPC(endpoint string) reconnectFunc { diff --git a/rpc/server.go b/rpc/server.go index 089bbb1fd..2742adf07 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -46,9 +46,11 @@ type Server struct { services serviceRegistry idgen func() ID - mutex sync.Mutex - codecs map[ServerCodec]struct{} - run atomic.Bool + mutex sync.Mutex + codecs map[ServerCodec]struct{} + run atomic.Bool + batchItemLimit int + batchResponseLimit int } // NewServer creates a new server instance with no registered handlers. @@ -65,6 +67,17 @@ func NewServer() *Server { return server } +// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit' +// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of +// response bytes across all requests in a batch. +// +// This method should be called before processing any requests via ServeCodec, ServeHTTP, +// ServeListener etc. +func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { + s.batchItemLimit = itemLimit + s.batchResponseLimit = maxResponseSize +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either a RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the @@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { } defer s.untrackCodec(codec) - c := initClient(codec, s.idgen, &s.services) + cfg := &clientConfig{ + idgen: s.idgen, + batchItemLimit: s.batchItemLimit, + batchResponseLimit: s.batchResponseLimit, + } + c := initClient(codec, &s.services, cfg) <-codec.closed() c.Close() } @@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { return } - h := newHandler(ctx, codec, s.idgen, &s.services) + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit) h.allowSubscribe = false defer h.close(io.EOF, nil) diff --git a/rpc/server_test.go b/rpc/server_test.go index f1a9b3d5c..5d3929dfd 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -70,6 +70,7 @@ func TestServer(t *testing.T) { func runTestScript(t *testing.T, file string) { server := newTestServer() + server.SetBatchLimits(4, 100000) content, err := os.ReadFile(file) if err != nil { t.Fatal(err) @@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) { } } } + +func TestServerBatchResponseSizeLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(100, 60) + var ( + batch []BatchElem + client = DialInProc(server) + ) + for i := 0; i < 5; i++ { + batch = append(batch, BatchElem{ + Method: "test_echo", + Args: []any{"x", 1}, + Result: new(echoResult), + }) + } + if err := client.BatchCall(batch); err != nil { + t.Fatal("error sending batch:", err) + } + for i := range batch { + // We expect the first two queries to be ok, but after that the size limit takes effect. + if i < 2 { + if batch[i].Error != nil { + t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error) + } + continue + } + // After two, we expect an error. + re, ok := batch[i].Error.(Error) + if !ok { + t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error) + } + wantedCode := errcodeResponseTooLarge + if re.ErrorCode() != wantedCode { + t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode) + } + } +} diff --git a/rpc/stdio.go b/rpc/stdio.go index ae32db26e..084e5f070 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) { // DialIO creates a client which uses the given IO channels func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { - return newClient(ctx, newClientTransportIO(in, out)) + cfg := new(clientConfig) + return newClient(ctx, cfg, newClientTransportIO(in, out)) } func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { diff --git a/rpc/testdata/invalid-batch-toolarge.js b/rpc/testdata/invalid-batch-toolarge.js new file mode 100644 index 000000000..218fea58a --- /dev/null +++ b/rpc/testdata/invalid-batch-toolarge.js @@ -0,0 +1,13 @@ +// This file checks the behavior of the batch item limit code. +// In tests, the batch item limit is set to 4. So to trigger the error, +// all batches in this file have 5 elements. + +// For batches that do not contain any calls, a response message with "id" == null +// is returned. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}] + +// For batches with at least one call, the call's "id" is used. +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}] diff --git a/rpc/websocket.go b/rpc/websocket.go index 889562d1a..b1213fdfa 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale if err != nil { return nil, err } - return newClient(ctx, connect) + return newClient(ctx, cfg, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error if err != nil { return nil, err } - return newClient(ctx, connect) + return newClient(ctx, cfg, connect) } func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {