rpc: add limit for batch request items and response size (#26681)

This PR adds server-side limits for JSON-RPC batch requests. Before this change, batches
were limited only by processing time. The server would pick calls from the batch and
answer them until the response timeout occurred, then stop processing the remaining batch
items.

Here, we are adding two additional limits which can be configured:

- the 'item limit': batches can have at most N items
- the 'response size limit': batches can contain at most X response bytes

These limits are optional in package rpc. In Geth, we set a default limit of 1000 items
and 25MB response size.

When a batch goes over the limit, an error response is returned to the client. However,
doing this correctly isn't always possible. In JSON-RPC, only method calls with a valid
`id` can be responded to. Since batches may also contain non-call messages or
notifications, the best effort thing we can do to report an error with the batch itself is
reporting the limit violation as an error for the first method call in the batch. If a batch is
too large, but contains only notifications and responses, the error will be reported with
a null `id`.

The RPC client was also changed so it can deal with errors resulting from too large
batches. An older client connected to the server code in this PR could get stuck
until the request timeout occurred when the batch is too large. **Upgrading to a version
of the RPC client containing this change is strongly recommended to avoid timeout issues.**

For some weird reason, when writing the original client implementation, @fjl worked off of
the assumption that responses could be distributed across batches arbitrarily. So for a
batch request containing requests `[A B C]`, the server could respond with `[A B C]` but
also with `[A B] [C]` or even `[A] [B] [C]` and it wouldn't make a difference to the
client.

So in the implementation of BatchCallContext, the client waited for all requests in the
batch individually. If the server didn't respond to some of the requests in the batch, the
client would eventually just time out (if a context was used).

With the addition of batch limits into the server, we anticipate that people will hit this
kind of error way more often. To handle this properly, the client now waits for a single
response batch and expects it to contain all responses to the requests.

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
mmsqe 2023-06-13 19:38:58 +08:00 committed by GitHub
parent 5ac4da3653
commit f3314bb6df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 557 additions and 229 deletions

View File

@ -732,6 +732,7 @@ func signer(c *cli.Context) error {
cors := utils.SplitAndTrim(c.String(utils.HTTPCORSDomainFlag.Name))
srv := rpc.NewServer()
srv.SetBatchLimits(node.DefaultConfig.BatchRequestLimit, node.DefaultConfig.BatchResponseMaxSize)
err := node.RegisterApis(rpcAPI, []string{"account"}, srv)
if err != nil {
utils.Fatalf("Could not register API: %w", err)

View File

@ -168,6 +168,8 @@ var (
utils.RPCGlobalEVMTimeoutFlag,
utils.RPCGlobalTxFeeCapFlag,
utils.AllowUnprotectedTxs,
utils.BatchRequestLimit,
utils.BatchResponseMaxSize,
}
metricsFlags = []cli.Flag{

View File

@ -713,6 +713,18 @@ var (
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
Category: flags.APICategory,
}
BatchRequestLimit = &cli.IntFlag{
Name: "rpc.batch-request-limit",
Usage: "Maximum number of requests in a batch",
Value: node.DefaultConfig.BatchRequestLimit,
Category: flags.APICategory,
}
BatchResponseMaxSize = &cli.IntFlag{
Name: "rpc.batch-response-max-size",
Usage: "Maximum number of bytes returned from a batched call",
Value: node.DefaultConfig.BatchResponseMaxSize,
Category: flags.APICategory,
}
EnablePersonal = &cli.BoolFlag{
Name: "rpc.enabledeprecatedpersonal",
Usage: "Enables the (deprecated) personal namespace",
@ -1130,6 +1142,14 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
if ctx.IsSet(AllowUnprotectedTxs.Name) {
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
}
if ctx.IsSet(BatchRequestLimit.Name) {
cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name)
}
if ctx.IsSet(BatchResponseMaxSize.Name) {
cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name)
}
}
// setGraphQL creates the GraphQL listener interface string from the set

View File

@ -176,6 +176,10 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri
CorsAllowedOrigins: api.node.config.HTTPCors,
Vhosts: api.node.config.HTTPVirtualHosts,
Modules: api.node.config.HTTPModules,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
}
if cors != nil {
config.CorsAllowedOrigins = nil
@ -250,6 +254,10 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap
Modules: api.node.config.WSModules,
Origins: api.node.config.WSOrigins,
// ExposeAll: api.node.config.WSExposeAll,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
}
if apis != nil {
config.Modules = nil

View File

@ -197,6 +197,12 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"`
// BatchRequestLimit is the maximum number of requests in a batch.
BatchRequestLimit int `toml:",omitempty"`
// BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call.
BatchResponseMaxSize int `toml:",omitempty"`
// JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"`

View File

@ -46,17 +46,19 @@ var (
// DefaultConfig contains reasonable default settings.
var DefaultConfig = Config{
DataDir: DefaultDataDir(),
HTTPPort: DefaultHTTPPort,
AuthAddr: DefaultAuthHost,
AuthPort: DefaultAuthPort,
AuthVirtualHosts: DefaultAuthVhosts,
HTTPModules: []string{"net", "web3"},
HTTPVirtualHosts: []string{"localhost"},
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"},
GraphQLVirtualHosts: []string{"localhost"},
DataDir: DefaultDataDir(),
HTTPPort: DefaultHTTPPort,
AuthAddr: DefaultAuthHost,
AuthPort: DefaultAuthPort,
AuthVirtualHosts: DefaultAuthVhosts,
HTTPModules: []string{"net", "web3"},
HTTPVirtualHosts: []string{"localhost"},
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"},
BatchRequestLimit: 1000,
BatchResponseMaxSize: 25 * 1000 * 1000,
GraphQLVirtualHosts: []string{"localhost"},
P2P: p2p.Config{
ListenAddr: ":30303",
MaxPeers: 50,

View File

@ -101,10 +101,11 @@ func New(conf *Config) (*Node, error) {
if strings.HasSuffix(conf.Name, ".ipc") {
return nil, errors.New(`Config.Name cannot end in ".ipc"`)
}
server := rpc.NewServer()
server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize)
node := &Node{
config: conf,
inprocHandler: rpc.NewServer(),
inprocHandler: server,
eventmux: new(event.TypeMux),
log: conf.Logger,
stop: make(chan struct{}),
@ -403,6 +404,11 @@ func (n *Node) startRPC() error {
openAPIs, allAPIs = n.getAPIs()
)
rpcConfig := rpcEndpointConfig{
batchItemLimit: n.config.BatchRequestLimit,
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
}
initHttp := func(server *httpServer, port int) error {
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
return err
@ -412,6 +418,7 @@ func (n *Node) startRPC() error {
Vhosts: n.config.HTTPVirtualHosts,
Modules: n.config.HTTPModules,
prefix: n.config.HTTPPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil {
return err
}
@ -425,9 +432,10 @@ func (n *Node) startRPC() error {
return err
}
if err := server.enableWS(openAPIs, wsConfig{
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil {
return err
}
@ -441,26 +449,29 @@ func (n *Node) startRPC() error {
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
return err
}
sharedConfig := rpcConfig
sharedConfig.jwtSecret = secret
if err := server.enableRPC(allAPIs, httpConfig{
CorsAllowedOrigins: DefaultAuthCors,
Vhosts: n.config.AuthVirtualHosts,
Modules: DefaultAuthModules,
prefix: DefaultAuthPrefix,
jwtSecret: secret,
rpcEndpointConfig: sharedConfig,
}); err != nil {
return err
}
servers = append(servers, server)
// Enable auth via WS
server = n.wsServerForPort(port, true)
if err := server.setListenAddr(n.config.AuthAddr, port); err != nil {
return err
}
if err := server.enableWS(allAPIs, wsConfig{
Modules: DefaultAuthModules,
Origins: DefaultAuthOrigins,
prefix: DefaultAuthPrefix,
jwtSecret: secret,
Modules: DefaultAuthModules,
Origins: DefaultAuthOrigins,
prefix: DefaultAuthPrefix,
rpcEndpointConfig: sharedConfig,
}); err != nil {
return err
}

View File

@ -41,15 +41,21 @@ type httpConfig struct {
CorsAllowedOrigins []string
Vhosts []string
prefix string // path prefix on which to mount http handler
jwtSecret []byte // optional JWT secret
rpcEndpointConfig
}
// wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct {
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
jwtSecret []byte // optional JWT secret
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
rpcEndpointConfig
}
type rpcEndpointConfig struct {
jwtSecret []byte // optional JWT secret
batchItemLimit int
batchResponseSizeLimit int
}
type rpcHandler struct {
@ -297,6 +303,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApis(apis, config.Modules, srv); err != nil {
return err
}
@ -328,6 +335,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
}
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApis(apis, config.Modules, srv); err != nil {
return err
}

View File

@ -339,8 +339,10 @@ func TestJWT(t *testing.T) {
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
return ss
}
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
cfg := rpcEndpointConfig{jwtSecret: []byte("secret")}
httpcfg := &httpConfig{rpcEndpointConfig: cfg}
wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg}
srv := createAndStartServer(t, httpcfg, true, wscfg, nil)
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())

View File

@ -34,14 +34,15 @@ import (
var (
ErrBadResult = errors.New("bad result in JSON-RPC response")
ErrClientQuit = errors.New("client is closed")
ErrNoResult = errors.New("no result in JSON-RPC response")
ErrNoResult = errors.New("JSON-RPC response has no result")
ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call")
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
errClientReconnected = errors.New("client reconnected")
errDead = errors.New("connection lost")
)
// Timeouts
const (
// Timeouts
defaultDialTimeout = 10 * time.Second // used if context has no deadline
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
)
@ -84,6 +85,10 @@ type Client struct {
// This function, if non-nil, is called when the connection is lost.
reconnectFunc reconnectFunc
// config fields
batchItemLimit int
batchResponseMaxSize int
// writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on reqInit and released by sending on reqSent.
@ -114,7 +119,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.Background()
ctx = context.WithValue(ctx, clientContextKey{}, c)
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
handler := newHandler(ctx, conn, c.idgen, c.services)
handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize)
return &clientConn{conn, handler}
}
@ -128,14 +133,17 @@ type readOp struct {
batch bool
}
// requestOp represents a pending request. This is used for both batch and non-batch
// requests.
type requestOp struct {
ids []json.RawMessage
err error
resp chan *jsonrpcMessage // receives up to len(ids) responses
sub *ClientSubscription // only set for EthSubscribe requests
ids []json.RawMessage
err error
resp chan []*jsonrpcMessage // the response goes here
sub *ClientSubscription // set for Subscribe requests.
hadResponse bool // true when the request was responded to
}
func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) {
select {
case <-ctx.Done():
// Send the timeout to dispatch so it can remove the request IDs.
@ -211,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
}
return newClient(ctx, reconnect)
return newClient(ctx, cfg, reconnect)
}
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
@ -221,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) {
return client, ok
}
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) {
conn, err := connect(initctx)
if err != nil {
return nil, err
}
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
c := initClient(conn, new(serviceRegistry), cfg)
c.reconnectFunc = connect
return c, nil
}
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client {
_, isHTTP := conn.(*httpConn)
c := &Client{
isHTTP: isHTTP,
idgen: idgen,
services: services,
writeConn: conn,
close: make(chan struct{}),
closing: make(chan struct{}),
didClose: make(chan struct{}),
reconnected: make(chan ServerCodec),
readOp: make(chan readOp),
readErr: make(chan error),
reqInit: make(chan *requestOp),
reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
isHTTP: isHTTP,
services: services,
idgen: cfg.idgen,
batchItemLimit: cfg.batchItemLimit,
batchResponseMaxSize: cfg.batchResponseLimit,
writeConn: conn,
close: make(chan struct{}),
closing: make(chan struct{}),
didClose: make(chan struct{}),
reconnected: make(chan ServerCodec),
readOp: make(chan readOp),
readErr: make(chan error),
reqInit: make(chan *requestOp),
reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
}
// Set defaults.
if c.idgen == nil {
c.idgen = randomIDGenerator()
}
// Launch the main loop.
if !isHTTP {
go c.dispatch(conn)
}
@ -325,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
if err != nil {
return err
}
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan []*jsonrpcMessage, 1),
}
if c.isHTTP {
err = c.sendHTTP(ctx, op, msg)
@ -337,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
}
// dispatch has accepted the request and will close the channel when it quits.
switch resp, err := op.wait(ctx, c); {
case err != nil:
batchresp, err := op.wait(ctx, c)
if err != nil {
return err
}
resp := batchresp[0]
switch {
case resp.Error != nil:
return resp.Error
case len(resp.Result) == 0:
@ -380,7 +403,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
)
op := &requestOp{
ids: make([]json.RawMessage, len(b)),
resp: make(chan *jsonrpcMessage, len(b)),
resp: make(chan []*jsonrpcMessage, 1),
}
for i, elem := range b {
msg, err := c.newMessage(elem.Method, elem.Args...)
@ -398,28 +421,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
} else {
err = c.send(ctx, op, msgs)
}
if err != nil {
return err
}
batchresp, err := op.wait(ctx, c)
if err != nil {
return err
}
// Wait for all responses to come back.
for n := 0; n < len(b) && err == nil; n++ {
var resp *jsonrpcMessage
resp, err = op.wait(ctx, c)
if err != nil {
break
for n := 0; n < len(batchresp) && err == nil; n++ {
resp := batchresp[n]
if resp == nil {
// Ignore null responses. These can happen for batches sent via HTTP.
continue
}
// Find the element corresponding to this response.
// The element is guaranteed to be present because dispatch
// only sends valid IDs to our channel.
elem := &b[byID[string(resp.ID)]]
if resp.Error != nil {
index, ok := byID[string(resp.ID)]
if !ok {
continue
}
delete(byID, string(resp.ID))
// Assign result and error.
elem := &b[index]
switch {
case resp.Error != nil:
elem.Error = resp.Error
continue
}
if len(resp.Result) == 0 {
case resp.Result == nil:
elem.Error = ErrNoResult
continue
default:
elem.Error = json.Unmarshal(resp.Result, elem.Result)
}
elem.Error = json.Unmarshal(resp.Result, elem.Result)
}
// Check that all expected responses have been received.
for _, index := range byID {
elem := &b[index]
elem.Error = ErrMissingBatchResponse
}
return err
}
@ -480,7 +523,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
}
op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan *jsonrpcMessage),
resp: make(chan []*jsonrpcMessage, 1),
sub: newClientSubscription(c, namespace, chanVal),
}

View File

@ -28,11 +28,18 @@ type ClientOption interface {
}
type clientConfig struct {
// HTTP settings
httpClient *http.Client
httpHeaders http.Header
httpAuth HTTPAuth
// WebSocket options
wsDialer *websocket.Dialer
// RPC handler options
idgen func() ID
batchItemLimit int
batchResponseLimit int
}
func (cfg *clientConfig) initHeaders() {
@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption {
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
// auth information to the request.
type HTTPAuth func(h http.Header) error
// WithBatchItemLimit changes the maximum number of items allowed in batch requests.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchItemLimit(limit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchItemLimit = limit
})
}
// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be
// generated for batch requests. When this limit is reached, further calls in the batch
// will not be processed.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchResponseSizeLimit(sizeLimit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchResponseLimit = sizeLimit
})
}

View File

@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) {
}
}
// This checks that, for HTTP connections, the length of batch responses is validated to
// match the request exactly.
func TestClientBatchRequest_len(t *testing.T) {
b, err := json.Marshal([]jsonrpcMessage{
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)},
{Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)},
})
if err != nil {
t.Fatal("failed to encode jsonrpc message:", err)
@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) {
}))
t.Cleanup(s.Close)
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
t.Run("too-few", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
{Method: "foo", Result: new(string)},
{Method: "bar", Result: new(string)},
{Method: "baz", Result: new(string)},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:2] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[2:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
}
})
t.Run("too-many", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{
{Method: "foo"},
{Method: "foo", Result: new(string)},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:1] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
}
})
}
// This checks that the client can handle the case where the server doesn't
// respond to all requests in a batch.
func TestClientBatchRequestLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(2, 100000)
client := DialInProc(server)
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
}
err := client.BatchCall(batch)
if err != nil {
t.Fatal("unexpected error:", err)
}
// Check that the first response indicates an error with batch size.
var err0 Error
if !errors.As(batch[0].Error, &err0) {
t.Log("error zero:", batch[0].Error)
t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error)
} else {
if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge {
t.Fatalf("wrong error on batch elem zero: %v", err0)
}
}
// Check that remaining response batch elements are reported as absent.
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error)
}
}
}
func TestClientNotify(t *testing.T) {
server := newTestServer()
defer server.Stop()
@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) {
_, hasDeadline := ctx.Deadline()
t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
// default:
// t.Logf("got expected error with %v wait time: %v", timeout, err)
// t.Logf("got expected error with %v wait time: %v", timeout, err)
}
cancel()
}
@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) {
defer srv.Stop()
// Create the client on the other end of the pipe.
client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) {
cfg := new(clientConfig)
client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) {
return NewCodec(p2), nil
})
defer client.Close()

View File

@ -61,12 +61,15 @@ const (
errcodeDefault = -32000
errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002
errcodeResponseTooLarge = -32003
errcodePanic = -32603
errcodeMarshalError = -32603
)
const (
errMsgTimeout = "request timed out"
errMsgTimeout = "request timed out"
errMsgResponseTooLarge = "response too large"
errMsgBatchTooLarge = "batch too large"
)
type methodNotFoundError struct{ method string }

View File

@ -49,17 +49,19 @@ import (
// h.removeRequestOp(op) // timeout, etc.
// }
type handler struct {
reg *serviceRegistry
unsubscribeCb *callback
idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
allowSubscribe bool
reg *serviceRegistry
unsubscribeCb *callback
idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
allowSubscribe bool
batchRequestLimit int
batchResponseMaxSize int
subLock sync.Mutex
serverSubs map[ID]*Subscription
@ -70,19 +72,21 @@ type callProc struct {
notifiers []*Notifier
}
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
h := &handler{
reg: reg,
idgen: idgen,
conn: conn,
respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx,
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
reg: reg,
idgen: idgen,
conn: conn,
respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx,
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
batchRequestLimit: batchRequestLimit,
batchResponseMaxSize: batchResponseMaxSize,
}
if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr())
@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.doWrite(ctx, conn, false)
}
// timeout sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
// respondWithError sends the responses added so far. For the remaining unanswered call
// messages, it responds with the given error.
func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) {
b.mutex.Lock()
defer b.mutex.Unlock()
for _, msg := range b.calls {
if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
b.resp = append(b.resp, resp)
b.resp = append(b.resp, msg.errorResponse(err))
}
}
b.doWrite(ctx, conn, true)
@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
})
return
}
// Handle non-call messages first:
calls := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range msgs {
if handled := h.handleImmediate(msg); !handled {
calls = append(calls, msg)
}
// Apply limit on total number of requests.
if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit {
h.startCallProc(func(cp *callProc) {
h.respondWithBatchTooLarge(cp, msgs)
})
return
}
// Handle non-call messages first.
// Here we need to find the requestOp that sent the request batch.
calls := make([]*jsonrpcMessage, 0, len(msgs))
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
calls = append(calls, msg)
})
if len(calls) == 0 {
return
}
// Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) {
var (
@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
callBuffer.timeout(cp.ctx, h.conn)
err := &internalServerError{errcodeTimeout, errMsgTimeout}
callBuffer.respondWithError(cp.ctx, h.conn, err)
})
}
responseBytes := 0
for {
// No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil {
@ -214,61 +226,88 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}
resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp)
if resp != nil && h.batchResponseMaxSize != 0 {
responseBytes += len(resp.Result)
if responseBytes > h.batchResponseMaxSize {
err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge}
callBuffer.respondWithError(cp.ctx, h.conn, err)
break
}
}
}
if timer != nil {
timer.Stop()
}
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers)
callBuffer.write(cp.ctx, h.conn)
for _, n := range cp.notifiers {
n.activate()
}
})
}
// handleMsg handles a single message.
func (h *handler) handleMsg(msg *jsonrpcMessage) {
if ok := h.handleImmediate(msg); ok {
return
func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) {
resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge})
// Find the first call and add its "id" field to the error.
// This is the best we can do, given that the protocol doesn't have a way
// of reporting an error for the entire batch.
for _, msg := range batch {
if msg.isCall() {
resp.ID = msg.ID
break
}
}
h.startCallProc(func(cp *callProc) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true)
}
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
}
// handleMsg handles a single non-batch message.
func (h *handler) handleMsg(msg *jsonrpcMessage) {
msgs := []*jsonrpcMessage{msg}
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
h.startCallProc(func(cp *callProc) {
h.handleNonBatchCall(cp, msg)
})
})
}
func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
}
}
// close cancels all requests except for inflightReq and waits for
// call goroutines to shut down.
func (h *handler) close(err error, inflightReq *requestOp) {
@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) {
}()
}
// handleImmediate executes non-call messages. It returns false if the message is a
// call or requires a reply.
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
start := time.Now()
switch {
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
return true
// handleResponse processes method call responses.
func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) {
var resolvedops []*requestOp
handleResp := func(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
return false
case msg.isResponse():
h.handleResponse(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
return true
default:
return false
resolvedops = append(resolvedops, op)
delete(h.respWait, string(msg.ID))
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
if op.sub != nil {
if msg.Error != nil {
op.err = msg.Error
} else {
op.err = json.Unmarshal(msg.Result, &op.sub.subid)
if op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
}
if !op.hadResponse {
op.hadResponse = true
op.resp <- batch
}
}
for _, msg := range batch {
start := time.Now()
switch {
case msg.isResponse():
handleResp(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
continue
}
handleCall(msg)
default:
handleCall(msg)
}
}
for _, op := range resolvedops {
h.removeRequestOp(op)
}
}
@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
}
}
// handleResponse processes method call responses.
func (h *handler) handleResponse(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
delete(h.respWait, string(msg.ID))
// For normal responses, just forward the reply to Call/BatchCall.
if op.sub == nil {
op.resp <- msg
return
}
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
defer close(op.resp)
if msg.Error != nil {
op.err = msg.Error
return
}
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
// handleCallMsg executes a call message and returns the answer.
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
start := time.Now()
@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.handleCall(ctx, msg)
h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
return nil
case msg.isCall():
resp := h.handleCall(ctx, msg)
var ctx []interface{}
@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.log.Debug("Served "+msg.Method, ctx...)
}
return resp
case msg.hasValidID():
return msg.errorResponse(&invalidRequestError{"invalid request"})
default:
return errorMessage(&invalidRequestError{"invalid request"})
}
@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
if callb == nil {
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
}
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()})
}
start := time.Now()
answer := h.runMethod(cp.ctx, msg, callb, args)
// Collect the statistics for RPC calls if metrics is enabled.
// We only care about pure rpc call. Filter out subscription.
if callb != h.unsubscribeCb {
@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
rpcServingTimer.UpdateSince(start)
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
}
return answer
}

View File

@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
var cfg clientConfig
cfg.httpClient = client
fn := newClientTransportHTTP(endpoint, &cfg)
return newClient(context.Background(), fn)
return newClient(context.Background(), &cfg, fn)
}
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
}
defer respBody.Close()
var respmsg jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil {
var resp jsonrpcMessage
batch := [1]*jsonrpcMessage{&resp}
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
return err
}
op.resp <- &respmsg
op.resp <- batch[:]
return nil
}
@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
return err
}
defer respBody.Close()
var respmsgs []jsonrpcMessage
var respmsgs []*jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
return err
}
if len(respmsgs) != len(msgs) {
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
}
for i := 0; i < len(respmsgs); i++ {
op.resp <- &respmsgs[i]
}
op.resp <- respmsgs
return nil
}

View File

@ -24,7 +24,8 @@ import (
// DialInProc attaches an in-process connection to the given RPC server.
func DialInProc(handler *Server) *Client {
initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
cfg := new(clientConfig)
c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe()
go handler.ServeCodec(NewCodec(p1), 0)
return NewCodec(p2), nil

View File

@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
return newClient(ctx, newClientTransportIPC(endpoint))
cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIPC(endpoint))
}
func newClientTransportIPC(endpoint string) reconnectFunc {

View File

@ -46,9 +46,11 @@ type Server struct {
services serviceRegistry
idgen func() ID
mutex sync.Mutex
codecs map[ServerCodec]struct{}
run atomic.Bool
mutex sync.Mutex
codecs map[ServerCodec]struct{}
run atomic.Bool
batchItemLimit int
batchResponseLimit int
}
// NewServer creates a new server instance with no registered handlers.
@ -65,6 +67,17 @@ func NewServer() *Server {
return server
}
// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit'
// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of
// response bytes across all requests in a batch.
//
// This method should be called before processing any requests via ServeCodec, ServeHTTP,
// ServeListener etc.
func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) {
s.batchItemLimit = itemLimit
s.batchResponseLimit = maxResponseSize
}
// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
}
defer s.untrackCodec(codec)
c := initClient(codec, s.idgen, &s.services)
cfg := &clientConfig{
idgen: s.idgen,
batchItemLimit: s.batchItemLimit,
batchResponseLimit: s.batchResponseLimit,
}
c := initClient(codec, &s.services, cfg)
<-codec.closed()
c.Close()
}
@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
return
}
h := newHandler(ctx, codec, s.idgen, &s.services)
h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit)
h.allowSubscribe = false
defer h.close(io.EOF, nil)

View File

@ -70,6 +70,7 @@ func TestServer(t *testing.T) {
func runTestScript(t *testing.T, file string) {
server := newTestServer()
server.SetBatchLimits(4, 100000)
content, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) {
}
}
}
func TestServerBatchResponseSizeLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(100, 60)
var (
batch []BatchElem
client = DialInProc(server)
)
for i := 0; i < 5; i++ {
batch = append(batch, BatchElem{
Method: "test_echo",
Args: []any{"x", 1},
Result: new(echoResult),
})
}
if err := client.BatchCall(batch); err != nil {
t.Fatal("error sending batch:", err)
}
for i := range batch {
// We expect the first two queries to be ok, but after that the size limit takes effect.
if i < 2 {
if batch[i].Error != nil {
t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error)
}
continue
}
// After two, we expect an error.
re, ok := batch[i].Error.(Error)
if !ok {
t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error)
}
wantedCode := errcodeResponseTooLarge
if re.ErrorCode() != wantedCode {
t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode)
}
}
}

View File

@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
return newClient(ctx, newClientTransportIO(in, out))
cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIO(in, out))
}
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {

13
rpc/testdata/invalid-batch-toolarge.js vendored Normal file
View File

@ -0,0 +1,13 @@
// This file checks the behavior of the batch item limit code.
// In tests, the batch item limit is set to 4. So to trigger the error,
// all batches in this file have 5 elements.
// For batches that do not contain any calls, a response message with "id" == null
// is returned.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}]
// For batches with at least one call, the call's "id" is used.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}]

View File

@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
if err != nil {
return nil, err
}
return newClient(ctx, connect)
return newClient(ctx, cfg, connect)
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
if err != nil {
return nil, err
}
return newClient(ctx, connect)
return newClient(ctx, cfg, connect)
}
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {