// Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package rpc import ( "context" "encoding/base64" "fmt" "net/http" "net/url" "os" "strings" "sync" "time" mapset "github.com/deckarep/golang-set/v2" "github.com/ethereum/go-ethereum/log" "github.com/gorilla/websocket" ) const ( wsReadBuffer = 1024 wsWriteBuffer = 1024 wsPingInterval = 30 * time.Second wsPingWriteTimeout = 5 * time.Second wsPongTimeout = 30 * time.Second wsDefaultReadLimit = 32 * 1024 * 1024 ) var wsBufferPool = new(sync.Pool) // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { var upgrader = websocket.Upgrader{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, WriteBufferPool: wsBufferPool, CheckOrigin: wsHandshakeValidator(allowedOrigins), } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Debug("WebSocket upgrade failed", "err", err) return } codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) s.ServeCodec(codec, 0) }) } // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted. func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { origins := mapset.NewSet[string]() allowAllOrigins := false for _, origin := range allowedOrigins { if origin == "*" { allowAllOrigins = true } if origin != "" { origins.Add(origin) } } // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { origins.Add("http://" + hostname) } } log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) f := func(req *http.Request) bool { // Skip origin verification if no Origin header is present. The origin check // is supposed to protect against browser based attacks. Browsers always set // Origin. Non-browser software can put anything in origin and checking it doesn't // provide additional security. if _, ok := req.Header["Origin"]; !ok { return true } // Verify origin against allow list. origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || originIsAllowed(origins, origin) { return true } log.Warn("Rejected WebSocket connection", "origin", origin) return false } return f } type wsHandshakeError struct { err error status string } func (e wsHandshakeError) Error() string { s := e.err.Error() if e.status != "" { s += " (HTTP status " + e.status + ")" } return s } func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool { it := allowedOrigins.Iterator() for origin := range it.C { if ruleAllowsOrigin(origin, browserOrigin) { return true } } return false } func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { var ( allowedScheme, allowedHostname, allowedPort string browserScheme, browserHostname, browserPort string err error ) allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) if err != nil { log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) return false } browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) if err != nil { log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) return false } if allowedScheme != "" && allowedScheme != browserScheme { return false } if allowedHostname != "" && allowedHostname != browserHostname { return false } if allowedPort != "" && allowedPort != browserPort { return false } return true } func parseOriginURL(origin string) (string, string, string, error) { parsedURL, err := url.Parse(strings.ToLower(origin)) if err != nil { return "", "", "", err } var scheme, hostname, port string if strings.Contains(origin, "://") { scheme = parsedURL.Scheme hostname = parsedURL.Hostname() port = parsedURL.Port() } else { scheme = "" hostname = parsedURL.Scheme port = parsedURL.Opaque if hostname == "" { hostname = origin } } return scheme, hostname, port, nil } // DialWebsocketWithDialer creates a new RPC client using WebSocket. // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. // // Deprecated: use DialOptions and the WithWebsocketDialer option. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { cfg := new(clientConfig) cfg.wsDialer = &dialer if origin != "" { cfg.setHeader("origin", origin) } connect, err := newClientTransportWS(endpoint, cfg) if err != nil { return nil, err } return newClient(ctx, cfg, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint. // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { cfg := new(clientConfig) if origin != "" { cfg.setHeader("origin", origin) } connect, err := newClientTransportWS(endpoint, cfg) if err != nil { return nil, err } return newClient(ctx, cfg, connect) } func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { dialer := cfg.wsDialer if dialer == nil { dialer = &websocket.Dialer{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, WriteBufferPool: wsBufferPool, Proxy: http.ProxyFromEnvironment, } } dialURL, header, err := wsClientHeaders(endpoint, "") if err != nil { return nil, err } for key, values := range cfg.httpHeaders { header[key] = values } connect := func(ctx context.Context) (ServerCodec, error) { header := header.Clone() if cfg.httpAuth != nil { if err := cfg.httpAuth(header); err != nil { return nil, err } } conn, resp, err := dialer.DialContext(ctx, dialURL, header) if err != nil { hErr := wsHandshakeError{err: err} if resp != nil { hErr.status = resp.Status } return nil, hErr } messageSizeLimit := int64(wsDefaultReadLimit) if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 { messageSizeLimit = *cfg.wsMessageSizeLimit } return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil } return connect, nil } func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { endpointURL, err := url.Parse(endpoint) if err != nil { return endpoint, nil, err } header := make(http.Header) if origin != "" { header.Add("origin", origin) } if endpointURL.User != nil { b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) header.Add("authorization", "Basic "+b64auth) endpointURL.User = nil } return endpointURL.String(), header, nil } type websocketCodec struct { *jsonCodec conn *websocket.Conn info PeerInfo wg sync.WaitGroup pingReset chan struct{} pongReceived chan struct{} } func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec { conn.SetReadLimit(readLimit) encode := func(v interface{}, isErrorResponse bool) error { return conn.WriteJSON(v) } wc := &websocketCodec{ jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), pongReceived: make(chan struct{}), info: PeerInfo{ Transport: "ws", RemoteAddr: conn.RemoteAddr().String(), }, } // Fill in connection details. wc.info.HTTP.Host = host wc.info.HTTP.Origin = req.Get("Origin") wc.info.HTTP.UserAgent = req.Get("User-Agent") // Start pinger. conn.SetPongHandler(func(appData string) error { select { case wc.pongReceived <- struct{}{}: case <-wc.closed(): } return nil }) wc.wg.Add(1) go wc.pingLoop() return wc } func (wc *websocketCodec) close() { wc.jsonCodec.close() wc.wg.Wait() } func (wc *websocketCodec) peerInfo() PeerInfo { return wc.info } func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { err := wc.jsonCodec.writeJSON(ctx, v, isError) if err == nil { // Notify pingLoop to delay the next idle ping. select { case wc.pingReset <- struct{}{}: default: } } return err } // pingLoop sends periodic ping frames when the connection is idle. func (wc *websocketCodec) pingLoop() { var pingTimer = time.NewTimer(wsPingInterval) defer wc.wg.Done() defer pingTimer.Stop() for { select { case <-wc.closed(): return case <-wc.pingReset: if !pingTimer.Stop() { <-pingTimer.C } pingTimer.Reset(wsPingInterval) case <-pingTimer.C: wc.jsonCodec.encMu.Lock() wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) wc.conn.WriteMessage(websocket.PingMessage, nil) wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) wc.jsonCodec.encMu.Unlock() pingTimer.Reset(wsPingInterval) case <-wc.pongReceived: wc.conn.SetReadDeadline(time.Time{}) } } }