diff --git a/rpc/client.go b/rpc/client.go index 198ce6357..e9deb3f6d 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -59,6 +59,12 @@ const ( maxClientSubscriptionBuffer = 20000 ) +const ( + httpScheme = "http" + wsScheme = "ws" + ipcScheme = "ipc" +) + // BatchElem is an element in a batch request. type BatchElem struct { Method string @@ -75,7 +81,7 @@ type BatchElem struct { // Client represents a connection to an RPC server. type Client struct { idgen func() ID // for subscriptions - isHTTP bool + scheme string // connection type: http, ws or ipc services *serviceRegistry idCounter uint32 @@ -111,6 +117,10 @@ type clientConn struct { func (c *Client) newClientConn(conn ServerCodec) *clientConn { ctx := context.WithValue(context.Background(), clientContextKey{}, c) + // Http connections have already set the scheme + if !c.isHTTP() && c.scheme != "" { + ctx = context.WithValue(ctx, "scheme", c.scheme) + } handler := newHandler(ctx, conn, c.idgen, c.services) return &clientConn{conn, handler} } @@ -136,7 +146,7 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro select { case <-ctx.Done(): // Send the timeout to dispatch so it can remove the request IDs. - if !c.isHTTP { + if !c.isHTTP() { select { case c.reqTimeout <- op: case <-c.closing: @@ -203,10 +213,18 @@ func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) } func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { - _, isHTTP := conn.(*httpConn) + scheme := "" + switch conn.(type) { + case *httpConn: + scheme = httpScheme + case *websocketCodec: + scheme = wsScheme + case *jsonCodec: + scheme = ipcScheme + } c := &Client{ idgen: idgen, - isHTTP: isHTTP, + scheme: scheme, services: services, writeConn: conn, close: make(chan struct{}), @@ -219,7 +237,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C reqSent: make(chan error, 1), reqTimeout: make(chan *requestOp), } - if !isHTTP { + if !c.isHTTP() { go c.dispatch(conn) } return c @@ -250,7 +268,7 @@ func (c *Client) SupportedModules() (map[string]string, error) { // Close closes the client, aborting any in-flight requests. func (c *Client) Close() { - if c.isHTTP { + if c.isHTTP() { return } select { @@ -264,7 +282,7 @@ func (c *Client) Close() { // This method only works for clients using HTTP, it doesn't have // any effect for clients using another transport. func (c *Client) SetHeader(key, value string) { - if !c.isHTTP { + if !c.isHTTP() { return } conn := c.writeConn.(*httpConn) @@ -298,7 +316,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} - if c.isHTTP { + if c.isHTTP() { err = c.sendHTTP(ctx, op, msg) } else { err = c.send(ctx, op, msg) @@ -357,7 +375,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { } var err error - if c.isHTTP { + if c.isHTTP() { err = c.sendBatchHTTP(ctx, op, msgs) } else { err = c.send(ctx, op, msgs) @@ -402,7 +420,7 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) } msg.ID = nil - if c.isHTTP { + if c.isHTTP() { return c.sendHTTP(ctx, op, msg) } return c.send(ctx, op, msg) @@ -440,7 +458,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf if chanVal.IsNil() { panic("channel given to Subscribe must not be nil") } - if c.isHTTP { + if c.isHTTP() { return nil, ErrNotificationsUnsupported } @@ -642,3 +660,7 @@ func (c *Client) read(codec ServerCodec) { c.readOp <- readOp{msgs, batch} } } + +func (c *Client) isHTTP() bool { + return c.scheme == httpScheme +}