diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 86dc6f40f..11829cbad 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -181,6 +181,7 @@ var ( utils.GraphQLCORSDomainFlag, utils.GraphQLVirtualHostsFlag, utils.HTTPApiFlag, + utils.HTTPPathPrefixFlag, utils.LegacyRPCApiFlag, utils.WSEnabledFlag, utils.WSListenAddrFlag, @@ -190,6 +191,7 @@ var ( utils.WSApiFlag, utils.LegacyWSApiFlag, utils.WSAllowedOriginsFlag, + utils.WSPathPrefixFlag, utils.LegacyWSAllowedOriginsFlag, utils.IPCDisabledFlag, utils.IPCPathFlag, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index ba311bf7f..0ed31d7da 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -138,12 +138,14 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.HTTPListenAddrFlag, utils.HTTPPortFlag, utils.HTTPApiFlag, + utils.HTTPPathPrefixFlag, utils.HTTPCORSDomainFlag, utils.HTTPVirtualHostsFlag, utils.WSEnabledFlag, utils.WSListenAddrFlag, utils.WSPortFlag, utils.WSApiFlag, + utils.WSPathPrefixFlag, utils.WSAllowedOriginsFlag, utils.GraphQLEnabledFlag, utils.GraphQLCORSDomainFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 8f9f68ba6..aa5180dd9 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -531,6 +531,11 @@ var ( Usage: "API's offered over the HTTP-RPC interface", Value: "", } + HTTPPathPrefixFlag = cli.StringFlag{ + Name: "http.rpcprefix", + Usage: "HTTP path path prefix on which JSON-RPC is served. Use '/' to serve on all paths.", + Value: "", + } GraphQLEnabledFlag = cli.BoolFlag{ Name: "graphql", Usage: "Enable GraphQL on the HTTP-RPC server. Note that GraphQL can only be started if an HTTP server is started as well.", @@ -569,6 +574,11 @@ var ( Usage: "Origins from which to accept websockets requests", Value: "", } + WSPathPrefixFlag = cli.StringFlag{ + Name: "ws.rpcprefix", + Usage: "HTTP path prefix on which JSON-RPC is served. Use '/' to serve on all paths.", + Value: "", + } ExecFlag = cli.StringFlag{ Name: "exec", Usage: "Execute JavaScript statement", @@ -946,6 +956,10 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { if ctx.GlobalIsSet(HTTPVirtualHostsFlag.Name) { cfg.HTTPVirtualHosts = SplitAndTrim(ctx.GlobalString(HTTPVirtualHostsFlag.Name)) } + + if ctx.GlobalIsSet(HTTPPathPrefixFlag.Name) { + cfg.HTTPPathPrefix = ctx.GlobalString(HTTPPathPrefixFlag.Name) + } } // setGraphQL creates the GraphQL listener interface string from the set @@ -995,6 +1009,10 @@ func setWS(ctx *cli.Context, cfg *node.Config) { if ctx.GlobalIsSet(WSApiFlag.Name) { cfg.WSModules = SplitAndTrim(ctx.GlobalString(WSApiFlag.Name)) } + + if ctx.GlobalIsSet(WSPathPrefixFlag.Name) { + cfg.WSPathPrefix = ctx.GlobalString(WSPathPrefixFlag.Name) + } } // setIPC creates an IPC path configuration from the set command line flags, diff --git a/go.sum b/go.sum index 4b799d77f..2c16d81f0 100644 --- a/go.sum +++ b/go.sum @@ -433,7 +433,6 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -501,7 +500,6 @@ golang.org/x/tools v0.0.0-20200108203644-89082a384178/go.mod h1:TB2adYChydJhpapK golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= @@ -549,15 +547,12 @@ google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyz google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/olebedev/go-duktape.v3 v3.0.0-20200619000410-60c24ae608a6 h1:a6cXbcDDUkSBlpnkWV1bJ+vv3mOgQEltEJ2rPxroVu0= gopkg.in/olebedev/go-duktape.v3 v3.0.0-20200619000410-60c24ae608a6/go.mod h1:uAJfkITjFhyEEuUfm7bsmCZRbW5WRq8s9EY8HZ6hCns= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/urfave/cli.v1 v1.20.0 h1:NdAVW6RYxDif9DhDHaAortIu956m2c0v+09AZBPTbE0= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 71320012d..a88c9b30b 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -30,6 +30,8 @@ import ( "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/params" + + "github.com/stretchr/testify/assert" ) func TestBuildSchema(t *testing.T) { @@ -166,18 +168,8 @@ func TestGraphQLHTTPOnSamePort_GQLRequest_Unsuccessful(t *testing.T) { if err != nil { t.Fatalf("could not post: %v", err) } - bodyBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("could not read from response body: %v", err) - } - resp.Body.Close() // make sure the request is not handled successfully - if want, have := "404 page not found\n", string(bodyBytes); have != want { - t.Errorf("have:\n%v\nwant:\n%v", have, want) - } - if want, have := 404, resp.StatusCode; want != have { - t.Errorf("wrong statuscode, have:\n%v\nwant:%v", have, want) - } + assert.Equal(t, http.StatusNotFound, resp.StatusCode) } func createNode(t *testing.T, gqlEnabled bool) *node.Node { diff --git a/node/api_test.go b/node/api_test.go index a07ce833c..9c3fa3a31 100644 --- a/node/api_test.go +++ b/node/api_test.go @@ -244,7 +244,10 @@ func TestStartRPC(t *testing.T) { } for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { + t.Parallel() + // Apply some sane defaults. config := test.cfg // config.Logger = testlog.Logger(t, log.LvlDebug) diff --git a/node/config.go b/node/config.go index 61e41cd7d..447a69505 100644 --- a/node/config.go +++ b/node/config.go @@ -139,6 +139,9 @@ type Config struct { // interface. HTTPTimeouts rpc.HTTPTimeouts + // HTTPPathPrefix specifies a path prefix on which http-rpc is to be served. + HTTPPathPrefix string `toml:",omitempty"` + // WSHost is the host interface on which to start the websocket RPC server. If // this field is empty, no websocket API endpoint will be started. WSHost string @@ -148,6 +151,9 @@ type Config struct { // ephemeral nodes). WSPort int `toml:",omitempty"` + // WSPathPrefix specifies a path prefix on which ws-rpc is to be served. + WSPathPrefix string `toml:",omitempty"` + // WSOrigins is the list of domain to accept websocket requests from. Please be // aware that the server can only act upon the HTTP request the client sends and // cannot verify the validity of the request header. diff --git a/node/node.go b/node/node.go index b58594ef1..2ed4c31f6 100644 --- a/node/node.go +++ b/node/node.go @@ -135,6 +135,14 @@ func New(conf *Config) (*Node, error) { node.server.Config.NodeDatabase = node.config.NodeDB() } + // Check HTTP/WS prefixes are valid. + if err := validatePrefix("HTTP", conf.HTTPPathPrefix); err != nil { + return nil, err + } + if err := validatePrefix("WebSocket", conf.WSPathPrefix); err != nil { + return nil, err + } + // Configure RPC servers. node.http = newHTTPServer(node.log, conf.HTTPTimeouts) node.ws = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) @@ -346,6 +354,7 @@ func (n *Node) startRPC() error { CorsAllowedOrigins: n.config.HTTPCors, Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, + prefix: n.config.HTTPPathPrefix, } if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil { return err @@ -361,6 +370,7 @@ func (n *Node) startRPC() error { config := wsConfig{ Modules: n.config.WSModules, Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, } if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil { return err @@ -457,6 +467,7 @@ func (n *Node) RegisterHandler(name, path string, handler http.Handler) { if n.state != initializingState { panic("can't register HTTP handler on running/stopped node") } + n.http.mux.Handle(path, handler) n.http.handlerNames[path] = name } @@ -513,17 +524,18 @@ func (n *Node) IPCEndpoint() string { return n.ipc.endpoint } -// HTTPEndpoint returns the URL of the HTTP server. +// HTTPEndpoint returns the URL of the HTTP server. Note that this URL does not +// contain the JSON-RPC path prefix set by HTTPPathPrefix. func (n *Node) HTTPEndpoint() string { return "http://" + n.http.listenAddr() } -// WSEndpoint retrieves the current WS endpoint used by the protocol stack. +// WSEndpoint returns the current JSON-RPC over WebSocket endpoint. func (n *Node) WSEndpoint() string { if n.http.wsAllowed() { - return "ws://" + n.http.listenAddr() + return "ws://" + n.http.listenAddr() + n.http.wsConfig.prefix } - return "ws://" + n.ws.listenAddr() + return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix } // EventMux retrieves the event multiplexer used by all the network services in diff --git a/node/node_test.go b/node/node_test.go index 8f306ef02..6731dbac1 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -390,7 +390,7 @@ func TestLifecycleTerminationGuarantee(t *testing.T) { } // Tests whether a handler can be successfully mounted on the canonical HTTP server -// on the given path +// on the given prefix func TestRegisterHandler_Successful(t *testing.T) { node := createNode(t, 7878, 7979) @@ -483,7 +483,112 @@ func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) { if !checkRPC(node.HTTPEndpoint()) { t.Fatalf("http request failed") } +} +type rpcPrefixTest struct { + httpPrefix, wsPrefix string + // These lists paths on which JSON-RPC should be served / not served. + wantHTTP []string + wantNoHTTP []string + wantWS []string + wantNoWS []string +} + +func TestNodeRPCPrefix(t *testing.T) { + t.Parallel() + + tests := []rpcPrefixTest{ + // both off + { + httpPrefix: "", wsPrefix: "", + wantHTTP: []string{"/", "/?p=1"}, + wantNoHTTP: []string{"/test", "/test?p=1"}, + wantWS: []string{"/", "/?p=1"}, + wantNoWS: []string{"/test", "/test?p=1"}, + }, + // only http prefix + { + httpPrefix: "/testprefix", wsPrefix: "", + wantHTTP: []string{"/testprefix", "/testprefix?p=1", "/testprefix/x", "/testprefix/x?p=1"}, + wantNoHTTP: []string{"/", "/?p=1", "/test", "/test?p=1"}, + wantWS: []string{"/", "/?p=1"}, + wantNoWS: []string{"/testprefix", "/testprefix?p=1", "/test", "/test?p=1"}, + }, + // only ws prefix + { + httpPrefix: "", wsPrefix: "/testprefix", + wantHTTP: []string{"/", "/?p=1"}, + wantNoHTTP: []string{"/testprefix", "/testprefix?p=1", "/test", "/test?p=1"}, + wantWS: []string{"/testprefix", "/testprefix?p=1", "/testprefix/x", "/testprefix/x?p=1"}, + wantNoWS: []string{"/", "/?p=1", "/test", "/test?p=1"}, + }, + // both set + { + httpPrefix: "/testprefix", wsPrefix: "/testprefix", + wantHTTP: []string{"/testprefix", "/testprefix?p=1", "/testprefix/x", "/testprefix/x?p=1"}, + wantNoHTTP: []string{"/", "/?p=1", "/test", "/test?p=1"}, + wantWS: []string{"/testprefix", "/testprefix?p=1", "/testprefix/x", "/testprefix/x?p=1"}, + wantNoWS: []string{"/", "/?p=1", "/test", "/test?p=1"}, + }, + } + + for _, test := range tests { + test := test + name := fmt.Sprintf("http=%s ws=%s", test.httpPrefix, test.wsPrefix) + t.Run(name, func(t *testing.T) { + cfg := &Config{ + HTTPHost: "127.0.0.1", + HTTPPathPrefix: test.httpPrefix, + WSHost: "127.0.0.1", + WSPathPrefix: test.wsPrefix, + } + node, err := New(cfg) + if err != nil { + t.Fatal("can't create node:", err) + } + defer node.Close() + if err := node.Start(); err != nil { + t.Fatal("can't start node:", err) + } + test.check(t, node) + }) + } +} + +func (test rpcPrefixTest) check(t *testing.T, node *Node) { + t.Helper() + httpBase := "http://" + node.http.listenAddr() + wsBase := "ws://" + node.http.listenAddr() + + if node.WSEndpoint() != wsBase+test.wsPrefix { + t.Errorf("Error: node has wrong WSEndpoint %q", node.WSEndpoint()) + } + + for _, path := range test.wantHTTP { + resp := rpcRequest(t, httpBase+path) + if resp.StatusCode != 200 { + t.Errorf("Error: %s: bad status code %d, want 200", path, resp.StatusCode) + } + } + for _, path := range test.wantNoHTTP { + resp := rpcRequest(t, httpBase+path) + if resp.StatusCode != 404 { + t.Errorf("Error: %s: bad status code %d, want 404", path, resp.StatusCode) + } + } + for _, path := range test.wantWS { + err := wsRequest(t, wsBase+path, "") + if err != nil { + t.Errorf("Error: %s: WebSocket connection failed: %v", path, err) + } + } + for _, path := range test.wantNoWS { + err := wsRequest(t, wsBase+path, "") + if err == nil { + t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path) + } + + } } func createNode(t *testing.T, httpPort, wsPort int) *Node { diff --git a/node/rpcstack.go b/node/rpcstack.go index 81e054ec9..d693bb0bb 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -39,12 +39,14 @@ type httpConfig struct { Modules []string CorsAllowedOrigins []string Vhosts []string + prefix string // path prefix on which to mount http handler } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { Origins []string Modules []string + prefix string // path prefix on which to mount ws handler } type rpcHandler struct { @@ -62,6 +64,7 @@ type httpServer struct { listener net.Listener // non-nil when server is running // HTTP RPC handler things. + httpConfig httpConfig httpHandler atomic.Value // *rpcHandler @@ -79,6 +82,7 @@ type httpServer struct { func newHTTPServer(log log.Logger, timeouts rpc.HTTPTimeouts) *httpServer { h := &httpServer{log: log, timeouts: timeouts, handlerNames: make(map[string]string)} + h.httpHandler.Store((*rpcHandler)(nil)) h.wsHandler.Store((*rpcHandler)(nil)) return h @@ -142,12 +146,17 @@ func (h *httpServer) start() error { // if server is websocket only, return after logging if h.wsAllowed() && !h.rpcAllowed() { - h.log.Info("WebSocket enabled", "url", fmt.Sprintf("ws://%v", listener.Addr())) + url := fmt.Sprintf("ws://%v", listener.Addr()) + if h.wsConfig.prefix != "" { + url += h.wsConfig.prefix + } + h.log.Info("WebSocket enabled", "url", url) return nil } // Log http endpoint. h.log.Info("HTTP server started", "endpoint", listener.Addr(), + "prefix", h.httpConfig.prefix, "cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","), "vhosts", strings.Join(h.httpConfig.Vhosts, ","), ) @@ -170,26 +179,60 @@ func (h *httpServer) start() error { } func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - rpc := h.httpHandler.Load().(*rpcHandler) - if r.RequestURI == "/" { - // Serve JSON-RPC on the root path. - ws := h.wsHandler.Load().(*rpcHandler) - if ws != nil && isWebsocket(r) { + // check if ws request and serve if ws enabled + ws := h.wsHandler.Load().(*rpcHandler) + if ws != nil && isWebsocket(r) { + if checkPath(r, h.wsConfig.prefix) { ws.ServeHTTP(w, r) - return } - if rpc != nil { - rpc.ServeHTTP(w, r) - return - } - } else if rpc != nil { + return + } + // if http-rpc is enabled, try to serve request + rpc := h.httpHandler.Load().(*rpcHandler) + if rpc != nil { + // First try to route in the mux. // Requests to a path below root are handled by the mux, // which has all the handlers registered via Node.RegisterHandler. // These are made available when RPC is enabled. - h.mux.ServeHTTP(w, r) - return + muxHandler, pattern := h.mux.Handler(r) + if pattern != "" { + muxHandler.ServeHTTP(w, r) + return + } + + if checkPath(r, h.httpConfig.prefix) { + rpc.ServeHTTP(w, r) + return + } } - w.WriteHeader(404) + w.WriteHeader(http.StatusNotFound) +} + +// checkPath checks whether a given request URL matches a given path prefix. +func checkPath(r *http.Request, path string) bool { + // if no prefix has been specified, request URL must be on root + if path == "" { + return r.URL.Path == "/" + } + // otherwise, check to make sure prefix matches + return len(r.URL.Path) >= len(path) && r.URL.Path[:len(path)] == path +} + +// validatePrefix checks if 'path' is a valid configuration value for the RPC prefix option. +func validatePrefix(what, path string) error { + if path == "" { + return nil + } + if path[0] != '/' { + return fmt.Errorf(`%s RPC path prefix %q does not contain leading "/"`, what, path) + } + if strings.ContainsAny(path, "?#") { + // This is just to avoid confusion. While these would match correctly (i.e. they'd + // match if URL-escaped into path), it's not easy to understand for users when + // setting that on the command line. + return fmt.Errorf("%s RPC path prefix %q contains URL meta-characters", what, path) + } + return nil } // stop shuts down the HTTP server. diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 8267fb2f1..f92f0ba39 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -18,7 +18,10 @@ package node import ( "bytes" + "fmt" "net/http" + "net/url" + "strconv" "strings" "testing" @@ -31,25 +34,27 @@ import ( // TestCorsHandler makes sure CORS are properly handled on the http server. func TestCorsHandler(t *testing.T) { - srv := createAndStartServer(t, httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, wsConfig{}) + srv := createAndStartServer(t, &httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, &wsConfig{}) defer srv.stop() + url := "http://" + srv.listenAddr() - resp := testRequest(t, "origin", "test.com", "", srv) + resp := rpcRequest(t, url, "origin", "test.com") assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) - resp2 := testRequest(t, "origin", "bad", "", srv) + resp2 := rpcRequest(t, url, "origin", "bad") assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) } // TestVhosts makes sure vhosts are properly handled on the http server. func TestVhosts(t *testing.T) { - srv := createAndStartServer(t, httpConfig{Vhosts: []string{"test"}}, false, wsConfig{}) + srv := createAndStartServer(t, &httpConfig{Vhosts: []string{"test"}}, false, &wsConfig{}) defer srv.stop() + url := "http://" + srv.listenAddr() - resp := testRequest(t, "", "", "test", srv) + resp := rpcRequest(t, url, "host", "test") assert.Equal(t, resp.StatusCode, http.StatusOK) - resp2 := testRequest(t, "", "", "bad", srv) + resp2 := rpcRequest(t, url, "host", "bad") assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } @@ -138,14 +143,15 @@ func TestWebsocketOrigins(t *testing.T) { }, } for _, tc := range tests { - srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: splitAndTrim(tc.spec)}) + srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) + url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { - if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err != nil { + if err := wsRequest(t, url, origin); err != nil { t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) } } for _, origin := range tc.expFail { - if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err == nil { + if err := wsRequest(t, url, origin); err == nil { t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) } } @@ -168,47 +174,118 @@ func TestIsWebsocket(t *testing.T) { assert.True(t, isWebsocket(r)) } -func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer { +func Test_checkPath(t *testing.T) { + tests := []struct { + req *http.Request + prefix string + expected bool + }{ + { + req: &http.Request{URL: &url.URL{Path: "/test"}}, + prefix: "/test", + expected: true, + }, + { + req: &http.Request{URL: &url.URL{Path: "/testing"}}, + prefix: "/test", + expected: true, + }, + { + req: &http.Request{URL: &url.URL{Path: "/"}}, + prefix: "/test", + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: "/fail"}}, + prefix: "/test", + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: "/"}}, + prefix: "", + expected: true, + }, + { + req: &http.Request{URL: &url.URL{Path: "/fail"}}, + prefix: "", + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: "/"}}, + prefix: "/", + expected: true, + }, + { + req: &http.Request{URL: &url.URL{Path: "/testing"}}, + prefix: "/", + expected: true, + }, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + assert.Equal(t, tt.expected, checkPath(tt.req, tt.prefix)) + }) + } +} + +func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig) *httpServer { t.Helper() srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) - - assert.NoError(t, srv.enableRPC(nil, conf)) + assert.NoError(t, srv.enableRPC(nil, *conf)) if ws { - assert.NoError(t, srv.enableWS(nil, wsConf)) + assert.NoError(t, srv.enableWS(nil, *wsConf)) } assert.NoError(t, srv.setListenAddr("localhost", 0)) assert.NoError(t, srv.start()) - return srv } -func attemptWebsocketConnectionFromOrigin(t *testing.T, srv *httpServer, browserOrigin string) error { +// wsRequest attempts to open a WebSocket connection to the given URL. +func wsRequest(t *testing.T, url, browserOrigin string) error { t.Helper() - dialer := websocket.DefaultDialer - _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{browserOrigin}, - }) + t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) + + headers := make(http.Header) + if browserOrigin != "" { + headers.Set("Origin", browserOrigin) + } + conn, _, err := websocket.DefaultDialer.Dial(url, headers) + if conn != nil { + conn.Close() + } return err } -func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response { +// rpcRequest performs a JSON-RPC request to the given URL. +func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response { t.Helper() - body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,method":"rpc_modules"}`)) - req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body) - req.Header.Set("content-type", "application/json") - if key != "" && value != "" { - req.Header.Set(key, value) + // Create the request. + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules","params":[]}`)) + req, err := http.NewRequest("POST", url, body) + if err != nil { + t.Fatal("could not create http request:", err) } - if host != "" { - req.Host = host + req.Header.Set("content-type", "application/json") + + // Apply extra headers. + if len(extraHeaders)%2 != 0 { + panic("odd extraHeaders length") + } + for i := 0; i < len(extraHeaders); i += 2 { + key, value := extraHeaders[i], extraHeaders[i+1] + if strings.ToLower(key) == "host" { + req.Host = value + } else { + req.Header.Set(key, value) + } } - client := http.DefaultClient - resp, err := client.Do(req) + // Perform the request. + t.Logf("checking RPC/HTTP on %s %v", url, extraHeaders) + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) }