diff --git a/node/config.go b/node/config.go index 2047299fb..49959d5ec 100644 --- a/node/config.go +++ b/node/config.go @@ -201,7 +201,7 @@ type Config struct { // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. AllowUnprotectedTxs bool `toml:",omitempty"` - // JWTSecret is the hex-encoded jwt secret. + // JWTSecret is the path to the hex-encoded jwt secret. JWTSecret string `toml:",omitempty"` } diff --git a/node/jwt_auth.go b/node/jwt_auth.go new file mode 100644 index 000000000..d4f8193ca --- /dev/null +++ b/node/jwt_auth.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "fmt" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" +) + +// NewJWTAuth creates an rpc client authentication provider that uses JWT. The +// secret MUST be 32 bytes (256 bits) as defined by the Engine-API authentication spec. +// +// See https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md +// for more details about this authentication scheme. +func NewJWTAuth(jwtsecret [32]byte) rpc.HTTPAuth { + return func(h http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now()}, + }) + s, err := token.SignedString(jwtsecret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + h.Set("Authorization", "Bearer "+s) + return nil + } +} diff --git a/node/node.go b/node/node.go index b60e32f22..3cbefef02 100644 --- a/node/node.go +++ b/node/node.go @@ -668,6 +668,19 @@ func (n *Node) WSEndpoint() string { return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix } +// HTTPAuthEndpoint returns the URL of the authenticated HTTP server. +func (n *Node) HTTPAuthEndpoint() string { + return "http://" + n.httpAuth.listenAddr() +} + +// WSAuthEndpoint returns the current authenticated JSON-RPC over WebSocket endpoint. +func (n *Node) WSAuthEndpoint() string { + if n.httpAuth.wsAllowed() { + return "ws://" + n.httpAuth.listenAddr() + n.httpAuth.wsConfig.prefix + } + return "ws://" + n.wsAuth.listenAddr() + n.wsAuth.wsConfig.prefix +} + // EventMux retrieves the event multiplexer used by all the network services in // the current protocol stack. func (n *Node) EventMux() *event.TypeMux { diff --git a/node/node_auth_test.go b/node/node_auth_test.go new file mode 100644 index 000000000..597cd8531 --- /dev/null +++ b/node/node_auth_test.go @@ -0,0 +1,237 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "context" + crand "crypto/rand" + "fmt" + "net/http" + "os" + "path" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" +) + +type helloRPC string + +func (ta helloRPC) HelloWorld() (string, error) { + return string(ta), nil +} + +type authTest struct { + name string + endpoint string + prov rpc.HTTPAuth + expectDialFail bool + expectCall1Fail bool + expectCall2Fail bool +} + +func (at *authTest) Run(t *testing.T) { + ctx := context.Background() + cl, err := rpc.DialOptions(ctx, at.endpoint, rpc.WithHTTPAuth(at.prov)) + if at.expectDialFail { + if err == nil { + t.Fatal("expected initial dial to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to dial rpc endpoint: %v", err) + } + + var x string + err = cl.CallContext(ctx, &x, "engine_helloWorld") + if at.expectCall1Fail { + if err == nil { + t.Fatal("expected call 1 to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to call rpc endpoint: %v", err) + } + if x != "hello engine" { + t.Fatalf("method was silent but did not return expected value: %q", x) + } + + err = cl.CallContext(ctx, &x, "eth_helloWorld") + if at.expectCall2Fail { + if err == nil { + t.Fatal("expected call 2 to fail") + } else { + return + } + } + if err != nil { + t.Fatalf("failed to call rpc endpoint: %v", err) + } + if x != "hello eth" { + t.Fatalf("method was silent but did not return expected value: %q", x) + } +} + +func TestAuthEndpoints(t *testing.T) { + var secret [32]byte + if _, err := crand.Read(secret[:]); err != nil { + t.Fatalf("failed to create jwt secret: %v", err) + } + // Geth must read it from a file, and does not support in-memory JWT secrets, so we create a temporary file. + jwtPath := path.Join(t.TempDir(), "jwt_secret") + if err := os.WriteFile(jwtPath, []byte(hexutil.Encode(secret[:])), 0600); err != nil { + t.Fatalf("failed to prepare jwt secret file: %v", err) + } + // We get ports assigned by the node automatically + conf := &Config{ + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + AuthAddr: "127.0.0.1", + AuthPort: 0, + JWTSecret: jwtPath, + + WSModules: []string{"eth", "engine"}, + HTTPModules: []string{"eth", "engine"}, + } + node, err := New(conf) + if err != nil { + t.Fatalf("could not create a new node: %v", err) + } + // register dummy apis so we can test the modules are available and reachable with authentication + node.RegisterAPIs([]rpc.API{ + { + Namespace: "engine", + Version: "1.0", + Service: helloRPC("hello engine"), + Public: true, + Authenticated: true, + }, + { + Namespace: "eth", + Version: "1.0", + Service: helloRPC("hello eth"), + Public: true, + Authenticated: true, + }, + }) + if err := node.Start(); err != nil { + t.Fatalf("failed to start test node: %v", err) + } + defer node.Close() + + // sanity check we are running different endpoints + if a, b := node.WSEndpoint(), node.WSAuthEndpoint(); a == b { + t.Fatalf("expected ws and auth-ws endpoints to be different, got: %q and %q", a, b) + } + if a, b := node.HTTPEndpoint(), node.HTTPAuthEndpoint(); a == b { + t.Fatalf("expected http and auth-http endpoints to be different, got: %q and %q", a, b) + } + + goodAuth := NewJWTAuth(secret) + var otherSecret [32]byte + if _, err := crand.Read(otherSecret[:]); err != nil { + t.Fatalf("failed to create jwt secret: %v", err) + } + badAuth := NewJWTAuth(otherSecret) + + notTooLong := time.Second * 57 + tooLong := time.Second * 60 + requestDelay := time.Second + + testCases := []authTest{ + // Auth works + {name: "ws good", endpoint: node.WSAuthEndpoint(), prov: goodAuth, expectCall1Fail: false}, + {name: "http good", endpoint: node.HTTPAuthEndpoint(), prov: goodAuth, expectCall1Fail: false}, + + // Try a bad auth + {name: "ws bad", endpoint: node.WSAuthEndpoint(), prov: badAuth, expectDialFail: true}, // ws auth is immediate + {name: "http bad", endpoint: node.HTTPAuthEndpoint(), prov: badAuth, expectCall1Fail: true}, // http auth is on first call + + // A common mistake with JWT is to allow the "none" algorithm, which is a valid JWT but not secure. + {name: "ws none", endpoint: node.WSAuthEndpoint(), prov: noneAuth(secret), expectDialFail: true}, + {name: "http none", endpoint: node.HTTPAuthEndpoint(), prov: noneAuth(secret), expectCall1Fail: true}, + + // claims of 5 seconds or more, older or newer, are not allowed + {name: "ws too old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectDialFail: true}, + {name: "http too old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectCall1Fail: true}, + // note: for it to be too long we need to add a delay, so that once we receive the request, the difference has not dipped below the "tooLong" + {name: "ws too new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectDialFail: true}, + {name: "http too new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectCall1Fail: true}, + + // Try offset the time, but stay just within bounds + {name: "ws old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)}, + {name: "http old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)}, + {name: "ws new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)}, + {name: "http new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)}, + + // ws only authenticates on initial dial, then continues communication + {name: "ws single auth", endpoint: node.WSAuthEndpoint(), prov: changingAuth(goodAuth, badAuth)}, + {name: "http call fail auth", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, badAuth), expectCall2Fail: true}, + {name: "http call fail time", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, offsetTimeAuth(secret, tooLong+requestDelay)), expectCall2Fail: true}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, testCase.Run) + } +} + +func noneAuth(secret [32]byte) rpc.HTTPAuth { + return func(header http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now()}, + }) + s, err := token.SignedString(secret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + header.Set("Authorization", "Bearer "+s) + return nil + } +} + +func changingAuth(provs ...rpc.HTTPAuth) rpc.HTTPAuth { + i := 0 + return func(header http.Header) error { + i += 1 + if i > len(provs) { + i = len(provs) + } + return provs[i-1](header) + } +} + +func offsetTimeAuth(secret [32]byte, offset time.Duration) rpc.HTTPAuth { + return func(header http.Header) error { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iat": &jwt.NumericDate{Time: time.Now().Add(offset)}, + }) + s, err := token.SignedString(secret[:]) + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + header.Set("Authorization", "Bearer "+s) + return nil + } +} diff --git a/rpc/client.go b/rpc/client.go index d3ce02977..8288f976e 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "net/url" + "os" "reflect" "strconv" "sync/atomic" @@ -99,7 +100,7 @@ type Client struct { reqTimeout chan *requestOp // removes response IDs when call timeout expires } -type reconnectFunc func(ctx context.Context) (ServerCodec, error) +type reconnectFunc func(context.Context) (ServerCodec, error) type clientContextKey struct{} @@ -153,14 +154,16 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro // // The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a // file name with no URL scheme, a local socket connection is established using UNIX -// domain sockets on supported platforms and named pipes on Windows. If you want to -// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead. +// domain sockets on supported platforms and named pipes on Windows. +// +// If you want to further configure the transport, use DialOptions instead of this +// function. // // For websocket connections, the origin is set to the local host name. // -// The client reconnects automatically if the connection is lost. +// The client reconnects automatically when the connection is lost. func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) + return DialOptions(context.Background(), rawurl) } // DialContext creates a new RPC client, just like Dial. @@ -168,22 +171,46 @@ func Dial(rawurl string) (*Client, error) { // The context is used to cancel or time out the initial connection establishment. It does // not affect subsequent interactions with the client. func DialContext(ctx context.Context, rawurl string) (*Client, error) { + return DialOptions(ctx, rawurl) +} + +// DialOptions creates a new RPC client for the given URL. You can supply any of the +// pre-defined client options to configure the underlying transport. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +// +// The client reconnects automatically when the connection is lost. +func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } + + cfg := new(clientConfig) + for _, opt := range options { + opt.applyOption(cfg) + } + + var reconnect reconnectFunc switch u.Scheme { case "http", "https": - return DialHTTP(rawurl) + reconnect = newClientTransportHTTP(rawurl, cfg) case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") + rc, err := newClientTransportWS(rawurl, cfg) + if err != nil { + return nil, err + } + reconnect = rc case "stdio": - return DialStdIO(ctx) + reconnect = newClientTransportIO(os.Stdin, os.Stdout) case "": - return DialIPC(ctx, rawurl) + reconnect = newClientTransportIPC(rawurl) default: return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) } + + return newClient(ctx, reconnect) } // ClientFromContext retrieves the client from the context, if any. This can be used to perform diff --git a/rpc/client_opt.go b/rpc/client_opt.go new file mode 100644 index 000000000..5ad7c22b3 --- /dev/null +++ b/rpc/client_opt.go @@ -0,0 +1,106 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +// ClientOption is a configuration option for the RPC client. +type ClientOption interface { + applyOption(*clientConfig) +} + +type clientConfig struct { + httpClient *http.Client + httpHeaders http.Header + httpAuth HTTPAuth + + wsDialer *websocket.Dialer +} + +func (cfg *clientConfig) initHeaders() { + if cfg.httpHeaders == nil { + cfg.httpHeaders = make(http.Header) + } +} + +func (cfg *clientConfig) setHeader(key, value string) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) +} + +type optionFunc func(*clientConfig) + +func (fn optionFunc) applyOption(opt *clientConfig) { + fn(opt) +} + +// WithWebsocketDialer configures the websocket.Dialer used by the RPC client. +func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsDialer = &dialer + }) +} + +// WithHeader configures HTTP headers set by the RPC client. Headers set using this option +// will be used for both HTTP and WebSocket connections. +func WithHeader(key, value string) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) + }) +} + +// WithHeaders configures HTTP headers set by the RPC client. Headers set using this +// option will be used for both HTTP and WebSocket connections. +func WithHeaders(headers http.Header) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + for k, vs := range headers { + cfg.httpHeaders[k] = vs + } + }) +} + +// WithHTTPClient configures the http.Client used by the RPC client. +func WithHTTPClient(c *http.Client) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.httpClient = c + }) +} + +// WithHTTPAuth configures HTTP request authentication. The given provider will be called +// whenever a request is made. Note that only one authentication provider can be active at +// any time. +func WithHTTPAuth(a HTTPAuth) ClientOption { + if a == nil { + panic("nil auth") + } + return optionFunc(func(cfg *clientConfig) { + cfg.httpAuth = a + }) +} + +// A HTTPAuth function is called by the client whenever a HTTP request is sent. +// The function must be safe for concurrent use. +// +// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add +// auth information to the request. +type HTTPAuth func(h http.Header) error diff --git a/rpc/client_opt_test.go b/rpc/client_opt_test.go new file mode 100644 index 000000000..d7cc2572a --- /dev/null +++ b/rpc/client_opt_test.go @@ -0,0 +1,25 @@ +package rpc_test + +import ( + "context" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/rpc" +) + +// This example configures a HTTP-based RPC client with two options - one setting the +// overall request timeout, the other adding a custom HTTP header to all requests. +func ExampleDialOptions() { + tokenHeader := rpc.WithHeader("x-token", "foo") + httpClient := rpc.WithHTTPClient(&http.Client{ + Timeout: 10 * time.Second, + }) + + ctx := context.Background() + c, err := rpc.DialOptions(ctx, "http://rpc.example.com", httpClient, tokenHeader) + if err != nil { + panic(err) + } + c.Close() +} diff --git a/rpc/http.go b/rpc/http.go index 858d80858..8595959af 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -45,6 +45,7 @@ type httpConn struct { closeCh chan interface{} mu sync.Mutex // protects headers headers http.Header + auth HTTPAuth } // httpConn implements ServerCodec, but it is treated specially by Client @@ -117,8 +118,15 @@ var DefaultHTTPTimeouts = HTTPTimeouts{ IdleTimeout: 120 * time.Second, } +// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. +func DialHTTP(endpoint string) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client)) +} + // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // using the provided HTTP Client. +// +// Deprecated: use DialOptions and the WithHTTPClient option. func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { // Sanity check URL so we don't end up with a client that will fail every request. _, err := url.Parse(endpoint) @@ -126,24 +134,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { return nil, err } - initctx := context.Background() - headers := make(http.Header, 2) - headers.Set("accept", contentType) - headers.Set("content-type", contentType) - return newClient(initctx, func(context.Context) (ServerCodec, error) { - hc := &httpConn{ - client: client, - headers: headers, - url: endpoint, - closeCh: make(chan interface{}), - } - return hc, nil - }) + var cfg clientConfig + fn := newClientTransportHTTP(endpoint, &cfg) + return newClient(context.Background(), fn) } -// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. -func DialHTTP(endpoint string) (*Client, error) { - return DialHTTPWithClient(endpoint, new(http.Client)) +func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { + headers := make(http.Header, 2+len(cfg.httpHeaders)) + headers.Set("accept", contentType) + headers.Set("content-type", contentType) + for key, values := range cfg.httpHeaders { + headers[key] = values + } + + client := cfg.httpClient + if client == nil { + client = new(http.Client) + } + + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + auth: cfg.httpAuth, + closeCh: make(chan interface{}), + } + + return func(ctx context.Context) (ServerCodec, error) { + return hc, nil + } } func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { @@ -195,6 +214,11 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos hc.mu.Lock() req.Header = hc.headers.Clone() hc.mu.Unlock() + if hc.auth != nil { + if err := hc.auth(req.Header); err != nil { + return nil, err + } + } // do request resp, err := hc.client.Do(req) diff --git a/rpc/ipc.go b/rpc/ipc.go index 07a211c62..d9e0de62e 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -46,11 +46,15 @@ func (s *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { + return newClient(ctx, newClientTransportIPC(endpoint)) +} + +func newClientTransportIPC(endpoint string) reconnectFunc { + return func(ctx context.Context) (ServerCodec, error) { conn, err := newIPCConnection(ctx, endpoint) if err != nil { return nil, err } return NewCodec(conn), err - }) + } } diff --git a/rpc/stdio.go b/rpc/stdio.go index be2bab1c9..ae32db26e 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -32,12 +32,16 @@ func DialStdIO(ctx context.Context) (*Client, error) { // DialIO creates a client which uses the given IO channels func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { - return newClient(ctx, func(_ context.Context) (ServerCodec, error) { + return newClient(ctx, newClientTransportIO(in, out)) +} + +func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc { + return func(context.Context) (ServerCodec, error) { return NewCodec(stdioConn{ in: in, out: out, }), nil - }) + } } type stdioConn struct { diff --git a/rpc/websocket.go b/rpc/websocket.go index 28380d8aa..f2a923446 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -181,24 +181,23 @@ func parseOriginURL(origin string) (string, string, string, error) { return scheme, hostname, port, nil } -// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server -// that is listening on the given endpoint using the provided dialer. +// DialWebsocketWithDialer creates a new RPC client using WebSocket. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +// +// Deprecated: use DialOptions and the WithWebsocketDialer option. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { - endpoint, header, err := wsClientHeaders(endpoint, origin) + cfg := new(clientConfig) + cfg.wsDialer = &dialer + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) if err != nil { return nil, err } - return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { - conn, resp, err := dialer.DialContext(ctx, endpoint, header) - if err != nil { - hErr := wsHandshakeError{err: err} - if resp != nil { - hErr.status = resp.Status - } - return nil, hErr - } - return newWebsocketCodec(conn, endpoint, header), nil - }) + return newClient(ctx, connect) } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -207,12 +206,53 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { - dialer := websocket.Dialer{ - ReadBufferSize: wsReadBuffer, - WriteBufferSize: wsWriteBuffer, - WriteBufferPool: wsBufferPool, + cfg := new(clientConfig) + if origin != "" { + cfg.setHeader("origin", origin) } - return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, connect) +} + +func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { + dialer := cfg.wsDialer + if dialer == nil { + dialer = &websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } + } + + dialURL, header, err := wsClientHeaders(endpoint, "") + if err != nil { + return nil, err + } + for key, values := range cfg.httpHeaders { + header[key] = values + } + + connect := func(ctx context.Context) (ServerCodec, error) { + header := header.Clone() + if cfg.httpAuth != nil { + if err := cfg.httpAuth(header); err != nil { + return nil, err + } + } + conn, resp, err := dialer.DialContext(ctx, dialURL, header) + if err != nil { + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr + } + return newWebsocketCodec(conn, dialURL, header), nil + } + return connect, nil } func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {