package jsonrpc import ( "context" "encoding/json" "errors" "io" "io/ioutil" "reflect" "sync" "sync/atomic" "time" "github.com/gorilla/websocket" "golang.org/x/xerrors" ) const wsCancel = "xrpc.cancel" const chValue = "xrpc.ch.val" const chClose = "xrpc.ch.close" type frame struct { // common Jsonrpc string `json:"jsonrpc"` ID *int64 `json:"id,omitempty"` Meta map[string]string `json:"meta,omitempty"` // request Method string `json:"method,omitempty"` Params []param `json:"params,omitempty"` // response Result json.RawMessage `json:"result,omitempty"` Error *respError `json:"error,omitempty"` } type outChanReg struct { reqID int64 chID uint64 ch reflect.Value } type wsConn struct { // outside params conn *websocket.Conn connFactory func() (*websocket.Conn, error) reconnectInterval time.Duration handler handlers requests <-chan clientRequest stop <-chan struct{} exiting chan struct{} // incoming messages incoming chan io.Reader incomingErr error // outgoing messages writeLk sync.Mutex // //// // Client related // inflight are requests we've sent to the remote inflight map[int64]clientRequest // chanHandlers is a map of client-side channel handlers chanHandlers map[uint64]func(m []byte, ok bool) // //// // Server related // handling are the calls we handle handling map[int64]context.CancelFunc handlingLk sync.Mutex spawnOutChanHandlerOnce sync.Once // chanCtr is a counter used for identifying output channels on the server side chanCtr uint64 registerCh chan outChanReg } // // // WebSocket Message utils // // // // nextMessage wait for one message and puts it to the incoming channel func (c *wsConn) nextMessage() { msgType, r, err := c.conn.NextReader() if err != nil { c.incomingErr = err close(c.incoming) return } if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage { c.incomingErr = errors.New("unsupported message type") close(c.incoming) return } c.incoming <- r } // nextWriter waits for writeLk and invokes the cb callback with WS message // writer when the lock is acquired func (c *wsConn) nextWriter(cb func(io.Writer)) { c.writeLk.Lock() defer c.writeLk.Unlock() wcl, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { log.Error("handle me:", err) return } cb(wcl) if err := wcl.Close(); err != nil { log.Error("handle me:", err) return } } func (c *wsConn) sendRequest(req request) error { c.writeLk.Lock() defer c.writeLk.Unlock() if err := c.conn.WriteJSON(req); err != nil { return err } return nil } // // // Output channels // // // // handleOutChans handles channel communication on the server side // (forwards channel messages to client) func (c *wsConn) handleOutChans() { regV := reflect.ValueOf(c.registerCh) exitV := reflect.ValueOf(c.exiting) cases := []reflect.SelectCase{ { // registration chan always 0 Dir: reflect.SelectRecv, Chan: regV, }, { // exit chan always 1 Dir: reflect.SelectRecv, Chan: exitV, }, } internal := len(cases) var caseToID []uint64 for { chosen, val, ok := reflect.Select(cases) switch chosen { case 0: // registration channel if !ok { // control channel closed - signals closed connection // This shouldn't happen, instead the exiting channel should get closed log.Warn("control channel closed") return } registration := val.Interface().(outChanReg) caseToID = append(caseToID, registration.chID) cases = append(cases, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: registration.ch, }) c.nextWriter(func(w io.Writer) { resp := &response{ Jsonrpc: "2.0", ID: registration.reqID, Result: registration.chID, } if err := json.NewEncoder(w).Encode(resp); err != nil { log.Error(err) return } }) continue case 1: // exiting channel if !ok { // exiting channel closed - signals closed connection // // We're not closing any channels as we're on receiving end. // Also, context cancellation below should take care of any running // requests return } log.Warn("exiting channel received a message") continue } if !ok { // Output channel closed, cleanup, and tell remote that this happened n := len(cases) - 1 if n > 0 { cases[chosen] = cases[n] caseToID[chosen-internal] = caseToID[n-internal] } id := caseToID[chosen-internal] cases = cases[:n] caseToID = caseToID[:n-internal] if err := c.sendRequest(request{ Jsonrpc: "2.0", ID: nil, // notification Method: chClose, Params: []param{{v: reflect.ValueOf(id)}}, }); err != nil { log.Warnf("closed out channel sendRequest failed: %s", err) } continue } // forward message if err := c.sendRequest(request{ Jsonrpc: "2.0", ID: nil, // notification Method: chValue, Params: []param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}}, }); err != nil { log.Warnf("sendRequest failed: %s", err) return } } } // handleChanOut registers output channel for forwarding to client func (c *wsConn) handleChanOut(ch reflect.Value, req int64) error { c.spawnOutChanHandlerOnce.Do(func() { go c.handleOutChans() }) id := atomic.AddUint64(&c.chanCtr, 1) select { case c.registerCh <- outChanReg{ reqID: req, chID: id, ch: ch, }: return nil case <-c.exiting: return xerrors.New("connection closing") } } // // // Context.Done propagation // // // // handleCtxAsync handles context lifetimes for client // TODO: this should be aware of events going through chanHandlers, and quit // when the related channel is closed. // This should also probably be a single goroutine, // Note that not doing this should be fine for now as long as we are using // contexts correctly (cancelling when async functions are no longer is use) func (c *wsConn) handleCtxAsync(actx context.Context, id int64) { <-actx.Done() c.sendRequest(request{ Jsonrpc: "2.0", Method: wsCancel, Params: []param{{v: reflect.ValueOf(id)}}, }) } // cancelCtx is a built-in rpc which handles context cancellation over rpc func (c *wsConn) cancelCtx(req frame) { if req.ID != nil { log.Warnf("%s call with ID set, won't respond", wsCancel) } var id int64 if err := json.Unmarshal(req.Params[0].data, &id); err != nil { log.Error("handle me:", err) return } c.handlingLk.Lock() defer c.handlingLk.Unlock() cf, ok := c.handling[id] if ok { cf() } } // // // Main Handling logic // // // func (c *wsConn) handleChanMessage(frame frame) { var chid uint64 if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return } hnd, ok := c.chanHandlers[chid] if !ok { log.Errorf("xrpc.ch.val: handler %d not found", chid) return } hnd(frame.Params[1].data, true) } func (c *wsConn) handleChanClose(frame frame) { var chid uint64 if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return } hnd, ok := c.chanHandlers[chid] if !ok { log.Errorf("xrpc.ch.val: handler %d not found", chid) return } delete(c.chanHandlers, chid) hnd(nil, false) } func (c *wsConn) handleResponse(frame frame) { req, ok := c.inflight[*frame.ID] if !ok { log.Error("client got unknown ID in response") return } if req.retCh != nil && frame.Result != nil { // output is channel var chid uint64 if err := json.Unmarshal(frame.Result, &chid); err != nil { log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result)) return } var chanCtx context.Context chanCtx, c.chanHandlers[chid] = req.retCh() go c.handleCtxAsync(chanCtx, *frame.ID) } req.ready <- clientResponse{ Jsonrpc: frame.Jsonrpc, Result: frame.Result, ID: *frame.ID, Error: frame.Error, } delete(c.inflight, *frame.ID) } func (c *wsConn) handleCall(ctx context.Context, frame frame) { req := request{ Jsonrpc: frame.Jsonrpc, ID: frame.ID, Meta: frame.Meta, Method: frame.Method, Params: frame.Params, } ctx, cancel := context.WithCancel(ctx) nextWriter := func(cb func(io.Writer)) { cb(ioutil.Discard) } done := func(keepCtx bool) { if !keepCtx { cancel() } } if frame.ID != nil { nextWriter = c.nextWriter c.handlingLk.Lock() c.handling[*frame.ID] = cancel c.handlingLk.Unlock() done = func(keepctx bool) { c.handlingLk.Lock() defer c.handlingLk.Unlock() if !keepctx { cancel() delete(c.handling, *frame.ID) } } } go c.handler.handle(ctx, req, nextWriter, rpcError, done, c.handleChanOut) } // handleFrame handles all incoming messages (calls and responses) func (c *wsConn) handleFrame(ctx context.Context, frame frame) { // Get message type by method name: // "" - response // "xrpc.*" - builtin // anything else - incoming remote call switch frame.Method { case "": // Response to our call c.handleResponse(frame) case wsCancel: c.cancelCtx(frame) case chValue: c.handleChanMessage(frame) case chClose: c.handleChanClose(frame) default: // Remote call c.handleCall(ctx, frame) } } func (c *wsConn) closeInFlight() { for id, req := range c.inflight { req.ready <- clientResponse{ Jsonrpc: "2.0", ID: id, Error: &respError{ Message: "handler: websocket connection closed", Code: 2, }, } } c.handlingLk.Lock() for _, cancel := range c.handling { cancel() } c.handlingLk.Unlock() c.inflight = map[int64]clientRequest{} c.handling = map[int64]context.CancelFunc{} } func (c *wsConn) closeChans() { for chid := range c.chanHandlers { hnd := c.chanHandlers[chid] delete(c.chanHandlers, chid) hnd(nil, false) } } func (c *wsConn) handleWsConn(ctx context.Context) { c.incoming = make(chan io.Reader) c.inflight = map[int64]clientRequest{} c.handling = map[int64]context.CancelFunc{} c.chanHandlers = map[uint64]func(m []byte, ok bool){} c.registerCh = make(chan outChanReg) defer close(c.exiting) // //// // on close, make sure to return from all pending calls, and cancel context // on all calls we handle defer c.closeInFlight() // wait for the first message go c.nextMessage() for { select { case r, ok := <-c.incoming: if !ok { if c.incomingErr != nil { if !websocket.IsCloseError(c.incomingErr, websocket.CloseNormalClosure) { log.Debugw("websocket error", "error", c.incomingErr) // connection dropped unexpectedly, do our best to recover it c.closeInFlight() c.closeChans() c.incoming = make(chan io.Reader) // listen again for responses go func() { if c.connFactory == nil { // likely the server side, don't try to reconnect return } var conn *websocket.Conn for conn == nil { time.Sleep(c.reconnectInterval) var err error if conn, err = c.connFactory(); err != nil { log.Debugw("websocket connection retry failed", "error", err) } } c.writeLk.Lock() c.conn = conn c.incomingErr = nil c.writeLk.Unlock() go c.nextMessage() }() continue } } return // remote closed } // debug util - dump all messages to stderr // r = io.TeeReader(r, os.Stderr) var frame frame if err := json.NewDecoder(r).Decode(&frame); err != nil { log.Error("handle me:", err) return } c.handleFrame(ctx, frame) go c.nextMessage() case req := <-c.requests: c.writeLk.Lock() if req.req.ID != nil { if c.incomingErr != nil { // No conn?, immediate fail req.ready <- clientResponse{ Jsonrpc: "2.0", ID: *req.req.ID, Error: &respError{ Message: "handler: websocket connection closed", Code: 2, }, } c.writeLk.Unlock() break } c.inflight[*req.req.ID] = req } c.writeLk.Unlock() if err := c.sendRequest(req.req); err != nil { log.Errorf("sendReqest failed (Handle me): %s", err) } case <-c.stop: c.writeLk.Lock() cmsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") if err := c.conn.WriteMessage(websocket.CloseMessage, cmsg); err != nil { log.Warn("failed to write close message: ", err) } if err := c.conn.Close(); err != nil { log.Warnw("websocket close error", "error", err) } c.writeLk.Unlock() return } } }