forked from cerc-io/plugeth
node: allow websocket and HTTP on the same port (#20810)
This change makes it possible to run geth with JSON-RPC over HTTP and WebSocket on the same TCP port. The default port for WebSocket is still 8546. geth --rpc --rpcport 8545 --ws --wsport 8545 This also removes a lot of deprecated API surface from package rpc. The rpc package is now purely about serving JSON-RPC and no longer provides a way to start an HTTP server.
This commit is contained in:
parent
5065cdefff
commit
07d909ff32
@ -583,9 +583,16 @@ func signer(c *cli.Context) error {
|
|||||||
vhosts := splitAndTrim(c.GlobalString(utils.RPCVirtualHostsFlag.Name))
|
vhosts := splitAndTrim(c.GlobalString(utils.RPCVirtualHostsFlag.Name))
|
||||||
cors := splitAndTrim(c.GlobalString(utils.RPCCORSDomainFlag.Name))
|
cors := splitAndTrim(c.GlobalString(utils.RPCCORSDomainFlag.Name))
|
||||||
|
|
||||||
|
srv := rpc.NewServer()
|
||||||
|
err := node.RegisterApisFromWhitelist(rpcAPI, []string{"account"}, srv, false)
|
||||||
|
if err != nil {
|
||||||
|
utils.Fatalf("Could not register API: %w", err)
|
||||||
|
}
|
||||||
|
handler := node.NewHTTPHandlerStack(srv, cors, vhosts)
|
||||||
|
|
||||||
// start http server
|
// start http server
|
||||||
httpEndpoint := fmt.Sprintf("%s:%d", c.GlobalString(utils.RPCListenAddrFlag.Name), c.Int(rpcPortFlag.Name))
|
httpEndpoint := fmt.Sprintf("%s:%d", c.GlobalString(utils.RPCListenAddrFlag.Name), c.Int(rpcPortFlag.Name))
|
||||||
listener, _, err := rpc.StartHTTPEndpoint(httpEndpoint, rpcAPI, []string{"account"}, cors, vhosts, rpc.DefaultHTTPTimeouts)
|
listener, err := node.StartHTTPEndpoint(httpEndpoint, rpc.DefaultHTTPTimeouts, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Fatalf("Could not start RPC api: %v", err)
|
utils.Fatalf("Could not start RPC api: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -890,6 +890,14 @@ func retesteth(ctx *cli.Context) error {
|
|||||||
vhosts := splitAndTrim(ctx.GlobalString(utils.RPCVirtualHostsFlag.Name))
|
vhosts := splitAndTrim(ctx.GlobalString(utils.RPCVirtualHostsFlag.Name))
|
||||||
cors := splitAndTrim(ctx.GlobalString(utils.RPCCORSDomainFlag.Name))
|
cors := splitAndTrim(ctx.GlobalString(utils.RPCCORSDomainFlag.Name))
|
||||||
|
|
||||||
|
// register apis and create handler stack
|
||||||
|
srv := rpc.NewServer()
|
||||||
|
err := node.RegisterApisFromWhitelist(rpcAPI, []string{"test", "eth", "debug", "web3"}, srv, false)
|
||||||
|
if err != nil {
|
||||||
|
utils.Fatalf("Could not register RPC apis: %w", err)
|
||||||
|
}
|
||||||
|
handler := node.NewHTTPHandlerStack(srv, cors, vhosts)
|
||||||
|
|
||||||
// start http server
|
// start http server
|
||||||
var RetestethHTTPTimeouts = rpc.HTTPTimeouts{
|
var RetestethHTTPTimeouts = rpc.HTTPTimeouts{
|
||||||
ReadTimeout: 120 * time.Second,
|
ReadTimeout: 120 * time.Second,
|
||||||
@ -897,7 +905,7 @@ func retesteth(ctx *cli.Context) error {
|
|||||||
IdleTimeout: 120 * time.Second,
|
IdleTimeout: 120 * time.Second,
|
||||||
}
|
}
|
||||||
httpEndpoint := fmt.Sprintf("%s:%d", ctx.GlobalString(utils.RPCListenAddrFlag.Name), ctx.Int(rpcPortFlag.Name))
|
httpEndpoint := fmt.Sprintf("%s:%d", ctx.GlobalString(utils.RPCListenAddrFlag.Name), ctx.Int(rpcPortFlag.Name))
|
||||||
listener, _, err := rpc.StartHTTPEndpoint(httpEndpoint, rpcAPI, []string{"test", "eth", "debug", "web3"}, cors, vhosts, RetestethHTTPTimeouts)
|
listener, err := node.StartHTTPEndpoint(httpEndpoint, RetestethHTTPTimeouts, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Fatalf("Could not start RPC api: %v", err)
|
utils.Fatalf("Could not start RPC api: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ethereum/go-ethereum/internal/ethapi"
|
"github.com/ethereum/go-ethereum/internal/ethapi"
|
||||||
"github.com/ethereum/go-ethereum/log"
|
"github.com/ethereum/go-ethereum/log"
|
||||||
|
"github.com/ethereum/go-ethereum/node"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
"github.com/ethereum/go-ethereum/rpc"
|
"github.com/ethereum/go-ethereum/rpc"
|
||||||
"github.com/graph-gophers/graphql-go"
|
"github.com/graph-gophers/graphql-go"
|
||||||
@ -68,7 +69,18 @@ func (s *Service) Start(server *p2p.Server) error {
|
|||||||
if s.listener, err = net.Listen("tcp", s.endpoint); err != nil {
|
if s.listener, err = net.Listen("tcp", s.endpoint); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go rpc.NewHTTPServer(s.cors, s.vhosts, s.timeouts, s.handler).Serve(s.listener)
|
// create handler stack and wrap the graphql handler
|
||||||
|
handler := node.NewHTTPHandlerStack(s.handler, s.cors, s.vhosts)
|
||||||
|
// make sure timeout values are meaningful
|
||||||
|
node.CheckTimeouts(&s.timeouts)
|
||||||
|
// create http server
|
||||||
|
httpSrv := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: s.timeouts.ReadTimeout,
|
||||||
|
WriteTimeout: s.timeouts.WriteTimeout,
|
||||||
|
IdleTimeout: s.timeouts.IdleTimeout,
|
||||||
|
}
|
||||||
|
go httpSrv.Serve(s.listener)
|
||||||
log.Info("GraphQL endpoint opened", "url", fmt.Sprintf("http://%s", s.endpoint))
|
log.Info("GraphQL endpoint opened", "url", fmt.Sprintf("http://%s", s.endpoint))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -186,7 +186,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil {
|
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
99
node/endpoints.go
Normal file
99
node/endpoints.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright 2018 The go-ethereum Authors
|
||||||
|
// This file is part of the go-ethereum library.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Lesser General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Lesser General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public License
|
||||||
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/log"
|
||||||
|
"github.com/ethereum/go-ethereum/rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartHTTPEndpoint starts the HTTP RPC endpoint.
|
||||||
|
func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) {
|
||||||
|
// start the HTTP listener
|
||||||
|
var (
|
||||||
|
listener net.Listener
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if listener, err = net.Listen("tcp", endpoint); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// make sure timeout values are meaningful
|
||||||
|
CheckTimeouts(&timeouts)
|
||||||
|
// Bundle and start the HTTP server
|
||||||
|
httpSrv := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: timeouts.ReadTimeout,
|
||||||
|
WriteTimeout: timeouts.WriteTimeout,
|
||||||
|
IdleTimeout: timeouts.IdleTimeout,
|
||||||
|
}
|
||||||
|
go httpSrv.Serve(listener)
|
||||||
|
return listener, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// startWSEndpoint starts a websocket endpoint.
|
||||||
|
func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) {
|
||||||
|
// start the HTTP listener
|
||||||
|
var (
|
||||||
|
listener net.Listener
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if listener, err = net.Listen("tcp", endpoint); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wsSrv := &http.Server{Handler: handler}
|
||||||
|
go wsSrv.Serve(listener)
|
||||||
|
return listener, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkModuleAvailability checks that all names given in modules are actually
|
||||||
|
// available API services. It assumes that the MetadataApi module ("rpc") is always available;
|
||||||
|
// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
|
||||||
|
func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) {
|
||||||
|
availableSet := make(map[string]struct{})
|
||||||
|
for _, api := range apis {
|
||||||
|
if _, ok := availableSet[api.Namespace]; !ok {
|
||||||
|
availableSet[api.Namespace] = struct{}{}
|
||||||
|
available = append(available, api.Namespace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, name := range modules {
|
||||||
|
if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi {
|
||||||
|
bad = append(bad, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bad, available
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTimeouts ensures that timeout values are meaningful
|
||||||
|
func CheckTimeouts(timeouts *rpc.HTTPTimeouts) {
|
||||||
|
if timeouts.ReadTimeout < time.Second {
|
||||||
|
log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout)
|
||||||
|
timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout
|
||||||
|
}
|
||||||
|
if timeouts.WriteTimeout < time.Second {
|
||||||
|
log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout)
|
||||||
|
timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout
|
||||||
|
}
|
||||||
|
if timeouts.IdleTimeout < time.Second {
|
||||||
|
log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout)
|
||||||
|
timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout
|
||||||
|
}
|
||||||
|
}
|
69
node/node.go
69
node/node.go
@ -291,17 +291,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error {
|
|||||||
n.stopInProc()
|
n.stopInProc()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil {
|
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil {
|
||||||
n.stopIPC()
|
n.stopIPC()
|
||||||
n.stopInProc()
|
n.stopInProc()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
|
// if endpoints are not the same, start separate servers
|
||||||
n.stopHTTP()
|
if n.httpEndpoint != n.wsEndpoint {
|
||||||
n.stopIPC()
|
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
|
||||||
n.stopInProc()
|
n.stopHTTP()
|
||||||
return err
|
n.stopIPC()
|
||||||
|
n.stopInProc()
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// All API endpoints started successfully
|
// All API endpoints started successfully
|
||||||
n.rpcAPIs = apis
|
n.rpcAPIs = apis
|
||||||
return nil
|
return nil
|
||||||
@ -359,22 +363,36 @@ func (n *Node) stopIPC() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// startHTTP initializes and starts the HTTP RPC endpoint.
|
// startHTTP initializes and starts the HTTP RPC endpoint.
|
||||||
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error {
|
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error {
|
||||||
// Short circuit if the HTTP endpoint isn't being exposed
|
// Short circuit if the HTTP endpoint isn't being exposed
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts)
|
// register apis and create handler stack
|
||||||
|
srv := rpc.NewServer()
|
||||||
|
err := RegisterApisFromWhitelist(apis, modules, srv, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
handler := NewHTTPHandlerStack(srv, cors, vhosts)
|
||||||
|
// wrap handler in websocket handler only if websocket port is the same as http rpc
|
||||||
|
if n.httpEndpoint == n.wsEndpoint {
|
||||||
|
handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins))
|
||||||
|
}
|
||||||
|
listener, err := StartHTTPEndpoint(endpoint, timeouts, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()),
|
n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()),
|
||||||
"cors", strings.Join(cors, ","),
|
"cors", strings.Join(cors, ","),
|
||||||
"vhosts", strings.Join(vhosts, ","))
|
"vhosts", strings.Join(vhosts, ","))
|
||||||
|
if n.httpEndpoint == n.wsEndpoint {
|
||||||
|
n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr()))
|
||||||
|
}
|
||||||
// All listeners booted successfully
|
// All listeners booted successfully
|
||||||
n.httpEndpoint = endpoint
|
n.httpEndpoint = endpoint
|
||||||
n.httpListener = listener
|
n.httpListener = listener
|
||||||
n.httpHandler = handler
|
n.httpHandler = srv
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -399,7 +417,14 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
|
|||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll)
|
|
||||||
|
srv := rpc.NewServer()
|
||||||
|
handler := srv.WebsocketHandler(wsOrigins)
|
||||||
|
err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
listener, err := startWSEndpoint(endpoint, handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -407,7 +432,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
|
|||||||
// All listeners booted successfully
|
// All listeners booted successfully
|
||||||
n.wsEndpoint = endpoint
|
n.wsEndpoint = endpoint
|
||||||
n.wsListener = listener
|
n.wsListener = listener
|
||||||
n.wsHandler = handler
|
n.wsHandler = srv
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -664,3 +689,25 @@ func (n *Node) apis() []rpc.API {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
|
||||||
|
// and then registers all of the APIs exposed by the services.
|
||||||
|
func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error {
|
||||||
|
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
|
||||||
|
log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
|
||||||
|
}
|
||||||
|
// Generate the whitelist based on the allowed modules
|
||||||
|
whitelist := make(map[string]bool)
|
||||||
|
for _, module := range modules {
|
||||||
|
whitelist[module] = true
|
||||||
|
}
|
||||||
|
// Register all the APIs exposed by the services
|
||||||
|
for _, api := range apis {
|
||||||
|
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
|
||||||
|
if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -19,6 +19,7 @@ package node
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@ -27,6 +28,8 @@ import (
|
|||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
"github.com/ethereum/go-ethereum/rpc"
|
"github.com/ethereum/go-ethereum/rpc"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -597,3 +600,58 @@ func TestAPIGather(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
|
||||||
|
node := startHTTP(t)
|
||||||
|
defer node.stopHTTP()
|
||||||
|
|
||||||
|
wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not issue new http request ", err)
|
||||||
|
}
|
||||||
|
wsReq.Header.Set("Connection", "upgrade")
|
||||||
|
wsReq.Header.Set("Upgrade", "websocket")
|
||||||
|
wsReq.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
|
wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
|
||||||
|
|
||||||
|
resp := doHTTPRequest(t, wsReq)
|
||||||
|
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) {
|
||||||
|
node := startHTTP(t)
|
||||||
|
defer node.stopHTTP()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not issue new http request ", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Accept-Encoding", "gzip")
|
||||||
|
|
||||||
|
resp := doHTTPRequest(t, httpReq)
|
||||||
|
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func startHTTP(t *testing.T) *Node {
|
||||||
|
conf := &Config{HTTPPort: 7453, WSPort: 7453}
|
||||||
|
node, err := New(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not create a new node ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not start http service on node ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not issue a GET request to the given endpoint", err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
159
node/rpcstack.go
Normal file
159
node/rpcstack.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
// Copyright 2015 The go-ethereum Authors
|
||||||
|
// This file is part of the go-ethereum library.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Lesser General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Lesser General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public License
|
||||||
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/log"
|
||||||
|
"github.com/rs/cors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewHTTPHandlerStack returns wrapped http-related handlers
|
||||||
|
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
|
||||||
|
// Wrap the CORS-handler within a host-handler
|
||||||
|
handler := newCorsHandler(srv, cors)
|
||||||
|
handler = newVHostHandler(vhosts, handler)
|
||||||
|
return newGzipHandler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
|
||||||
|
// disable CORS support if user has not specified a custom CORS configuration
|
||||||
|
if len(allowedOrigins) == 0 {
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
c := cors.New(cors.Options{
|
||||||
|
AllowedOrigins: allowedOrigins,
|
||||||
|
AllowedMethods: []string{http.MethodPost, http.MethodGet},
|
||||||
|
MaxAge: 600,
|
||||||
|
AllowedHeaders: []string{"*"},
|
||||||
|
})
|
||||||
|
return c.Handler(srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// virtualHostHandler is a handler which validates the Host-header of incoming requests.
|
||||||
|
// Using virtual hosts can help prevent DNS rebinding attacks, where a 'random' domain name points to
|
||||||
|
// the service ip address (but without CORS headers). By verifying the targeted virtual host, we can
|
||||||
|
// ensure that it's a destination that the node operator has defined.
|
||||||
|
type virtualHostHandler struct {
|
||||||
|
vhosts map[string]struct{}
|
||||||
|
next http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
|
||||||
|
vhostMap := make(map[string]struct{})
|
||||||
|
for _, allowedHost := range vhosts {
|
||||||
|
vhostMap[strings.ToLower(allowedHost)] = struct{}{}
|
||||||
|
}
|
||||||
|
return &virtualHostHandler{vhostMap, next}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
|
||||||
|
func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// if r.Host is not set, we can continue serving since a browser would set the Host header
|
||||||
|
if r.Host == "" {
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
// Either invalid (too many colons) or no port specified
|
||||||
|
host = r.Host
|
||||||
|
}
|
||||||
|
if ipAddr := net.ParseIP(host); ipAddr != nil {
|
||||||
|
// It's an IP address, we can serve that
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
// Not an IP address, but a hostname. Need to validate
|
||||||
|
if _, exist := h.vhosts["*"]; exist {
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exist := h.vhosts[host]; exist {
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, "invalid host specified", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gzPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
w := gzip.NewWriter(ioutil.Discard)
|
||||||
|
return w
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type gzipResponseWriter struct {
|
||||||
|
io.Writer
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) WriteHeader(status int) {
|
||||||
|
w.Header().Del("Content-Length")
|
||||||
|
w.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
return w.Writer.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGzipHandler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
|
||||||
|
gz := gzPool.Get().(*gzip.Writer)
|
||||||
|
defer gzPool.Put(gz)
|
||||||
|
|
||||||
|
gz.Reset(w)
|
||||||
|
defer gz.Close()
|
||||||
|
|
||||||
|
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade
|
||||||
|
// request to the websocket protocol. If not, serves the the request with the http handler.
|
||||||
|
func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if isWebsocket(r) {
|
||||||
|
ws.ServeHTTP(w, r)
|
||||||
|
log.Debug("serving websocket request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWebsocket checks the header of an http request for a websocket upgrade request.
|
||||||
|
func isWebsocket(r *http.Request) bool {
|
||||||
|
return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
|
||||||
|
strings.ToLower(r.Header.Get("Connection")) == "upgrade"
|
||||||
|
}
|
38
node/rpcstack_test.go
Normal file
38
node/rpcstack_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/rpc"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) {
|
||||||
|
srv := rpc.NewServer()
|
||||||
|
|
||||||
|
handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{}))
|
||||||
|
ts := httptest.NewServer(handler)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
responses := make(chan *http.Response)
|
||||||
|
go func(responses chan *http.Response) {
|
||||||
|
client := &http.Client{}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("could not issue a GET request to the test http server", err)
|
||||||
|
}
|
||||||
|
responses <- resp
|
||||||
|
}(responses)
|
||||||
|
|
||||||
|
response := <-responses
|
||||||
|
assert.Equal(t, "websocket", response.Header.Get("Upgrade"))
|
||||||
|
}
|
@ -22,89 +22,6 @@ import (
|
|||||||
"github.com/ethereum/go-ethereum/log"
|
"github.com/ethereum/go-ethereum/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkModuleAvailability checks that all names given in modules are actually
|
|
||||||
// available API services. It assumes that the MetadataApi module ("rpc") is always available;
|
|
||||||
// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
|
|
||||||
func checkModuleAvailability(modules []string, apis []API) (bad, available []string) {
|
|
||||||
availableSet := make(map[string]struct{})
|
|
||||||
for _, api := range apis {
|
|
||||||
if _, ok := availableSet[api.Namespace]; !ok {
|
|
||||||
availableSet[api.Namespace] = struct{}{}
|
|
||||||
available = append(available, api.Namespace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, name := range modules {
|
|
||||||
if _, ok := availableSet[name]; !ok && name != MetadataApi {
|
|
||||||
bad = append(bad, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return bad, available
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules.
|
|
||||||
func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) {
|
|
||||||
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
|
|
||||||
log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
|
|
||||||
}
|
|
||||||
// Generate the whitelist based on the allowed modules
|
|
||||||
whitelist := make(map[string]bool)
|
|
||||||
for _, module := range modules {
|
|
||||||
whitelist[module] = true
|
|
||||||
}
|
|
||||||
// Register all the APIs exposed by the services
|
|
||||||
handler := NewServer()
|
|
||||||
for _, api := range apis {
|
|
||||||
if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
|
|
||||||
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
log.Debug("HTTP registered", "namespace", api.Namespace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// All APIs registered, start the HTTP listener
|
|
||||||
var (
|
|
||||||
listener net.Listener
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if listener, err = net.Listen("tcp", endpoint); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener)
|
|
||||||
return listener, handler, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartWSEndpoint starts a websocket endpoint.
|
|
||||||
func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) {
|
|
||||||
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
|
|
||||||
log.Error("Unavailable modules in WS API list", "unavailable", bad, "available", available)
|
|
||||||
}
|
|
||||||
// Generate the whitelist based on the allowed modules
|
|
||||||
whitelist := make(map[string]bool)
|
|
||||||
for _, module := range modules {
|
|
||||||
whitelist[module] = true
|
|
||||||
}
|
|
||||||
// Register all the APIs exposed by the services
|
|
||||||
handler := NewServer()
|
|
||||||
for _, api := range apis {
|
|
||||||
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
|
|
||||||
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// All APIs registered, start the HTTP listener
|
|
||||||
var (
|
|
||||||
listener net.Listener
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if listener, err = net.Listen("tcp", endpoint); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
go NewWSServer(wsOrigins, handler).Serve(listener)
|
|
||||||
return listener, handler, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartIPCEndpoint starts an IPC endpoint.
|
// StartIPCEndpoint starts an IPC endpoint.
|
||||||
func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) {
|
func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) {
|
||||||
// Register all the APIs exposed by the services.
|
// Register all the APIs exposed by the services.
|
||||||
|
66
rpc/gzip.go
66
rpc/gzip.go
@ -1,66 +0,0 @@
|
|||||||
// Copyright 2019 The go-ethereum Authors
|
|
||||||
// This file is part of the go-ethereum library.
|
|
||||||
//
|
|
||||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
|
||||||
// it under the terms of the GNU Lesser General Public License as published by
|
|
||||||
// the Free Software Foundation, either version 3 of the License, or
|
|
||||||
// (at your option) any later version.
|
|
||||||
//
|
|
||||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
|
||||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
// GNU Lesser General Public License for more details.
|
|
||||||
//
|
|
||||||
// You should have received a copy of the GNU Lesser General Public License
|
|
||||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
package rpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var gzPool = sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
w := gzip.NewWriter(ioutil.Discard)
|
|
||||||
return w
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type gzipResponseWriter struct {
|
|
||||||
io.Writer
|
|
||||||
http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *gzipResponseWriter) WriteHeader(status int) {
|
|
||||||
w.Header().Del("Content-Length")
|
|
||||||
w.ResponseWriter.WriteHeader(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
return w.Writer.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGzipHandler(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Encoding", "gzip")
|
|
||||||
|
|
||||||
gz := gzPool.Get().(*gzip.Writer)
|
|
||||||
defer gzPool.Put(gz)
|
|
||||||
|
|
||||||
gz.Reset(w)
|
|
||||||
defer gz.Close()
|
|
||||||
|
|
||||||
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
|
|
||||||
})
|
|
||||||
}
|
|
97
rpc/http.go
97
rpc/http.go
@ -25,14 +25,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"mime"
|
"mime"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/log"
|
|
||||||
"github.com/rs/cors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -209,37 +204,6 @@ func (t *httpServerConn) RemoteAddr() string {
|
|||||||
// SetWriteDeadline does nothing and always returns nil.
|
// SetWriteDeadline does nothing and always returns nil.
|
||||||
func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil }
|
func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
// NewHTTPServer creates a new HTTP RPC server around an API provider.
|
|
||||||
//
|
|
||||||
// Deprecated: Server implements http.Handler
|
|
||||||
func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv http.Handler) *http.Server {
|
|
||||||
// Wrap the CORS-handler within a host-handler
|
|
||||||
handler := newCorsHandler(srv, cors)
|
|
||||||
handler = newVHostHandler(vhosts, handler)
|
|
||||||
handler = newGzipHandler(handler)
|
|
||||||
|
|
||||||
// Make sure timeout values are meaningful
|
|
||||||
if timeouts.ReadTimeout < time.Second {
|
|
||||||
log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout)
|
|
||||||
timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout
|
|
||||||
}
|
|
||||||
if timeouts.WriteTimeout < time.Second {
|
|
||||||
log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout)
|
|
||||||
timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout
|
|
||||||
}
|
|
||||||
if timeouts.IdleTimeout < time.Second {
|
|
||||||
log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout)
|
|
||||||
timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout
|
|
||||||
}
|
|
||||||
// Bundle and start the HTTP server
|
|
||||||
return &http.Server{
|
|
||||||
Handler: handler,
|
|
||||||
ReadTimeout: timeouts.ReadTimeout,
|
|
||||||
WriteTimeout: timeouts.WriteTimeout,
|
|
||||||
IdleTimeout: timeouts.IdleTimeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP serves JSON-RPC requests over HTTP.
|
// ServeHTTP serves JSON-RPC requests over HTTP.
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Permit dumb empty requests for remote health-checks (AWS)
|
// Permit dumb empty requests for remote health-checks (AWS)
|
||||||
@ -296,64 +260,3 @@ func validateRequest(r *http.Request) (int, error) {
|
|||||||
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
|
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
|
||||||
return http.StatusUnsupportedMediaType, err
|
return http.StatusUnsupportedMediaType, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
|
|
||||||
// disable CORS support if user has not specified a custom CORS configuration
|
|
||||||
if len(allowedOrigins) == 0 {
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
c := cors.New(cors.Options{
|
|
||||||
AllowedOrigins: allowedOrigins,
|
|
||||||
AllowedMethods: []string{http.MethodPost, http.MethodGet},
|
|
||||||
MaxAge: 600,
|
|
||||||
AllowedHeaders: []string{"*"},
|
|
||||||
})
|
|
||||||
return c.Handler(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// virtualHostHandler is a handler which validates the Host-header of incoming requests.
|
|
||||||
// The virtualHostHandler can prevent DNS rebinding attacks, which do not utilize CORS-headers,
|
|
||||||
// since they do in-domain requests against the RPC api. Instead, we can see on the Host-header
|
|
||||||
// which domain was used, and validate that against a whitelist.
|
|
||||||
type virtualHostHandler struct {
|
|
||||||
vhosts map[string]struct{}
|
|
||||||
next http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
|
|
||||||
func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// if r.Host is not set, we can continue serving since a browser would set the Host header
|
|
||||||
if r.Host == "" {
|
|
||||||
h.next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
|
||||||
if err != nil {
|
|
||||||
// Either invalid (too many colons) or no port specified
|
|
||||||
host = r.Host
|
|
||||||
}
|
|
||||||
if ipAddr := net.ParseIP(host); ipAddr != nil {
|
|
||||||
// It's an IP address, we can serve that
|
|
||||||
h.next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
|
|
||||||
}
|
|
||||||
// Not an IP address, but a hostname. Need to validate
|
|
||||||
if _, exist := h.vhosts["*"]; exist {
|
|
||||||
h.next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, exist := h.vhosts[host]; exist {
|
|
||||||
h.next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.Error(w, "invalid host specified", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
|
|
||||||
vhostMap := make(map[string]struct{})
|
|
||||||
for _, allowedHost := range vhosts {
|
|
||||||
vhostMap[strings.ToLower(allowedHost)] = struct{}{}
|
|
||||||
}
|
|
||||||
return &virtualHostHandler{vhostMap, next}
|
|
||||||
}
|
|
||||||
|
@ -38,13 +38,6 @@ const (
|
|||||||
|
|
||||||
var wsBufferPool = new(sync.Pool)
|
var wsBufferPool = new(sync.Pool)
|
||||||
|
|
||||||
// NewWSServer creates a new websocket RPC server around an API provider.
|
|
||||||
//
|
|
||||||
// Deprecated: use Server.WebsocketHandler
|
|
||||||
func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
|
|
||||||
return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
||||||
//
|
//
|
||||||
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
||||||
|
Loading…
Reference in New Issue
Block a user