rpc: enforce the 128KB request limits on websockets too
This commit is contained in:
parent
6a2d2869f6
commit
555f42cfd8
@ -27,16 +27,16 @@ import (
|
|||||||
"mime"
|
"mime"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
maxHTTPRequestContentLength = 1024 * 128
|
maxRequestContentLength = 1024 * 128
|
||||||
)
|
)
|
||||||
|
|
||||||
var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||||
@ -182,8 +182,8 @@ func validateRequest(r *http.Request) (int, error) {
|
|||||||
if r.Method == http.MethodPut || r.Method == http.MethodDelete {
|
if r.Method == http.MethodPut || r.Method == http.MethodDelete {
|
||||||
return http.StatusMethodNotAllowed, errors.New("method not allowed")
|
return http.StatusMethodNotAllowed, errors.New("method not allowed")
|
||||||
}
|
}
|
||||||
if r.ContentLength > maxHTTPRequestContentLength {
|
if r.ContentLength > maxRequestContentLength {
|
||||||
err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength)
|
err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength)
|
||||||
return http.StatusRequestEntityTooLarge, err
|
return http.StatusRequestEntityTooLarge, err
|
||||||
}
|
}
|
||||||
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
|
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
|
||||||
|
@ -32,7 +32,7 @@ func TestHTTPErrorResponseWithPut(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
|
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
|
||||||
body := make([]rune, maxHTTPRequestContentLength+1)
|
body := make([]rune, maxRequestContentLength+1)
|
||||||
testHTTPErrorResponse(t,
|
testHTTPErrorResponse(t,
|
||||||
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
|
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
|
||||||
}
|
}
|
||||||
|
40
rpc/json.go
40
rpc/json.go
@ -78,10 +78,10 @@ type jsonNotification struct {
|
|||||||
type jsonCodec struct {
|
type jsonCodec struct {
|
||||||
closer sync.Once // close closed channel once
|
closer sync.Once // close closed channel once
|
||||||
closed chan interface{} // closed on Close
|
closed chan interface{} // closed on Close
|
||||||
decMu sync.Mutex // guards d
|
decMu sync.Mutex // guards the decoder
|
||||||
d *json.Decoder // decodes incoming requests
|
decode func(v interface{}) error // decoder to allow multiple transports
|
||||||
encMu sync.Mutex // guards e
|
encMu sync.Mutex // guards the encoder
|
||||||
e *json.Encoder // encodes responses
|
encode func(v interface{}) error // encoder to allow multiple transports
|
||||||
rw io.ReadWriteCloser // connection
|
rw io.ReadWriteCloser // connection
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,11 +96,29 @@ func (err *jsonError) ErrorCode() int {
|
|||||||
return err.Code
|
return err.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0
|
// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
|
||||||
|
// on explicitly given encoding and decoding methods.
|
||||||
|
func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec {
|
||||||
|
return &jsonCodec{
|
||||||
|
closed: make(chan interface{}),
|
||||||
|
encode: encode,
|
||||||
|
decode: decode,
|
||||||
|
rw: rwc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
|
||||||
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec {
|
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec {
|
||||||
d := json.NewDecoder(rwc)
|
enc := json.NewEncoder(rwc)
|
||||||
d.UseNumber()
|
dec := json.NewDecoder(rwc)
|
||||||
return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc}
|
dec.UseNumber()
|
||||||
|
|
||||||
|
return &jsonCodec{
|
||||||
|
closed: make(chan interface{}),
|
||||||
|
encode: enc.Encode,
|
||||||
|
decode: dec.Decode,
|
||||||
|
rw: rwc,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBatch returns true when the first non-whitespace characters is '['
|
// isBatch returns true when the first non-whitespace characters is '['
|
||||||
@ -123,14 +141,12 @@ func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) {
|
|||||||
defer c.decMu.Unlock()
|
defer c.decMu.Unlock()
|
||||||
|
|
||||||
var incomingMsg json.RawMessage
|
var incomingMsg json.RawMessage
|
||||||
if err := c.d.Decode(&incomingMsg); err != nil {
|
if err := c.decode(&incomingMsg); err != nil {
|
||||||
return nil, false, &invalidRequestError{err.Error()}
|
return nil, false, &invalidRequestError{err.Error()}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isBatch(incomingMsg) {
|
if isBatch(incomingMsg) {
|
||||||
return parseBatchRequest(incomingMsg)
|
return parseBatchRequest(incomingMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseRequest(incomingMsg)
|
return parseRequest(incomingMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,7 +354,7 @@ func (c *jsonCodec) Write(res interface{}) error {
|
|||||||
c.encMu.Lock()
|
c.encMu.Lock()
|
||||||
defer c.encMu.Unlock()
|
defer c.encMu.Unlock()
|
||||||
|
|
||||||
return c.e.Encode(res)
|
return c.encode(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the underlying connection
|
// Close the underlying connection
|
||||||
|
@ -17,8 +17,10 @@
|
|||||||
package rpc
|
package rpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -32,6 +34,23 @@ import (
|
|||||||
"gopkg.in/fatih/set.v0"
|
"gopkg.in/fatih/set.v0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// websocketJSONCodec is a custom JSON codec with payload size enforcement and
|
||||||
|
// special number parsing.
|
||||||
|
var websocketJSONCodec = websocket.Codec{
|
||||||
|
// Marshal is the stock JSON marshaller used by the websocket library too.
|
||||||
|
Marshal: func(v interface{}) ([]byte, byte, error) {
|
||||||
|
msg, err := json.Marshal(v)
|
||||||
|
return msg, websocket.TextFrame, err
|
||||||
|
},
|
||||||
|
// Unmarshal is a specialized unmarshaller to properly convert numbers.
|
||||||
|
Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(msg))
|
||||||
|
dec.UseNumber()
|
||||||
|
|
||||||
|
return dec.Decode(v)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
||||||
//
|
//
|
||||||
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
||||||
@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
|||||||
return websocket.Server{
|
return websocket.Server{
|
||||||
Handshake: wsHandshakeValidator(allowedOrigins),
|
Handshake: wsHandshakeValidator(allowedOrigins),
|
||||||
Handler: func(conn *websocket.Conn) {
|
Handler: func(conn *websocket.Conn) {
|
||||||
srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
|
// Create a custom encode/decode pair to enforce payload size and number encoding
|
||||||
|
conn.MaxPayloadBytes = maxRequestContentLength
|
||||||
|
|
||||||
|
encoder := func(v interface{}) error {
|
||||||
|
return websocketJSONCodec.Send(conn, v)
|
||||||
|
}
|
||||||
|
decoder := func(v interface{}) error {
|
||||||
|
return websocketJSONCodec.Receive(conn, v)
|
||||||
|
}
|
||||||
|
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user