node, rpc: add JWT auth support in client (#24911)

This adds a generic mechanism for 'dial options' in the RPC client,
and also implements a specific dial option for the JWT authentication
mechanism used by the engine API. Some real tests for the server-side
authentication handling are also added.

Co-authored-by: Joshua Gutow <jgutow@optimism.io>
Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
protolambda 2022-09-02 17:40:41 +02:00 committed by GitHub
parent 7f2890a9be
commit 90711efb0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 574 additions and 49 deletions

View File

@ -201,7 +201,7 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"` AllowUnprotectedTxs bool `toml:",omitempty"`
// JWTSecret is the hex-encoded jwt secret. // JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"` JWTSecret string `toml:",omitempty"`
} }

45
node/jwt_auth.go Normal file
View File

@ -0,0 +1,45 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"fmt"
"net/http"
"time"
"github.com/ethereum/go-ethereum/rpc"
"github.com/golang-jwt/jwt/v4"
)
// NewJWTAuth creates an rpc client authentication provider that uses JWT. The
// secret MUST be 32 bytes (256 bits) as defined by the Engine-API authentication spec.
//
// See https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md
// for more details about this authentication scheme.
func NewJWTAuth(jwtsecret [32]byte) rpc.HTTPAuth {
return func(h http.Header) error {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iat": &jwt.NumericDate{Time: time.Now()},
})
s, err := token.SignedString(jwtsecret[:])
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
h.Set("Authorization", "Bearer "+s)
return nil
}
}

View File

@ -668,6 +668,19 @@ func (n *Node) WSEndpoint() string {
return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix return "ws://" + n.ws.listenAddr() + n.ws.wsConfig.prefix
} }
// HTTPAuthEndpoint returns the URL of the authenticated HTTP server.
func (n *Node) HTTPAuthEndpoint() string {
return "http://" + n.httpAuth.listenAddr()
}
// WSAuthEndpoint returns the current authenticated JSON-RPC over WebSocket endpoint.
func (n *Node) WSAuthEndpoint() string {
if n.httpAuth.wsAllowed() {
return "ws://" + n.httpAuth.listenAddr() + n.httpAuth.wsConfig.prefix
}
return "ws://" + n.wsAuth.listenAddr() + n.wsAuth.wsConfig.prefix
}
// EventMux retrieves the event multiplexer used by all the network services in // EventMux retrieves the event multiplexer used by all the network services in
// the current protocol stack. // the current protocol stack.
func (n *Node) EventMux() *event.TypeMux { func (n *Node) EventMux() *event.TypeMux {

237
node/node_auth_test.go Normal file
View File

@ -0,0 +1,237 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"context"
crand "crypto/rand"
"fmt"
"net/http"
"os"
"path"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/rpc"
"github.com/golang-jwt/jwt/v4"
)
type helloRPC string
func (ta helloRPC) HelloWorld() (string, error) {
return string(ta), nil
}
type authTest struct {
name string
endpoint string
prov rpc.HTTPAuth
expectDialFail bool
expectCall1Fail bool
expectCall2Fail bool
}
func (at *authTest) Run(t *testing.T) {
ctx := context.Background()
cl, err := rpc.DialOptions(ctx, at.endpoint, rpc.WithHTTPAuth(at.prov))
if at.expectDialFail {
if err == nil {
t.Fatal("expected initial dial to fail")
} else {
return
}
}
if err != nil {
t.Fatalf("failed to dial rpc endpoint: %v", err)
}
var x string
err = cl.CallContext(ctx, &x, "engine_helloWorld")
if at.expectCall1Fail {
if err == nil {
t.Fatal("expected call 1 to fail")
} else {
return
}
}
if err != nil {
t.Fatalf("failed to call rpc endpoint: %v", err)
}
if x != "hello engine" {
t.Fatalf("method was silent but did not return expected value: %q", x)
}
err = cl.CallContext(ctx, &x, "eth_helloWorld")
if at.expectCall2Fail {
if err == nil {
t.Fatal("expected call 2 to fail")
} else {
return
}
}
if err != nil {
t.Fatalf("failed to call rpc endpoint: %v", err)
}
if x != "hello eth" {
t.Fatalf("method was silent but did not return expected value: %q", x)
}
}
func TestAuthEndpoints(t *testing.T) {
var secret [32]byte
if _, err := crand.Read(secret[:]); err != nil {
t.Fatalf("failed to create jwt secret: %v", err)
}
// Geth must read it from a file, and does not support in-memory JWT secrets, so we create a temporary file.
jwtPath := path.Join(t.TempDir(), "jwt_secret")
if err := os.WriteFile(jwtPath, []byte(hexutil.Encode(secret[:])), 0600); err != nil {
t.Fatalf("failed to prepare jwt secret file: %v", err)
}
// We get ports assigned by the node automatically
conf := &Config{
HTTPHost: "127.0.0.1",
HTTPPort: 0,
WSHost: "127.0.0.1",
WSPort: 0,
AuthAddr: "127.0.0.1",
AuthPort: 0,
JWTSecret: jwtPath,
WSModules: []string{"eth", "engine"},
HTTPModules: []string{"eth", "engine"},
}
node, err := New(conf)
if err != nil {
t.Fatalf("could not create a new node: %v", err)
}
// register dummy apis so we can test the modules are available and reachable with authentication
node.RegisterAPIs([]rpc.API{
{
Namespace: "engine",
Version: "1.0",
Service: helloRPC("hello engine"),
Public: true,
Authenticated: true,
},
{
Namespace: "eth",
Version: "1.0",
Service: helloRPC("hello eth"),
Public: true,
Authenticated: true,
},
})
if err := node.Start(); err != nil {
t.Fatalf("failed to start test node: %v", err)
}
defer node.Close()
// sanity check we are running different endpoints
if a, b := node.WSEndpoint(), node.WSAuthEndpoint(); a == b {
t.Fatalf("expected ws and auth-ws endpoints to be different, got: %q and %q", a, b)
}
if a, b := node.HTTPEndpoint(), node.HTTPAuthEndpoint(); a == b {
t.Fatalf("expected http and auth-http endpoints to be different, got: %q and %q", a, b)
}
goodAuth := NewJWTAuth(secret)
var otherSecret [32]byte
if _, err := crand.Read(otherSecret[:]); err != nil {
t.Fatalf("failed to create jwt secret: %v", err)
}
badAuth := NewJWTAuth(otherSecret)
notTooLong := time.Second * 57
tooLong := time.Second * 60
requestDelay := time.Second
testCases := []authTest{
// Auth works
{name: "ws good", endpoint: node.WSAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
{name: "http good", endpoint: node.HTTPAuthEndpoint(), prov: goodAuth, expectCall1Fail: false},
// Try a bad auth
{name: "ws bad", endpoint: node.WSAuthEndpoint(), prov: badAuth, expectDialFail: true}, // ws auth is immediate
{name: "http bad", endpoint: node.HTTPAuthEndpoint(), prov: badAuth, expectCall1Fail: true}, // http auth is on first call
// A common mistake with JWT is to allow the "none" algorithm, which is a valid JWT but not secure.
{name: "ws none", endpoint: node.WSAuthEndpoint(), prov: noneAuth(secret), expectDialFail: true},
{name: "http none", endpoint: node.HTTPAuthEndpoint(), prov: noneAuth(secret), expectCall1Fail: true},
// claims of 5 seconds or more, older or newer, are not allowed
{name: "ws too old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectDialFail: true},
{name: "http too old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -tooLong), expectCall1Fail: true},
// note: for it to be too long we need to add a delay, so that once we receive the request, the difference has not dipped below the "tooLong"
{name: "ws too new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectDialFail: true},
{name: "http too new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, tooLong+requestDelay), expectCall1Fail: true},
// Try offset the time, but stay just within bounds
{name: "ws old", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
{name: "http old", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, -notTooLong)},
{name: "ws new", endpoint: node.WSAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
{name: "http new", endpoint: node.HTTPAuthEndpoint(), prov: offsetTimeAuth(secret, notTooLong)},
// ws only authenticates on initial dial, then continues communication
{name: "ws single auth", endpoint: node.WSAuthEndpoint(), prov: changingAuth(goodAuth, badAuth)},
{name: "http call fail auth", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, badAuth), expectCall2Fail: true},
{name: "http call fail time", endpoint: node.HTTPAuthEndpoint(), prov: changingAuth(goodAuth, offsetTimeAuth(secret, tooLong+requestDelay)), expectCall2Fail: true},
}
for _, testCase := range testCases {
t.Run(testCase.name, testCase.Run)
}
}
func noneAuth(secret [32]byte) rpc.HTTPAuth {
return func(header http.Header) error {
token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
"iat": &jwt.NumericDate{Time: time.Now()},
})
s, err := token.SignedString(secret[:])
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
header.Set("Authorization", "Bearer "+s)
return nil
}
}
func changingAuth(provs ...rpc.HTTPAuth) rpc.HTTPAuth {
i := 0
return func(header http.Header) error {
i += 1
if i > len(provs) {
i = len(provs)
}
return provs[i-1](header)
}
}
func offsetTimeAuth(secret [32]byte, offset time.Duration) rpc.HTTPAuth {
return func(header http.Header) error {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iat": &jwt.NumericDate{Time: time.Now().Add(offset)},
})
s, err := token.SignedString(secret[:])
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
header.Set("Authorization", "Bearer "+s)
return nil
}
}

View File

@ -22,6 +22,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"os"
"reflect" "reflect"
"strconv" "strconv"
"sync/atomic" "sync/atomic"
@ -99,7 +100,7 @@ type Client struct {
reqTimeout chan *requestOp // removes response IDs when call timeout expires reqTimeout chan *requestOp // removes response IDs when call timeout expires
} }
type reconnectFunc func(ctx context.Context) (ServerCodec, error) type reconnectFunc func(context.Context) (ServerCodec, error)
type clientContextKey struct{} type clientContextKey struct{}
@ -153,14 +154,16 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
// //
// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a // The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a
// file name with no URL scheme, a local socket connection is established using UNIX // file name with no URL scheme, a local socket connection is established using UNIX
// domain sockets on supported platforms and named pipes on Windows. If you want to // domain sockets on supported platforms and named pipes on Windows.
// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead. //
// If you want to further configure the transport, use DialOptions instead of this
// function.
// //
// For websocket connections, the origin is set to the local host name. // For websocket connections, the origin is set to the local host name.
// //
// The client reconnects automatically if the connection is lost. // The client reconnects automatically when the connection is lost.
func Dial(rawurl string) (*Client, error) { func Dial(rawurl string) (*Client, error) {
return DialContext(context.Background(), rawurl) return DialOptions(context.Background(), rawurl)
} }
// DialContext creates a new RPC client, just like Dial. // DialContext creates a new RPC client, just like Dial.
@ -168,22 +171,46 @@ func Dial(rawurl string) (*Client, error) {
// The context is used to cancel or time out the initial connection establishment. It does // The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client. // not affect subsequent interactions with the client.
func DialContext(ctx context.Context, rawurl string) (*Client, error) { func DialContext(ctx context.Context, rawurl string) (*Client, error) {
return DialOptions(ctx, rawurl)
}
// DialOptions creates a new RPC client for the given URL. You can supply any of the
// pre-defined client options to configure the underlying transport.
//
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
//
// The client reconnects automatically when the connection is lost.
func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) {
u, err := url.Parse(rawurl) u, err := url.Parse(rawurl)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg := new(clientConfig)
for _, opt := range options {
opt.applyOption(cfg)
}
var reconnect reconnectFunc
switch u.Scheme { switch u.Scheme {
case "http", "https": case "http", "https":
return DialHTTP(rawurl) reconnect = newClientTransportHTTP(rawurl, cfg)
case "ws", "wss": case "ws", "wss":
return DialWebsocket(ctx, rawurl, "") rc, err := newClientTransportWS(rawurl, cfg)
if err != nil {
return nil, err
}
reconnect = rc
case "stdio": case "stdio":
return DialStdIO(ctx) reconnect = newClientTransportIO(os.Stdin, os.Stdout)
case "": case "":
return DialIPC(ctx, rawurl) reconnect = newClientTransportIPC(rawurl)
default: default:
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
} }
return newClient(ctx, reconnect)
} }
// ClientFromContext retrieves the client from the context, if any. This can be used to perform // ClientFromContext retrieves the client from the context, if any. This can be used to perform

106
rpc/client_opt.go Normal file
View File

@ -0,0 +1,106 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"net/http"
"github.com/gorilla/websocket"
)
// ClientOption is a configuration option for the RPC client.
type ClientOption interface {
applyOption(*clientConfig)
}
type clientConfig struct {
httpClient *http.Client
httpHeaders http.Header
httpAuth HTTPAuth
wsDialer *websocket.Dialer
}
func (cfg *clientConfig) initHeaders() {
if cfg.httpHeaders == nil {
cfg.httpHeaders = make(http.Header)
}
}
func (cfg *clientConfig) setHeader(key, value string) {
cfg.initHeaders()
cfg.httpHeaders.Set(key, value)
}
type optionFunc func(*clientConfig)
func (fn optionFunc) applyOption(opt *clientConfig) {
fn(opt)
}
// WithWebsocketDialer configures the websocket.Dialer used by the RPC client.
func WithWebsocketDialer(dialer websocket.Dialer) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.wsDialer = &dialer
})
}
// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
// will be used for both HTTP and WebSocket connections.
func WithHeader(key, value string) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.initHeaders()
cfg.httpHeaders.Set(key, value)
})
}
// WithHeaders configures HTTP headers set by the RPC client. Headers set using this
// option will be used for both HTTP and WebSocket connections.
func WithHeaders(headers http.Header) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.initHeaders()
for k, vs := range headers {
cfg.httpHeaders[k] = vs
}
})
}
// WithHTTPClient configures the http.Client used by the RPC client.
func WithHTTPClient(c *http.Client) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.httpClient = c
})
}
// WithHTTPAuth configures HTTP request authentication. The given provider will be called
// whenever a request is made. Note that only one authentication provider can be active at
// any time.
func WithHTTPAuth(a HTTPAuth) ClientOption {
if a == nil {
panic("nil auth")
}
return optionFunc(func(cfg *clientConfig) {
cfg.httpAuth = a
})
}
// A HTTPAuth function is called by the client whenever a HTTP request is sent.
// The function must be safe for concurrent use.
//
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
// auth information to the request.
type HTTPAuth func(h http.Header) error

25
rpc/client_opt_test.go Normal file
View File

@ -0,0 +1,25 @@
package rpc_test
import (
"context"
"net/http"
"time"
"github.com/ethereum/go-ethereum/rpc"
)
// This example configures a HTTP-based RPC client with two options - one setting the
// overall request timeout, the other adding a custom HTTP header to all requests.
func ExampleDialOptions() {
tokenHeader := rpc.WithHeader("x-token", "foo")
httpClient := rpc.WithHTTPClient(&http.Client{
Timeout: 10 * time.Second,
})
ctx := context.Background()
c, err := rpc.DialOptions(ctx, "http://rpc.example.com", httpClient, tokenHeader)
if err != nil {
panic(err)
}
c.Close()
}

View File

@ -45,6 +45,7 @@ type httpConn struct {
closeCh chan interface{} closeCh chan interface{}
mu sync.Mutex // protects headers mu sync.Mutex // protects headers
headers http.Header headers http.Header
auth HTTPAuth
} }
// httpConn implements ServerCodec, but it is treated specially by Client // httpConn implements ServerCodec, but it is treated specially by Client
@ -117,8 +118,15 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
} }
// DialHTTP creates a new RPC client that connects to an RPC server over HTTP.
func DialHTTP(endpoint string) (*Client, error) {
return DialHTTPWithClient(endpoint, new(http.Client))
}
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client. // using the provided HTTP Client.
//
// Deprecated: use DialOptions and the WithHTTPClient option.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
// Sanity check URL so we don't end up with a client that will fail every request. // Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint) _, err := url.Parse(endpoint)
@ -126,24 +134,35 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
return nil, err return nil, err
} }
initctx := context.Background() var cfg clientConfig
headers := make(http.Header, 2) fn := newClientTransportHTTP(endpoint, &cfg)
headers.Set("accept", contentType) return newClient(context.Background(), fn)
headers.Set("content-type", contentType)
return newClient(initctx, func(context.Context) (ServerCodec, error) {
hc := &httpConn{
client: client,
headers: headers,
url: endpoint,
closeCh: make(chan interface{}),
}
return hc, nil
})
} }
// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
func DialHTTP(endpoint string) (*Client, error) { headers := make(http.Header, 2+len(cfg.httpHeaders))
return DialHTTPWithClient(endpoint, new(http.Client)) headers.Set("accept", contentType)
headers.Set("content-type", contentType)
for key, values := range cfg.httpHeaders {
headers[key] = values
}
client := cfg.httpClient
if client == nil {
client = new(http.Client)
}
hc := &httpConn{
client: client,
headers: headers,
url: endpoint,
auth: cfg.httpAuth,
closeCh: make(chan interface{}),
}
return func(ctx context.Context) (ServerCodec, error) {
return hc, nil
}
} }
func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error {
@ -195,6 +214,11 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
hc.mu.Lock() hc.mu.Lock()
req.Header = hc.headers.Clone() req.Header = hc.headers.Clone()
hc.mu.Unlock() hc.mu.Unlock()
if hc.auth != nil {
if err := hc.auth(req.Header); err != nil {
return nil, err
}
}
// do request // do request
resp, err := hc.client.Do(req) resp, err := hc.client.Do(req)

View File

@ -46,11 +46,15 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not // The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client. // affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) { func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { return newClient(ctx, newClientTransportIPC(endpoint))
}
func newClientTransportIPC(endpoint string) reconnectFunc {
return func(ctx context.Context) (ServerCodec, error) {
conn, err := newIPCConnection(ctx, endpoint) conn, err := newIPCConnection(ctx, endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewCodec(conn), err return NewCodec(conn), err
}) }
} }

View File

@ -32,12 +32,16 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels // DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
return newClient(ctx, func(_ context.Context) (ServerCodec, error) { return newClient(ctx, newClientTransportIO(in, out))
}
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {
return func(context.Context) (ServerCodec, error) {
return NewCodec(stdioConn{ return NewCodec(stdioConn{
in: in, in: in,
out: out, out: out,
}), nil }), nil
}) }
} }
type stdioConn struct { type stdioConn struct {

View File

@ -181,24 +181,23 @@ func parseOriginURL(origin string) (string, string, string, error) {
return scheme, hostname, port, nil return scheme, hostname, port, nil
} }
// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // DialWebsocketWithDialer creates a new RPC client using WebSocket.
// that is listening on the given endpoint using the provided dialer. //
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
//
// Deprecated: use DialOptions and the WithWebsocketDialer option.
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
endpoint, header, err := wsClientHeaders(endpoint, origin) cfg := new(clientConfig)
cfg.wsDialer = &dialer
if origin != "" {
cfg.setHeader("origin", origin)
}
connect, err := newClientTransportWS(endpoint, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { return newClient(ctx, connect)
conn, resp, err := dialer.DialContext(ctx, endpoint, header)
if err != nil {
hErr := wsHandshakeError{err: err}
if resp != nil {
hErr.status = resp.Status
}
return nil, hErr
}
return newWebsocketCodec(conn, endpoint, header), nil
})
} }
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@ -207,12 +206,53 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
// The context is used for the initial connection establishment. It does not // The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client. // affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
dialer := websocket.Dialer{ cfg := new(clientConfig)
ReadBufferSize: wsReadBuffer, if origin != "" {
WriteBufferSize: wsWriteBuffer, cfg.setHeader("origin", origin)
WriteBufferPool: wsBufferPool,
} }
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) connect, err := newClientTransportWS(endpoint, cfg)
if err != nil {
return nil, err
}
return newClient(ctx, connect)
}
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {
dialer := cfg.wsDialer
if dialer == nil {
dialer = &websocket.Dialer{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
WriteBufferPool: wsBufferPool,
}
}
dialURL, header, err := wsClientHeaders(endpoint, "")
if err != nil {
return nil, err
}
for key, values := range cfg.httpHeaders {
header[key] = values
}
connect := func(ctx context.Context) (ServerCodec, error) {
header := header.Clone()
if cfg.httpAuth != nil {
if err := cfg.httpAuth(header); err != nil {
return nil, err
}
}
conn, resp, err := dialer.DialContext(ctx, dialURL, header)
if err != nil {
hErr := wsHandshakeError{err: err}
if resp != nil {
hErr.status = resp.Status
}
return nil, hErr
}
return newWebsocketCodec(conn, dialURL, header), nil
}
return connect, nil
} }
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {