jsonrpc: cleanup websocket handling logic a bit

This commit is contained in:
Łukasz Magiera 2019-07-23 02:23:19 +02:00
parent 40fa1becb5
commit 1b1ec2b812

View File

@ -32,27 +32,51 @@ type frame struct {
} }
func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) { func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) {
incoming := make(chan io.Reader) var incoming = make(chan io.Reader)
var incErr error var incomingErr error
var writeLk sync.Mutex
// inflight are requests we sent to the remote
var inflight = map[int64]clientRequest{}
// handling are the calls we handle
var handling = map[int64]context.CancelFunc{}
var handlingLk sync.Mutex
type outChanReg struct {
id uint64
ch reflect.Value
}
var chOnce sync.Once
// chanCtr is a counter used for identifying output channels on the server side
var chanCtr uint64
var registerCh = make(chan outChanReg)
defer close(registerCh)
// chanHandlers is a map of client-side channel handlers
chanHandlers := map[uint64]func(m []byte, ok bool){}
// nextMessage wait for one message and puts it to the incoming channel // nextMessage wait for one message and puts it to the incoming channel
nextMessage := func() { nextMessage := func() {
mtype, r, err := conn.NextReader() msgType, r, err := conn.NextReader()
if err != nil { if err != nil {
incErr = err incomingErr = err
close(incoming) close(incoming)
return return
} }
if mtype != websocket.BinaryMessage && mtype != websocket.TextMessage { if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage {
incErr = errors.New("unsupported message type") incomingErr = errors.New("unsupported message type")
close(incoming) close(incoming)
return return
} }
incoming <- r incoming <- r
} }
var writeLk sync.Mutex
// nextWriter waits for writeLk and invokes the cb callback with WS message // nextWriter waits for writeLk and invokes the cb callback with WS message
// writer when the lock is acquired // writer when the lock is acquired
nextWriter := func(cb func(io.Writer)) { nextWriter := func(cb func(io.Writer)) {
@ -83,28 +107,11 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
writeLk.Unlock() writeLk.Unlock()
} }
// wait for the first message
go nextMessage()
// inflight are requests we sent to the remote
inflight := map[int64]clientRequest{}
// handling are the calls we handle
handling := map[int64]context.CancelFunc{}
var handlingLk sync.Mutex
// //// // ////
// Subscriptions (func() <-chan Typ - like methods) // Subscriptions (func() <-chan Typ - like methods)
var chOnce sync.Once // handleOutChans handles channel communication on the server side
var outId uint64 // (forwards channel messages to client)
type chReg struct {
id uint64
ch reflect.Value
}
registerCh := make(chan chReg)
defer close(registerCh)
handleOutChans := func() { handleOutChans := func() {
regV := reflect.ValueOf(registerCh) regV := reflect.ValueOf(registerCh)
@ -121,13 +128,15 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
if chosen == 0 { // control channel if chosen == 0 { // control channel
if !ok { if !ok {
// not closing any channels as we're on receiving end. // control channel closed - signals closed connection
//
// We're not closing any channels as we're on receiving end.
// Also, context cancellation below should take care of any running // Also, context cancellation below should take care of any running
// requests // requests
return return
} }
registration := val.Interface().(chReg) registration := val.Interface().(outChanReg)
caseToId = append(caseToId, registration.id) caseToId = append(caseToId, registration.id)
cases = append(cases, reflect.SelectCase{ cases = append(cases, reflect.SelectCase{
@ -139,6 +148,8 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
if !ok { if !ok {
// Output channel closed, cleanup, and tell remote that this happened
n := len(caseToId) n := len(caseToId)
if n > 0 { if n > 0 {
cases[chosen] = cases[n] cases[chosen] = cases[n]
@ -158,6 +169,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
continue continue
} }
// forward message
sendReq(request{ sendReq(request{
Jsonrpc: "2.0", Jsonrpc: "2.0",
ID: nil, // notification ID: nil, // notification
@ -167,13 +179,14 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
} }
// handleChanOut registers output channel for forwarding to client
handleChanOut := func(ch reflect.Value) interface{} { handleChanOut := func(ch reflect.Value) interface{} {
chOnce.Do(func() { chOnce.Do(func() {
go handleOutChans() go handleOutChans()
}) })
id := atomic.AddUint64(&outId, 1) id := atomic.AddUint64(&chanCtr, 1)
registerCh <- chReg{ registerCh <- outChanReg{
id: id, id: id,
ch: ch, ch: ch,
} }
@ -181,10 +194,6 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
return id return id
} }
// client side subs
chanHandlers := map[uint64]func(m []byte, ok bool){}
// //// // ////
// on close, make sure to return from all pending calls, and cancel context // on close, make sure to return from all pending calls, and cancel context
@ -207,6 +216,12 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
}() }()
// handleCtxAsync handles context lifetimes for client
// TODO: this should be aware of events going through chanHandlers, and quit
// when the related channel is closed.
// This should also probably be a single goroutine,
// Note that not doing this should be fine for now as long as we are using
// contexts correctly (cancelling when async functions are no longer is use)
handleCtxAsync := func(actx context.Context, id int64) { handleCtxAsync := func(actx context.Context, id int64) {
<-actx.Done() <-actx.Done()
@ -238,15 +253,25 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
} }
// wait for the first message
go nextMessage()
var msgConsumed bool
for { for {
if msgConsumed {
msgConsumed = false
go nextMessage()
}
select { select {
case r, ok := <-incoming: case r, ok := <-incoming:
if !ok { if !ok {
if incErr != nil { if incomingErr != nil {
log.Debugf("websocket error", "error", incErr) log.Debugf("websocket error", "error", incomingErr)
} }
return // remote closed return // remote closed
} }
msgConsumed = true
// debug util - dump all messages to stderr // debug util - dump all messages to stderr
// r = io.TeeReader(r, os.Stderr) // r = io.TeeReader(r, os.Stderr)
@ -359,8 +384,6 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
go handler.handle(ctx, req, nw, rpcError, done, handleChanOut) go handler.handle(ctx, req, nw, rpcError, done, handleChanOut)
} }
go nextMessage() // TODO: fix on errors
case req := <-requests: case req := <-requests:
if req.req.ID != nil { if req.req.ID != nil {
inflight[*req.req.ID] = req inflight[*req.req.ID] = req