diff --git a/lib/jsonrpc/client.go b/lib/jsonrpc/client.go index 33f5697bf..f165d1a9f 100644 --- a/lib/jsonrpc/client.go +++ b/lib/jsonrpc/client.go @@ -80,7 +80,12 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser requests := make(chan clientRequest) handlers := map[string]rpcHandler{} - go handleWsConn(context.TODO(), conn, handlers, requests, stop) + go (&wsConn{ + conn: conn, + handler: handlers, + requests: requests, + stop: stop, + }).handleWsConn(context.TODO()) for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) diff --git a/lib/jsonrpc/server.go b/lib/jsonrpc/server.go index ce6b3ca84..021b2ebf3 100644 --- a/lib/jsonrpc/server.go +++ b/lib/jsonrpc/server.go @@ -36,7 +36,10 @@ func (s *RPCServer) handleWS(w http.ResponseWriter, r *http.Request) { return } - handleWsConn(r.Context(), c, s.methods, nil, nil) + (&wsConn{ + conn: c, + handler: s.methods, + }).handleWsConn(r.Context()) if err := c.Close(); err != nil { log.Error(err) diff --git a/lib/jsonrpc/websocket.go b/lib/jsonrpc/websocket.go index 0196b46cb..5198b1f19 100644 --- a/lib/jsonrpc/websocket.go +++ b/lib/jsonrpc/websocket.go @@ -31,27 +31,48 @@ type frame struct { Error *respError `json:"error,omitempty"` } -func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) { - var incoming = make(chan io.Reader) - var incomingErr error +type outChanReg struct { + id uint64 + ch reflect.Value +} - var writeLk sync.Mutex +type wsConn struct { + // outside params + conn *websocket.Conn + handler handlers + requests <-chan clientRequest + stop <-chan struct{} - // inflight are requests we sent to the remote - var inflight = map[int64]clientRequest{} + // incoming messages + incoming chan io.Reader + incomingErr error + + // outgoing messages + writeLk sync.Mutex + + // //// + // Client related + + // inflight are requests we've sent to the remote + inflight map[int64]clientRequest + + // //// + // Server related // handling are the calls we handle - var handling = map[int64]context.CancelFunc{} - var handlingLk sync.Mutex + handling map[int64]context.CancelFunc + handlingLk sync.Mutex - type outChanReg struct { - id uint64 - ch reflect.Value - } - var chOnce sync.Once + spawnOutChanHandlerOnce sync.Once // chanCtr is a counter used for identifying output channels on the server side - var chanCtr uint64 + chanCtr uint64 +} + +func (c *wsConn) handleWsConn(ctx context.Context) { + c.incoming = make(chan io.Reader) + c.inflight = map[int64]clientRequest{} + c.handling = map[int64]context.CancelFunc{} var registerCh = make(chan outChanReg) defer close(registerCh) @@ -61,27 +82,27 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r // nextMessage wait for one message and puts it to the incoming channel nextMessage := func() { - msgType, r, err := conn.NextReader() + msgType, r, err := c.conn.NextReader() if err != nil { - incomingErr = err - close(incoming) + c.incomingErr = err + close(c.incoming) return } if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage { - incomingErr = errors.New("unsupported message type") - close(incoming) + c.incomingErr = errors.New("unsupported message type") + close(c.incoming) return } - incoming <- r + c.incoming <- r } // nextWriter waits for writeLk and invokes the cb callback with WS message // writer when the lock is acquired nextWriter := func(cb func(io.Writer)) { - writeLk.Lock() - defer writeLk.Unlock() + c.writeLk.Lock() + defer c.writeLk.Unlock() - wcl, err := conn.NextWriter(websocket.TextMessage) + wcl, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { log.Error("handle me:", err) return @@ -96,13 +117,13 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r } sendReq := func(req request) { - writeLk.Lock() - if err := conn.WriteJSON(req); err != nil { + c.writeLk.Lock() + if err := c.conn.WriteJSON(req); err != nil { log.Error("handle me:", err) - writeLk.Unlock() + c.writeLk.Unlock() return } - writeLk.Unlock() + c.writeLk.Unlock() } // //// @@ -179,10 +200,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r // handleChanOut registers output channel for forwarding to client handleChanOut := func(ch reflect.Value) interface{} { - chOnce.Do(func() { + c.spawnOutChanHandlerOnce.Do(func() { go handleOutChans() }) - id := atomic.AddUint64(&chanCtr, 1) + id := atomic.AddUint64(&c.chanCtr, 1) registerCh <- outChanReg{ id: id, @@ -197,7 +218,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r // on close, make sure to return from all pending calls, and cancel context // on all calls we handle defer func() { - for id, req := range inflight { + for id, req := range c.inflight { req.ready <- clientResponse{ Jsonrpc: "2.0", ID: id, @@ -206,11 +227,11 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r }, } - handlingLk.Lock() - for _, cancel := range handling { + c.handlingLk.Lock() + for _, cancel := range c.handling { cancel() } - handlingLk.Unlock() + c.handlingLk.Unlock() } }() @@ -242,10 +263,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r return } - handlingLk.Lock() - defer handlingLk.Unlock() + c.handlingLk.Lock() + defer c.handlingLk.Unlock() - cf, ok := handling[id] + cf, ok := c.handling[id] if ok { cf() } @@ -262,10 +283,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r } select { - case r, ok := <-incoming: + case r, ok := <-c.incoming: if !ok { - if incomingErr != nil { - log.Debugf("websocket error", "error", incomingErr) + if c.incomingErr != nil { + log.Debugf("websocket error", "error", c.incomingErr) } return // remote closed } @@ -286,7 +307,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r // anything else - incoming remote call switch frame.Method { case "": // Response to our call - req, ok := inflight[*frame.ID] + req, ok := c.inflight[*frame.ID] if !ok { log.Error("client got unknown ID in response") continue @@ -311,7 +332,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r ID: *frame.ID, Error: frame.Error, } - delete(inflight, *frame.ID) + delete(c.inflight, *frame.ID) case wsCancel: cancelCtx(frame) case chValue: @@ -365,30 +386,30 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r if frame.ID != nil { nw = nextWriter - handlingLk.Lock() - handling[*frame.ID] = cf - handlingLk.Unlock() + c.handlingLk.Lock() + c.handling[*frame.ID] = cf + c.handlingLk.Unlock() done = func(keepctx bool) { - handlingLk.Lock() - defer handlingLk.Unlock() + c.handlingLk.Lock() + defer c.handlingLk.Unlock() if !keepctx { cf() - delete(handling, *frame.ID) + delete(c.handling, *frame.ID) } } } - go handler.handle(ctx, req, nw, rpcError, done, handleChanOut) + go c.handler.handle(ctx, req, nw, rpcError, done, handleChanOut) } - case req := <-requests: + case req := <-c.requests: if req.req.ID != nil { - inflight[*req.req.ID] = req + c.inflight[*req.req.ID] = req } sendReq(req.req) - case <-stop: - if err := conn.Close(); err != nil { + case <-c.stop: + if err := c.conn.Close(); err != nil { log.Debugf("websocket close error", "error", err) } return