jsonrpc: Break handleWsConn into smaller functions

This commit is contained in:
Łukasz Magiera 2019-07-23 03:45:10 +02:00
parent 2f0a088b18
commit 9b3ae45f61
2 changed files with 319 additions and 288 deletions

View File

@ -56,6 +56,9 @@ type wsConn struct {
// inflight are requests we've sent to the remote // inflight are requests we've sent to the remote
inflight map[int64]clientRequest inflight map[int64]clientRequest
// chanHandlers is a map of client-side channel handlers
chanHandlers map[uint64]func(m []byte, ok bool)
// //// // ////
// Server related // Server related
@ -67,21 +70,16 @@ type wsConn struct {
// chanCtr is a counter used for identifying output channels on the server side // chanCtr is a counter used for identifying output channels on the server side
chanCtr uint64 chanCtr uint64
registerCh chan outChanReg
} }
func (c *wsConn) handleWsConn(ctx context.Context) { // //
c.incoming = make(chan io.Reader) // WebSocket Message utils //
c.inflight = map[int64]clientRequest{} // //
c.handling = map[int64]context.CancelFunc{}
var registerCh = make(chan outChanReg)
defer close(registerCh)
// chanHandlers is a map of client-side channel handlers
chanHandlers := map[uint64]func(m []byte, ok bool){}
// nextMessage wait for one message and puts it to the incoming channel // nextMessage wait for one message and puts it to the incoming channel
nextMessage := func() { func (c *wsConn) nextMessage() {
msgType, r, err := c.conn.NextReader() msgType, r, err := c.conn.NextReader()
if err != nil { if err != nil {
c.incomingErr = err c.incomingErr = err
@ -98,7 +96,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
// nextWriter waits for writeLk and invokes the cb callback with WS message // nextWriter waits for writeLk and invokes the cb callback with WS message
// writer when the lock is acquired // writer when the lock is acquired
nextWriter := func(cb func(io.Writer)) { func (c *wsConn) nextWriter(cb func(io.Writer)) {
c.writeLk.Lock() c.writeLk.Lock()
defer c.writeLk.Unlock() defer c.writeLk.Unlock()
@ -116,7 +114,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
} }
} }
sendReq := func(req request) { func (c *wsConn) sendRequest(req request) {
c.writeLk.Lock() c.writeLk.Lock()
if err := c.conn.WriteJSON(req); err != nil { if err := c.conn.WriteJSON(req); err != nil {
log.Error("handle me:", err) log.Error("handle me:", err)
@ -126,13 +124,14 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
c.writeLk.Unlock() c.writeLk.Unlock()
} }
// //// // //
// Subscriptions (func() <-chan Typ - like methods) // Output channels //
// //
// handleOutChans handles channel communication on the server side // handleOutChans handles channel communication on the server side
// (forwards channel messages to client) // (forwards channel messages to client)
handleOutChans := func() { func (c *wsConn) handleOutChans() {
regV := reflect.ValueOf(registerCh) regV := reflect.ValueOf(c.registerCh)
cases := []reflect.SelectCase{ cases := []reflect.SelectCase{
{ // registration chan always 0 { // registration chan always 0
@ -179,7 +178,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
cases = cases[:n] cases = cases[:n]
caseToID = caseToID[:n-1] caseToID = caseToID[:n-1]
sendReq(request{ c.sendRequest(request{
Jsonrpc: "2.0", Jsonrpc: "2.0",
ID: nil, // notification ID: nil, // notification
Method: chClose, Method: chClose,
@ -189,7 +188,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
} }
// forward message // forward message
sendReq(request{ c.sendRequest(request{
Jsonrpc: "2.0", Jsonrpc: "2.0",
ID: nil, // notification ID: nil, // notification
Method: chValue, Method: chValue,
@ -199,13 +198,13 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
} }
// handleChanOut registers output channel for forwarding to client // handleChanOut registers output channel for forwarding to client
handleChanOut := func(ch reflect.Value) interface{} { func (c *wsConn) handleChanOut(ch reflect.Value) interface{} {
c.spawnOutChanHandlerOnce.Do(func() { c.spawnOutChanHandlerOnce.Do(func() {
go handleOutChans() go c.handleOutChans()
}) })
id := atomic.AddUint64(&c.chanCtr, 1) id := atomic.AddUint64(&c.chanCtr, 1)
registerCh <- outChanReg{ c.registerCh <- outChanReg{
id: id, id: id,
ch: ch, ch: ch,
} }
@ -213,6 +212,182 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
return id return id
} }
// //
// Context.Done propagation //
// //
// handleCtxAsync handles context lifetimes for client
// TODO: this should be aware of events going through chanHandlers, and quit
// when the related channel is closed.
// This should also probably be a single goroutine,
// Note that not doing this should be fine for now as long as we are using
// contexts correctly (cancelling when async functions are no longer is use)
func (c *wsConn) handleCtxAsync(actx context.Context, id int64) {
<-actx.Done()
c.sendRequest(request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(id)}},
})
}
// cancelCtx is a built-in rpc which handles context cancellation over rpc
func (c *wsConn) cancelCtx(req frame) {
if req.ID != nil {
log.Warnf("%s call with ID set, won't respond", wsCancel)
}
var id int64
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
log.Error("handle me:", err)
return
}
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
cf, ok := c.handling[id]
if ok {
cf()
}
}
// //
// Main Handling logic //
// //
func (c *wsConn) handleChanMessage(frame frame) {
var chid uint64
if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
hnd, ok := c.chanHandlers[chid]
if !ok {
log.Errorf("xrpc.ch.val: handler %d not found", chid)
return
}
hnd(frame.Params[1].data, true)
}
func (c *wsConn) handleChanClose(frame frame) {
var chid uint64
if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
return
}
hnd, ok := c.chanHandlers[chid]
if !ok {
log.Errorf("xrpc.ch.val: handler %d not found", chid)
return
}
delete(c.chanHandlers, chid)
hnd(nil, false)
}
func (c *wsConn) handleResponse(frame frame) {
req, ok := c.inflight[*frame.ID]
if !ok {
log.Error("client got unknown ID in response")
return
}
if req.retCh != nil && frame.Result != nil {
// output is channel
var chid uint64
if err := json.Unmarshal(frame.Result, &chid); err != nil {
log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result))
return
}
var chanCtx context.Context
chanCtx, c.chanHandlers[chid] = req.retCh()
go c.handleCtxAsync(chanCtx, *frame.ID)
}
req.ready <- clientResponse{
Jsonrpc: frame.Jsonrpc,
Result: frame.Result,
ID: *frame.ID,
Error: frame.Error,
}
delete(c.inflight, *frame.ID)
}
func (c *wsConn) handleCall(ctx context.Context, frame frame) {
req := request{
Jsonrpc: frame.Jsonrpc,
ID: frame.ID,
Method: frame.Method,
Params: frame.Params,
}
ctx, cancel := context.WithCancel(ctx)
nextWriter := func(cb func(io.Writer)) {
cb(ioutil.Discard)
}
done := func(keepCtx bool) {
if !keepCtx {
cancel()
}
}
if frame.ID != nil {
nextWriter = c.nextWriter
c.handlingLk.Lock()
c.handling[*frame.ID] = cancel
c.handlingLk.Unlock()
done = func(keepctx bool) {
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
if !keepctx {
cancel()
delete(c.handling, *frame.ID)
}
}
}
go c.handler.handle(ctx, req, nextWriter, rpcError, done, c.handleChanOut)
}
// handleFrame handles all incoming messages (calls and responses)
func (c *wsConn) handleFrame(ctx context.Context, frame frame) {
// Get message type by method name:
// "" - response
// "xrpc.*" - builtin
// anything else - incoming remote call
switch frame.Method {
case "": // Response to our call
c.handleResponse(frame)
case wsCancel:
c.cancelCtx(frame)
case chValue:
c.handleChanMessage(frame)
case chClose:
c.handleChanClose(frame)
default: // Remote call
c.handleCall(ctx, frame)
}
}
func (c *wsConn) handleWsConn(ctx context.Context) {
c.incoming = make(chan io.Reader)
c.inflight = map[int64]clientRequest{}
c.handling = map[int64]context.CancelFunc{}
c.chanHandlers = map[uint64]func(m []byte, ok bool){}
c.registerCh = make(chan outChanReg)
defer close(c.registerCh)
// //// // ////
// on close, make sure to return from all pending calls, and cancel context // on close, make sure to return from all pending calls, and cancel context
@ -235,53 +410,10 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
} }
}() }()
// handleCtxAsync handles context lifetimes for client
// TODO: this should be aware of events going through chanHandlers, and quit
// when the related channel is closed.
// This should also probably be a single goroutine,
// Note that not doing this should be fine for now as long as we are using
// contexts correctly (cancelling when async functions are no longer is use)
handleCtxAsync := func(actx context.Context, id int64) {
<-actx.Done()
sendReq(request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(id)}},
})
}
// cancelCtx is a built-in rpc which handles context cancellation over rpc
cancelCtx := func(req frame) {
if req.ID != nil {
log.Warnf("%s call with ID set, won't respond", wsCancel)
}
var id int64
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
log.Error("handle me:", err)
return
}
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
cf, ok := c.handling[id]
if ok {
cf()
}
}
// wait for the first message // wait for the first message
go nextMessage() go c.nextMessage()
var msgConsumed bool
for { for {
if msgConsumed {
msgConsumed = false
go nextMessage()
}
select { select {
case r, ok := <-c.incoming: case r, ok := <-c.incoming:
if !ok { if !ok {
@ -290,7 +422,6 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
} }
return // remote closed return // remote closed
} }
msgConsumed = true
// debug util - dump all messages to stderr // debug util - dump all messages to stderr
// r = io.TeeReader(r, os.Stderr) // r = io.TeeReader(r, os.Stderr)
@ -301,113 +432,13 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
return return
} }
// Get message type by method name: c.handleFrame(ctx, frame)
// "" - response go c.nextMessage()
// "xrpc.*" - builtin
// anything else - incoming remote call
switch frame.Method {
case "": // Response to our call
req, ok := c.inflight[*frame.ID]
if !ok {
log.Error("client got unknown ID in response")
continue
}
if req.retCh != nil && frame.Result != nil {
// output is channel
var chid uint64
if err := json.Unmarshal(frame.Result, &chid); err != nil {
log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result))
continue
}
var chanCtx context.Context
chanCtx, chanHandlers[chid] = req.retCh()
go handleCtxAsync(chanCtx, *frame.ID)
}
req.ready <- clientResponse{
Jsonrpc: frame.Jsonrpc,
Result: frame.Result,
ID: *frame.ID,
Error: frame.Error,
}
delete(c.inflight, *frame.ID)
case wsCancel:
cancelCtx(frame)
case chValue:
var chid uint64
if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
continue
}
hnd, ok := chanHandlers[chid]
if !ok {
log.Errorf("xrpc.ch.val: handler %d not found", chid)
continue
}
hnd(frame.Params[1].data, true)
case chClose:
var chid uint64
if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil {
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
continue
}
hnd, ok := chanHandlers[chid]
if !ok {
log.Errorf("xrpc.ch.val: handler %d not found", chid)
continue
}
delete(chanHandlers, chid)
hnd(nil, false)
default: // Remote call
req := request{
Jsonrpc: frame.Jsonrpc,
ID: frame.ID,
Method: frame.Method,
Params: frame.Params,
}
ctx, cf := context.WithCancel(ctx)
nw := func(cb func(io.Writer)) {
cb(ioutil.Discard)
}
done := func(keepctx bool) {
if !keepctx {
cf()
}
}
if frame.ID != nil {
nw = nextWriter
c.handlingLk.Lock()
c.handling[*frame.ID] = cf
c.handlingLk.Unlock()
done = func(keepctx bool) {
c.handlingLk.Lock()
defer c.handlingLk.Unlock()
if !keepctx {
cf()
delete(c.handling, *frame.ID)
}
}
}
go c.handler.handle(ctx, req, nw, rpcError, done, handleChanOut)
}
case req := <-c.requests: case req := <-c.requests:
if req.req.ID != nil { if req.req.ID != nil {
c.inflight[*req.req.ID] = req c.inflight[*req.req.ID] = req
} }
sendReq(req.req) c.sendRequest(req.req)
case <-c.stop: case <-c.stop:
if err := c.conn.Close(); err != nil { if err := c.conn.Close(); err != nil {
log.Debugf("websocket close error", "error", err) log.Debugf("websocket close error", "error", err)