rpc: enforce the 128KB request limits on websockets too

This commit is contained in:
Péter Szilágyi 2018-03-13 13:23:44 +02:00
parent 6a2d2869f6
commit 555f42cfd8
No known key found for this signature in database
GPG Key ID: E9AE538CEDF8293D
4 changed files with 66 additions and 22 deletions

View File

@ -27,16 +27,16 @@ import (
"mime" "mime"
"net" "net"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
"github.com/rs/cors" "github.com/rs/cors"
"strings"
) )
const ( const (
contentType = "application/json" contentType = "application/json"
maxHTTPRequestContentLength = 1024 * 128 maxRequestContentLength = 1024 * 128
) )
var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0") var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0")
@ -182,8 +182,8 @@ func validateRequest(r *http.Request) (int, error) {
if r.Method == http.MethodPut || r.Method == http.MethodDelete { if r.Method == http.MethodPut || r.Method == http.MethodDelete {
return http.StatusMethodNotAllowed, errors.New("method not allowed") return http.StatusMethodNotAllowed, errors.New("method not allowed")
} }
if r.ContentLength > maxHTTPRequestContentLength { if r.ContentLength > maxRequestContentLength {
err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxHTTPRequestContentLength) err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength)
return http.StatusRequestEntityTooLarge, err return http.StatusRequestEntityTooLarge, err
} }
mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")) mt, _, err := mime.ParseMediaType(r.Header.Get("content-type"))

View File

@ -32,7 +32,7 @@ func TestHTTPErrorResponseWithPut(t *testing.T) {
} }
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
body := make([]rune, maxHTTPRequestContentLength+1) body := make([]rune, maxRequestContentLength+1)
testHTTPErrorResponse(t, testHTTPErrorResponse(t,
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
} }

View File

@ -78,10 +78,10 @@ type jsonNotification struct {
type jsonCodec struct { type jsonCodec struct {
closer sync.Once // close closed channel once closer sync.Once // close closed channel once
closed chan interface{} // closed on Close closed chan interface{} // closed on Close
decMu sync.Mutex // guards d decMu sync.Mutex // guards the decoder
d *json.Decoder // decodes incoming requests decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards e encMu sync.Mutex // guards the encoder
e *json.Encoder // encodes responses encode func(v interface{}) error // encoder to allow multiple transports
rw io.ReadWriteCloser // connection rw io.ReadWriteCloser // connection
} }
@ -96,11 +96,29 @@ func (err *jsonError) ErrorCode() int {
return err.Code return err.Code
} }
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0 // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based
// on explicitly given encoding and decoding methods.
func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec {
return &jsonCodec{
closed: make(chan interface{}),
encode: encode,
decode: decode,
rw: rwc,
}
}
// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0.
func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec { func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec {
d := json.NewDecoder(rwc) enc := json.NewEncoder(rwc)
d.UseNumber() dec := json.NewDecoder(rwc)
return &jsonCodec{closed: make(chan interface{}), d: d, e: json.NewEncoder(rwc), rw: rwc} dec.UseNumber()
return &jsonCodec{
closed: make(chan interface{}),
encode: enc.Encode,
decode: dec.Decode,
rw: rwc,
}
} }
// isBatch returns true when the first non-whitespace characters is '[' // isBatch returns true when the first non-whitespace characters is '['
@ -123,14 +141,12 @@ func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) {
defer c.decMu.Unlock() defer c.decMu.Unlock()
var incomingMsg json.RawMessage var incomingMsg json.RawMessage
if err := c.d.Decode(&incomingMsg); err != nil { if err := c.decode(&incomingMsg); err != nil {
return nil, false, &invalidRequestError{err.Error()} return nil, false, &invalidRequestError{err.Error()}
} }
if isBatch(incomingMsg) { if isBatch(incomingMsg) {
return parseBatchRequest(incomingMsg) return parseBatchRequest(incomingMsg)
} }
return parseRequest(incomingMsg) return parseRequest(incomingMsg)
} }
@ -338,7 +354,7 @@ func (c *jsonCodec) Write(res interface{}) error {
c.encMu.Lock() c.encMu.Lock()
defer c.encMu.Unlock() defer c.encMu.Unlock()
return c.e.Encode(res) return c.encode(res)
} }
// Close the underlying connection // Close the underlying connection

View File

@ -17,8 +17,10 @@
package rpc package rpc
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -32,6 +34,23 @@ import (
"gopkg.in/fatih/set.v0" "gopkg.in/fatih/set.v0"
) )
// websocketJSONCodec is a custom JSON codec with payload size enforcement and
// special number parsing.
var websocketJSONCodec = websocket.Codec{
// Marshal is the stock JSON marshaller used by the websocket library too.
Marshal: func(v interface{}) ([]byte, byte, error) {
msg, err := json.Marshal(v)
return msg, websocket.TextFrame, err
},
// Unmarshal is a specialized unmarshaller to properly convert numbers.
Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
dec := json.NewDecoder(bytes.NewReader(msg))
dec.UseNumber()
return dec.Decode(v)
},
}
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
// //
// allowedOrigins should be a comma-separated list of allowed origin URLs. // allowedOrigins should be a comma-separated list of allowed origin URLs.
@ -40,7 +59,16 @@ func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{ return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins), Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) { Handler: func(conn *websocket.Conn) {
srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) // Create a custom encode/decode pair to enforce payload size and number encoding
conn.MaxPayloadBytes = maxRequestContentLength
encoder := func(v interface{}) error {
return websocketJSONCodec.Send(conn, v)
}
decoder := func(v interface{}) error {
return websocketJSONCodec.Receive(conn, v)
}
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
}, },
} }
} }