eth, rpc: add configurable option for wsMessageSizeLimit (#27801)
This change adds a configurable limit to websocket message. --------- Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
parent
c39cbc1a78
commit
705a51e566
@ -34,7 +34,8 @@ type clientConfig struct {
|
|||||||
httpAuth HTTPAuth
|
httpAuth HTTPAuth
|
||||||
|
|
||||||
// WebSocket options
|
// WebSocket options
|
||||||
wsDialer *websocket.Dialer
|
wsDialer *websocket.Dialer
|
||||||
|
wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit
|
||||||
|
|
||||||
// RPC handler options
|
// RPC handler options
|
||||||
idgen func() ID
|
idgen func() ID
|
||||||
@ -66,6 +67,14 @@ func WithWebsocketDialer(dialer websocket.Dialer) ClientOption {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithWebsocketMessageSizeLimit configures the websocket message size limit used by the RPC
|
||||||
|
// client. Passing a limit of 0 means no limit.
|
||||||
|
func WithWebsocketMessageSizeLimit(messageSizeLimit int64) ClientOption {
|
||||||
|
return optionFunc(func(cfg *clientConfig) {
|
||||||
|
cfg.wsMessageSizeLimit = &messageSizeLimit
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
|
// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
|
||||||
// will be used for both HTTP and WebSocket connections.
|
// will be used for both HTTP and WebSocket connections.
|
||||||
func WithHeader(key, value string) ClientOption {
|
func WithHeader(key, value string) ClientOption {
|
||||||
|
@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
|
|||||||
t.Fatalf("Expected service calc to be registered")
|
t.Fatalf("Expected service calc to be registered")
|
||||||
}
|
}
|
||||||
|
|
||||||
wantCallbacks := 13
|
wantCallbacks := 14
|
||||||
if len(svc.callbacks) != wantCallbacks {
|
if len(svc.callbacks) != wantCallbacks {
|
||||||
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
|
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
|
||||||
}
|
}
|
||||||
|
@ -90,6 +90,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *
|
|||||||
return echoResult{str, i, args}
|
return echoResult{str, i, args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *testService) Repeat(msg string, i int) string {
|
||||||
|
return strings.Repeat(msg, i)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
|
func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
|
||||||
return PeerInfoFromContext(ctx)
|
return PeerInfoFromContext(ctx)
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ const (
|
|||||||
wsPingInterval = 30 * time.Second
|
wsPingInterval = 30 * time.Second
|
||||||
wsPingWriteTimeout = 5 * time.Second
|
wsPingWriteTimeout = 5 * time.Second
|
||||||
wsPongTimeout = 30 * time.Second
|
wsPongTimeout = 30 * time.Second
|
||||||
wsMessageSizeLimit = 32 * 1024 * 1024
|
wsDefaultReadLimit = 32 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
var wsBufferPool = new(sync.Pool)
|
var wsBufferPool = new(sync.Pool)
|
||||||
@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
|||||||
log.Debug("WebSocket upgrade failed", "err", err)
|
log.Debug("WebSocket upgrade failed", "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
codec := newWebsocketCodec(conn, r.Host, r.Header)
|
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
|
||||||
s.ServeCodec(codec, 0)
|
s.ServeCodec(codec, 0)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -251,7 +251,11 @@ func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, er
|
|||||||
}
|
}
|
||||||
return nil, hErr
|
return nil, hErr
|
||||||
}
|
}
|
||||||
return newWebsocketCodec(conn, dialURL, header), nil
|
messageSizeLimit := int64(wsDefaultReadLimit)
|
||||||
|
if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 {
|
||||||
|
messageSizeLimit = *cfg.wsMessageSizeLimit
|
||||||
|
}
|
||||||
|
return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil
|
||||||
}
|
}
|
||||||
return connect, nil
|
return connect, nil
|
||||||
}
|
}
|
||||||
@ -283,8 +287,8 @@ type websocketCodec struct {
|
|||||||
pongReceived chan struct{}
|
pongReceived chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
|
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec {
|
||||||
conn.SetReadLimit(wsMessageSizeLimit)
|
conn.SetReadLimit(readLimit)
|
||||||
encode := func(v interface{}, isErrorResponse bool) error {
|
encode := func(v interface{}, isErrorResponse bool) error {
|
||||||
return conn.WriteJSON(v)
|
return conn.WriteJSON(v)
|
||||||
}
|
}
|
||||||
|
@ -113,6 +113,66 @@ func TestWebsocketLargeCall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test checks whether the wsMessageSizeLimit option is obeyed.
|
||||||
|
func TestWebsocketLargeRead(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
srv = newTestServer()
|
||||||
|
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
|
||||||
|
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
|
||||||
|
)
|
||||||
|
defer srv.Stop()
|
||||||
|
defer httpsrv.Close()
|
||||||
|
|
||||||
|
testLimit := func(limit *int64) {
|
||||||
|
opts := []ClientOption{}
|
||||||
|
expLimit := int64(wsDefaultReadLimit)
|
||||||
|
if limit != nil && *limit >= 0 {
|
||||||
|
opts = append(opts, WithWebsocketMessageSizeLimit(*limit))
|
||||||
|
if *limit > 0 {
|
||||||
|
expLimit = *limit // 0 means infinite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client, err := DialOptions(context.Background(), wsURL, opts...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't dial: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
// Remove some bytes for json encoding overhead.
|
||||||
|
underLimit := int(expLimit - 128)
|
||||||
|
overLimit := expLimit + 1
|
||||||
|
if expLimit == wsDefaultReadLimit {
|
||||||
|
// No point trying the full 32MB in tests. Just sanity-check that
|
||||||
|
// it's not obviously limited.
|
||||||
|
underLimit = 1024
|
||||||
|
overLimit = -1
|
||||||
|
}
|
||||||
|
var res string
|
||||||
|
// Check under limit
|
||||||
|
if err = client.Call(&res, "test_repeat", "A", underLimit); err != nil {
|
||||||
|
t.Fatalf("unexpected error with limit %d: %v", expLimit, err)
|
||||||
|
}
|
||||||
|
if len(res) != underLimit || strings.Count(res, "A") != underLimit {
|
||||||
|
t.Fatal("incorrect data")
|
||||||
|
}
|
||||||
|
// Check over limit
|
||||||
|
if overLimit > 0 {
|
||||||
|
err = client.Call(&res, "test_repeat", "A", expLimit+1)
|
||||||
|
if err == nil || err != websocket.ErrReadLimit {
|
||||||
|
t.Fatalf("wrong error with limit %d: %v expecting %v", expLimit, err, websocket.ErrReadLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ptr := func(v int64) *int64 { return &v }
|
||||||
|
|
||||||
|
testLimit(ptr(-1)) // Should be ignored (use default)
|
||||||
|
testLimit(ptr(0)) // Should be ignored (use default)
|
||||||
|
testLimit(nil) // Should be ignored (use default)
|
||||||
|
testLimit(ptr(200))
|
||||||
|
testLimit(ptr(wsDefaultReadLimit * 2))
|
||||||
|
}
|
||||||
|
|
||||||
func TestWebsocketPeerInfo(t *testing.T) {
|
func TestWebsocketPeerInfo(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
s = newTestServer()
|
s = newTestServer()
|
||||||
@ -206,7 +266,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
|
|||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
defer httpsrv.Close()
|
defer httpsrv.Close()
|
||||||
|
|
||||||
respLength := wsMessageSizeLimit - 50
|
respLength := wsDefaultReadLimit - 50
|
||||||
srv.RegisterName("test", largeRespService{respLength})
|
srv.RegisterName("test", largeRespService{respLength})
|
||||||
|
|
||||||
c, err := DialWebsocket(context.Background(), wsURL, "")
|
c, err := DialWebsocket(context.Background(), wsURL, "")
|
||||||
|
Loading…
Reference in New Issue
Block a user