diff --git a/node/config.go b/node/config.go
index 2047299fb..49959d5ec 100644
--- a/node/config.go
+++ b/node/config.go
@@ -201,7 +201,7 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"`
- // JWTSecret is the hex-encoded jwt secret.
+ // JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"`
}
diff --git a/node/jwt_auth.go b/node/jwt_auth.go
new file mode 100644
index 000000000..d4f8193ca
--- /dev/null
+++ b/node/jwt_auth.go
@@ -0,0 +1,45 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/ethereum/go-ethereum/rpc"
+ "github.com/golang-jwt/jwt/v4"
+)
+
+// NewJWTAuth creates an rpc client authentication provider that uses JWT. The
+// secret MUST be 32 bytes (256 bits) as defined by the Engine-API authentication spec.
+//
+// See https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md
+// for more details about this authentication scheme.
+func NewJWTAuth(jwtsecret [32]byte) rpc.HTTPAuth {
+ return func(h http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now()},
+ })
+ s, err := token.SignedString(jwtsecret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ h.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
diff --git a/node/node.go b/node/node.go
index b60e32f22..3cbefef02 100644
--- a/node/node.go
+++ b/node/node.go
@@ -668,6 +668,19 @@ func (n *Node) WSEndpoint() string {
return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix
}
+// HTTPAuthEndpoint returns the URL of the authenticated HTTP server.
+func (n *Node) HTTPAuthEndpoint() string {
+ return "http://" + n.httpAuth.listenAddr()
+}
+
+// WSAuthEndpoint returns the current authenticated JSON-RPC over WebSocket endpoint.
+func (n *Node) WSAuthEndpoint() string {
+ if n.httpAuth.wsAllowed() {
+ return "ws://" + n.httpAuth.listenAddr() + n.httpAuth.wsConfig.prefix
+ }
+ return "ws://" + n.wsAuth.listenAddr() + n.wsAuth.wsConfig.prefix
+}
+
// EventMux retrieves the event multiplexer used by all the network services in
// the current protocol stack.
func (n *Node) EventMux() *event.TypeMux {
diff --git a/node/node_auth_test.go b/node/node_auth_test.go
new file mode 100644
index 000000000..597cd8531
--- /dev/null
+++ b/node/node_auth_test.go
@@ -0,0 +1,237 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package node
+
+import (
+ "context"
+ crand "crypto/rand"
+ "fmt"
+ "net/http"
+ "os"
+ "path"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/rpc"
+ "github.com/golang-jwt/jwt/v4"
+)
+
+type helloRPC string
+
+func (ta helloRPC) HelloWorld() (string, error) {
+ return string(ta), nil
+}
+
+type authTest struct {
+ name string
+ endpoint string
+ prov rpc.HTTPAuth
+ expectDialFail bool
+ expectCall1Fail bool
+ expectCall2Fail bool
+}
+
+func (at *authTest) Run(t *testing.T) {
+ ctx := context.Background()
+ cl, err := rpc.DialOptions(ctx, at.endpoint, rpc.WithHTTPAuth(at.prov))
+ if at.expectDialFail {
+ if err == nil {
+ t.Fatal("expected initial dial to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to dial rpc endpoint: %v", err)
+ }
+
+ var x string
+ err = cl.CallContext(ctx, &x, "engine_helloWorld")
+ if at.expectCall1Fail {
+ if err == nil {
+ t.Fatal("expected call 1 to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to call rpc endpoint: %v", err)
+ }
+ if x != "hello engine" {
+ t.Fatalf("method was silent but did not return expected value: %q", x)
+ }
+
+ err = cl.CallContext(ctx, &x, "eth_helloWorld")
+ if at.expectCall2Fail {
+ if err == nil {
+ t.Fatal("expected call 2 to fail")
+ } else {
+ return
+ }
+ }
+ if err != nil {
+ t.Fatalf("failed to call rpc endpoint: %v", err)
+ }
+ if x != "hello eth" {
+ t.Fatalf("method was silent but did not return expected value: %q", x)
+ }
+}
+
+func TestAuthEndpoints(t *testing.T) {
+ var secret [32]byte
+ if _, err := crand.Read(secret[:]); err != nil {
+ t.Fatalf("failed to create jwt secret: %v", err)
+ }
+ // Geth must read it from a file, and does not support in-memory JWT secrets, so we create a temporary file.
+ jwtPath := path.Join(t.TempDir(), "jwt_secret")
+ if err := os.WriteFile(jwtPath, []byte(hexutil.Encode(secret[:])), 0600); err != nil {
+ t.Fatalf("failed to prepare jwt secret file: %v", err)
+ }
+ // We get ports assigned by the node automatically
+ conf := &Config{
+ HTTPHost: "127.0.0.1",
+ HTTPPort: 0,
+ WSHost: "127.0.0.1",
+ WSPort: 0,
+ AuthAddr: "127.0.0.1",
+ AuthPort: 0,
+ JWTSecret: jwtPath,
+
+ WSModules: []string{"eth", "engine"},
+ HTTPModules: []string{"eth", "engine"},
+ }
+ node, err := New(conf)
+ if err != nil {
+ t.Fatalf("could not create a new node: %v", err)
+ }
+ // register dummy apis so we can test the modules are available and reachable with authentication
+ node.RegisterAPIs([]rpc.API{
+ {
+ Namespace: "engine",
+ Version: "1.0",
+ Service: helloRPC("hello engine"),
+ Public: true,
+ Authenticated: true,
+ },
+ {
+ Namespace: "eth",
+ Version: "1.0",
+ Service: helloRPC("hello eth"),
+ Public: true,
+ Authenticated: true,
+ },
+ })
+ if err := node.Start(); err != nil {
+ t.Fatalf("failed to start test node: %v", err)
+ }
+ defer node.Close()
+
+ // sanity check we are running different endpoints
+ if a, b := node.WSEndpoint(), node.WSAuthEndpoint(); a == b {
+ t.Fatalf("expected ws and auth-ws endpoints to be different, got: %q and %q", a, b)
+ }
+ if a, b := node.HTTPEndpoint(), node.HTTPAuthEndpoint(); a == b {
+ t.Fatalf("expected http and auth-http endpoints to be different, got: %q and %q", a, b)
+ }
+
+ goodAuth := NewJWTAuth(secret)
+ var otherSecret [32]byte
+ if _, err := crand.Read(otherSecret[:]); err != nil {
+ t.Fatalf("failed to create jwt secret: %v", err)
+ }
+ badAuth := NewJWTAuth(otherSecret)
+
+ notTooLong := time.Second * 57
+ tooLong := time.Second * 60
+ requestDelay := time.Second
+
+ testCases := []authTest{
+ // Auth works
+ {name: "ws good", endpoint: node.WSAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
+ {name: "http good", endpoint: node.HTTPAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
+
+ // Try a bad auth
+ {name: "ws bad", endpoint: node.WSAuthEndpoint(), prov: badAuth, expectDialFail: true}, // ws auth is immediate
+ {name: "http bad", endpoint: node.HTTPAuthEndpoint(), prov: badAuth, expectCall1Fail: true}, // http auth is on first call
+
+ // A common mistake with JWT is to allow the "none" algorithm, which is a valid JWT but not secure.
+ {name: "ws none", endpoint: node.WSAuthEndpoint(), prov: noneAuth(secret), expectDialFail: true},
+ {name: "http none", endpoint: node.HTTPAuthEndpoint(), prov: noneAuth(secret), expectCall1Fail: true},
+
+ // claims of 5 seconds or more, older or newer, are not allowed
+ {name: "ws too old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectDialFail: true},
+ {name: "http too old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectCall1Fail: true},
+ // note: for it to be too long we need to add a delay, so that once we receive the request, the difference has not dipped below the "tooLong"
+ {name: "ws too new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectDialFail: true},
+ {name: "http too new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectCall1Fail: true},
+
+ // Try offset the time, but stay just within bounds
+ {name: "ws old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
+ {name: "http old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
+ {name: "ws new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
+ {name: "http new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
+
+ // ws only authenticates on initial dial, then continues communication
+ {name: "ws single auth", endpoint: node.WSAuthEndpoint(), prov: changingAuth(goodAuth, badAuth)},
+ {name: "http call fail auth", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, badAuth), expectCall2Fail: true},
+ {name: "http call fail time", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, offsetTimeAuth(secret, tooLong+requestDelay)), expectCall2Fail: true},
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, testCase.Run)
+ }
+}
+
+func noneAuth(secret [32]byte) rpc.HTTPAuth {
+ return func(header http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now()},
+ })
+ s, err := token.SignedString(secret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ header.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
+
+func changingAuth(provs ...rpc.HTTPAuth) rpc.HTTPAuth {
+ i := 0
+ return func(header http.Header) error {
+ i += 1
+ if i > len(provs) {
+ i = len(provs)
+ }
+ return provs[i-1](header)
+ }
+}
+
+func offsetTimeAuth(secret [32]byte, offset time.Duration) rpc.HTTPAuth {
+ return func(header http.Header) error {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iat": &jwt.NumericDate{Time: time.Now().Add(offset)},
+ })
+ s, err := token.SignedString(secret[:])
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+ header.Set("Authorization", "Bearer "+s)
+ return nil
+ }
+}
diff --git a/rpc/client.go b/rpc/client.go
index d3ce02977..8288f976e 100644
--- a/rpc/client.go
+++ b/rpc/client.go
@@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"net/url"
+ "os"
"reflect"
"strconv"
"sync/atomic"
@@ -99,7 +100,7 @@ type Client struct {
reqTimeout chan *requestOp // removes response IDs when call timeout expires
}
-type reconnectFunc func(ctx context.Context) (ServerCodec, error)
+type reconnectFunc func(context.Context) (ServerCodec, error)
type clientContextKey struct{}
@@ -153,14 +154,16 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
//
// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a
// file name with no URL scheme, a local socket connection is established using UNIX
-// domain sockets on supported platforms and named pipes on Windows. If you want to
-// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead.
+// domain sockets on supported platforms and named pipes on Windows.
+//
+// If you want to further configure the transport, use DialOptions instead of this
+// function.
//
// For websocket connections, the origin is set to the local host name.
//
-// The client reconnects automatically if the connection is lost.
+// The client reconnects automatically when the connection is lost.
func Dial(rawurl string) (*Client, error) {
- return DialContext(context.Background(), rawurl)
+ return DialOptions(context.Background(), rawurl)
}
// DialContext creates a new RPC client, just like Dial.
@@ -168,22 +171,46 @@ func Dial(rawurl string) (*Client, error) {
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) {
+ return DialOptions(ctx, rawurl)
+}
+
+// DialOptions creates a new RPC client for the given URL. You can supply any of the
+// pre-defined client options to configure the underlying transport.
+//
+// The context is used to cancel or time out the initial connection establishment. It does
+// not affect subsequent interactions with the client.
+//
+// The client reconnects automatically when the connection is lost.
+func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
+
+ cfg := new(clientConfig)
+ for _, opt := range options {
+ opt.applyOption(cfg)
+ }
+
+ var reconnect reconnectFunc
switch u.Scheme {
case "http", "https":
- return DialHTTP(rawurl)
+ reconnect = newClientTransportHTTP(rawurl, cfg)
case "ws", "wss":
- return DialWebsocket(ctx, rawurl, "")
+ rc, err := newClientTransportWS(rawurl, cfg)
+ if err != nil {
+ return nil, err
+ }
+ reconnect = rc
case "stdio":
- return DialStdIO(ctx)
+ reconnect = newClientTransportIO(os.Stdin, os.Stdout)
case "":
- return DialIPC(ctx, rawurl)
+ reconnect = newClientTransportIPC(rawurl)
default:
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
}
+
+ return newClient(ctx, reconnect)
}
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
diff --git a/rpc/client_opt.go b/rpc/client_opt.go
new file mode 100644
index 000000000..5ad7c22b3
--- /dev/null
+++ b/rpc/client_opt.go
@@ -0,0 +1,106 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rpc
+
+import (
+ "net/http"
+
+ "github.com/gorilla/websocket"
+)
+
+// ClientOption is a configuration option for the RPC client.
+type ClientOption interface {
+ applyOption(*clientConfig)
+}
+
+type clientConfig struct {
+ httpClient *http.Client
+ httpHeaders http.Header
+ httpAuth HTTPAuth
+
+ wsDialer *websocket.Dialer
+}
+
+func (cfg *clientConfig) initHeaders() {
+ if cfg.httpHeaders == nil {
+ cfg.httpHeaders = make(http.Header)
+ }
+}
+
+func (cfg *clientConfig) setHeader(key, value string) {
+ cfg.initHeaders()
+ cfg.httpHeaders.Set(key, value)
+}
+
+type optionFunc func(*clientConfig)
+
+func (fn optionFunc) applyOption(opt *clientConfig) {
+ fn(opt)
+}
+
+// WithWebsocketDialer configures the websocket.Dialer used by the RPC client.
+func WithWebsocketDialer(dialer websocket.Dialer) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.wsDialer = &dialer
+ })
+}
+
+// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
+// will be used for both HTTP and WebSocket connections.
+func WithHeader(key, value string) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.initHeaders()
+ cfg.httpHeaders.Set(key, value)
+ })
+}
+
+// WithHeaders configures HTTP headers set by the RPC client. Headers set using this
+// option will be used for both HTTP and WebSocket connections.
+func WithHeaders(headers http.Header) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.initHeaders()
+ for k, vs := range headers {
+ cfg.httpHeaders[k] = vs
+ }
+ })
+}
+
+// WithHTTPClient configures the http.Client used by the RPC client.
+func WithHTTPClient(c *http.Client) ClientOption {
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.httpClient = c
+ })
+}
+
+// WithHTTPAuth configures HTTP request authentication. The given provider will be called
+// whenever a request is made. Note that only one authentication provider can be active at
+// any time.
+func WithHTTPAuth(a HTTPAuth) ClientOption {
+ if a == nil {
+ panic("nil auth")
+ }
+ return optionFunc(func(cfg *clientConfig) {
+ cfg.httpAuth = a
+ })
+}
+
+// A HTTPAuth function is called by the client whenever a HTTP request is sent.
+// The function must be safe for concurrent use.
+//
+// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
+// auth information to the request.
+type HTTPAuth func(h http.Header) error
diff --git a/rpc/client_opt_test.go b/rpc/client_opt_test.go
new file mode 100644
index 000000000..d7cc2572a
--- /dev/null
+++ b/rpc/client_opt_test.go
@@ -0,0 +1,25 @@
+package rpc_test
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ "github.com/ethereum/go-ethereum/rpc"
+)
+
+// This example configures a HTTP-based RPC client with two options - one setting the
+// overall request timeout, the other adding a custom HTTP header to all requests.
+func ExampleDialOptions() {
+ tokenHeader := rpc.WithHeader("x-token", "foo")
+ httpClient := rpc.WithHTTPClient(&http.Client{
+ Timeout: 10 * time.Second,
+ })
+
+ ctx := context.Background()
+ c, err := rpc.DialOptions(ctx, "http://rpc.example.com", httpClient, tokenHeader)
+ if err != nil {
+ panic(err)
+ }
+ c.Close()
+}
diff --git a/rpc/http.go b/rpc/http.go
index 858d80858..8595959af 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -45,6 +45,7 @@ type httpConn struct {
closeCh chan interface{}
mu sync.Mutex // protects headers
headers http.Header
+ auth HTTPAuth
}
// httpConn implements ServerCodec, but it is treated specially by Client
@@ -117,8 +118,15 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
IdleTimeout: 120 * time.Second,
}
+// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
+func DialHTTP(endpoint string) (*Client, error) {
+ return DialHTTPWithClient(endpoint, new(http.Client))
+}
+
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
+//
+// Deprecated: use DialOptions and the WithHTTPClient option.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
// Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
@@ -126,24 +134,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
return nil, err
}
- initctx := context.Background()
- headers := make(http.Header, 2)
- headers.Set("accept", contentType)
- headers.Set("content-type", contentType)
- return newClient(initctx, func(context.Context) (ServerCodec, error) {
- hc := &httpConn{
- client: client,
- headers: headers,
- url: endpoint,
- closeCh: make(chan interface{}),
- }
- return hc, nil
- })
+ var cfg clientConfig
+ fn := newClientTransportHTTP(endpoint, &cfg)
+ return newClient(context.Background(), fn)
}
-// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
-func DialHTTP(endpoint string) (*Client, error) {
- return DialHTTPWithClient(endpoint, new(http.Client))
+func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
+ headers := make(http.Header, 2+len(cfg.httpHeaders))
+ headers.Set("accept", contentType)
+ headers.Set("content-type", contentType)
+ for key, values := range cfg.httpHeaders {
+ headers[key] = values
+ }
+
+ client := cfg.httpClient
+ if client == nil {
+ client = new(http.Client)
+ }
+
+ hc := &httpConn{
+ client: client,
+ headers: headers,
+ url: endpoint,
+ auth: cfg.httpAuth,
+ closeCh: make(chan interface{}),
+ }
+
+ return func(ctx context.Context) (ServerCodec, error) {
+ return hc, nil
+ }
}
func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {
@@ -195,6 +214,11 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
+ if hc.auth != nil {
+ if err := hc.auth(req.Header); err != nil {
+ return nil, err
+ }
+ }
// do request
resp, err := hc.client.Do(req)
diff --git a/rpc/ipc.go b/rpc/ipc.go
index 07a211c62..d9e0de62e 100644
--- a/rpc/ipc.go
+++ b/rpc/ipc.go
@@ -46,11 +46,15 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
- return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
+ return newClient(ctx, newClientTransportIPC(endpoint))
+}
+
+func newClientTransportIPC(endpoint string) reconnectFunc {
+ return func(ctx context.Context) (ServerCodec, error) {
conn, err := newIPCConnection(ctx, endpoint)
if err != nil {
return nil, err
}
return NewCodec(conn), err
- })
+ }
}
diff --git a/rpc/stdio.go b/rpc/stdio.go
index be2bab1c9..ae32db26e 100644
--- a/rpc/stdio.go
+++ b/rpc/stdio.go
@@ -32,12 +32,16 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
- return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
+ return newClient(ctx, newClientTransportIO(in, out))
+}
+
+func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
+ return func(context.Context) (ServerCodec, error) {
return NewCodec(stdioConn{
in: in,
out: out,
}), nil
- })
+ }
}
type stdioConn struct {
diff --git a/rpc/websocket.go b/rpc/websocket.go
index 28380d8aa..f2a923446 100644
--- a/rpc/websocket.go
+++ b/rpc/websocket.go
@@ -181,24 +181,23 @@ func parseOriginURL(origin string) (string, string, string, error) {
return scheme, hostname, port, nil
}
-// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
-// that is listening on the given endpoint using the provided dialer.
+// DialWebsocketWithDialer creates a new RPC client using WebSocket.
+//
+// The context is used for the initial connection establishment. It does not
+// affect subsequent interactions with the client.
+//
+// Deprecated: use DialOptions and the WithWebsocketDialer option.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
- endpoint, header, err := wsClientHeaders(endpoint, origin)
+ cfg := new(clientConfig)
+ cfg.wsDialer = &dialer
+ if origin != "" {
+ cfg.setHeader("origin", origin)
+ }
+ connect, err := newClientTransportWS(endpoint, cfg)
if err != nil {
return nil, err
}
- return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
- conn, resp, err := dialer.DialContext(ctx, endpoint, header)
- if err != nil {
- hErr := wsHandshakeError{err: err}
- if resp != nil {
- hErr.status = resp.Status
- }
- return nil, hErr
- }
- return newWebsocketCodec(conn, endpoint, header), nil
- })
+ return newClient(ctx, connect)
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@@ -207,12 +206,53 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
- dialer := websocket.Dialer{
- ReadBufferSize: wsReadBuffer,
- WriteBufferSize: wsWriteBuffer,
- WriteBufferPool: wsBufferPool,
+ cfg := new(clientConfig)
+ if origin != "" {
+ cfg.setHeader("origin", origin)
}
- return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
+ connect, err := newClientTransportWS(endpoint, cfg)
+ if err != nil {
+ return nil, err
+ }
+ return newClient(ctx, connect)
+}
+
+func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
+ dialer := cfg.wsDialer
+ if dialer == nil {
+ dialer = &websocket.Dialer{
+ ReadBufferSize: wsReadBuffer,
+ WriteBufferSize: wsWriteBuffer,
+ WriteBufferPool: wsBufferPool,
+ }
+ }
+
+ dialURL, header, err := wsClientHeaders(endpoint, "")
+ if err != nil {
+ return nil, err
+ }
+ for key, values := range cfg.httpHeaders {
+ header[key] = values
+ }
+
+ connect := func(ctx context.Context) (ServerCodec, error) {
+ header := header.Clone()
+ if cfg.httpAuth != nil {
+ if err := cfg.httpAuth(header); err != nil {
+ return nil, err
+ }
+ }
+ conn, resp, err := dialer.DialContext(ctx, dialURL, header)
+ if err != nil {
+ hErr := wsHandshakeError{err: err}
+ if resp != nil {
+ hErr.status = resp.Status
+ }
+ return nil, hErr
+ }
+ return newWebsocketCodec(conn, dialURL, header), nil
+ }
+ return connect, nil
}
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {