diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 0ee120efd..8267fb2f1 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -19,6 +19,7 @@ package node import ( "bytes" "net/http" + "strings" "testing" "github.com/ethereum/go-ethereum/internal/testlog" @@ -52,25 +53,104 @@ func TestVhosts(t *testing.T) { assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } +type originTest struct { + spec string + expOk []string + expFail []string +} + +// splitAndTrim splits input separated by a comma +// and trims excessive white space from the substrings. +// Copied over from flags.go +func splitAndTrim(input string) (ret []string) { + l := strings.Split(input, ",") + for _, r := range l { + r = strings.TrimSpace(r) + if len(r) > 0 { + ret = append(ret, r) + } + } + return ret +} + // TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server. func TestWebsocketOrigins(t *testing.T) { - srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: []string{"test"}}) - defer srv.stop() - - dialer := websocket.DefaultDialer - _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"test"}, - }) - assert.NoError(t, err) - - _, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"bad"}, - }) - assert.Error(t, err) + tests := []originTest{ + { + spec: "*", // allow all + expOk: []string{"", "http://test", "https://test", "http://test:8540", "https://test:8540", + "http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + { + spec: "test", + expOk: []string{"http://test", "https://test", "http://test:8540", "https://test:8540"}, + expFail: []string{"http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + // scheme tests + { + spec: "https://test", + expOk: []string{"https://test", "https://test:9999"}, + expFail: []string{ + "test", // no scheme, required by spec + "http://test", // wrong scheme + "http://test.foo", "https://a.test.x", // subdomain variatoins + "http://testx:8540", "https://xtest:8540"}, + }, + // ip tests + { + spec: "https://12.34.56.78", + expOk: []string{"https://12.34.56.78", "https://12.34.56.78:8540"}, + expFail: []string{ + "http://12.34.56.78", // wrong scheme + "http://12.34.56.78:443", // wrong scheme + "http://1.12.34.56.78", // wrong 'domain name' + "http://12.34.56.78.a", // wrong 'domain name' + "https://87.65.43.21", "http://87.65.43.21:8540", "https://87.65.43.21:8540"}, + }, + // port tests + { + spec: "test:8540", + expOk: []string{"http://test:8540", "https://test:8540"}, + expFail: []string{ + "http://test", "https://test", // spec says port required + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // scheme and port + { + spec: "https://test:8540", + expOk: []string{"https://test:8540"}, + expFail: []string{ + "https://test", // missing port + "http://test", // missing port, + wrong scheme + "http://test:8540", // wrong scheme + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // several allowed origins + { + spec: "localhost,http://127.0.0.1", + expOk: []string{"localhost", "http://localhost", "https://localhost:8443", + "http://127.0.0.1", "http://127.0.0.1:8080"}, + expFail: []string{ + "https://127.0.0.1", // wrong scheme + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + } + for _, tc := range tests { + srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: splitAndTrim(tc.spec)}) + for _, origin := range tc.expOk { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err != nil { + t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) + } + } + for _, origin := range tc.expFail { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err == nil { + t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) + } + } + srv.stop() + } } // TestIsWebsocket tests if an incoming websocket upgrade request is handled properly. @@ -103,6 +183,17 @@ func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfi return srv } +func attemptWebsocketConnectionFromOrigin(t *testing.T, srv *httpServer, browserOrigin string) error { + t.Helper() + dialer := websocket.DefaultDialer + _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ + "Content-type": []string{"application/json"}, + "Sec-WebSocket-Version": []string{"13"}, + "Origin": []string{browserOrigin}, + }) + return err +} + func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response { t.Helper() diff --git a/rpc/websocket.go b/rpc/websocket.go index a716383be..cd60eeb61 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -75,14 +75,14 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { allowAllOrigins = true } if origin != "" { - origins.Add(strings.ToLower(origin)) + origins.Add(origin) } } // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { - origins.Add("http://" + strings.ToLower(hostname)) + origins.Add("http://" + hostname) } } log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) @@ -97,7 +97,7 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { } // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) - if allowAllOrigins || origins.Contains(origin) { + if allowAllOrigins || originIsAllowed(origins, origin) { return true } log.Warn("Rejected WebSocket connection", "origin", origin) @@ -120,6 +120,65 @@ func (e wsHandshakeError) Error() string { return s } +func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { + it := allowedOrigins.Iterator() + for origin := range it.C { + if ruleAllowsOrigin(origin.(string), browserOrigin) { + return true + } + } + return false +} + +func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { + var ( + allowedScheme, allowedHostname, allowedPort string + browserScheme, browserHostname, browserPort string + err error + ) + allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) + if err != nil { + log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) + return false + } + browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) + if err != nil { + log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) + return false + } + if allowedScheme != "" && allowedScheme != browserScheme { + return false + } + if allowedHostname != "" && allowedHostname != browserHostname { + return false + } + if allowedPort != "" && allowedPort != browserPort { + return false + } + return true +} + +func parseOriginURL(origin string) (string, string, string, error) { + parsedURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + return "", "", "", err + } + var scheme, hostname, port string + if strings.Contains(origin, "://") { + scheme = parsedURL.Scheme + hostname = parsedURL.Hostname() + port = parsedURL.Port() + } else { + scheme = "" + hostname = parsedURL.Scheme + port = parsedURL.Opaque + if hostname == "" { + hostname = origin + } + } + return scheme, hostname, port, nil +} + // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {