jsonrpc: Move ws handler state to a struct

This commit is contained in:
Łukasz Magiera 2019-07-23 03:20:48 +02:00
parent 0d5d6cd1c2
commit 2f0a088b18
3 changed files with 83 additions and 54 deletions

View File

@ -80,7 +80,12 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser
requests := make(chan clientRequest) requests := make(chan clientRequest)
handlers := map[string]rpcHandler{} handlers := map[string]rpcHandler{}
go handleWsConn(context.TODO(), conn, handlers, requests, stop) go (&wsConn{
conn: conn,
handler: handlers,
requests: requests,
stop: stop,
}).handleWsConn(context.TODO())
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i) f := typ.Field(i)

View File

@ -36,7 +36,10 @@ func (s *RPCServer) handleWS(w http.ResponseWriter, r *http.Request) {
return return
} }
handleWsConn(r.Context(), c, s.methods, nil, nil) (&wsConn{
conn: c,
handler: s.methods,
}).handleWsConn(r.Context())
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
log.Error(err) log.Error(err)

View File

@ -31,27 +31,48 @@ type frame struct {
Error *respError `json:"error,omitempty"` Error *respError `json:"error,omitempty"`
} }
func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) { type outChanReg struct {
var incoming = make(chan io.Reader) id uint64
var incomingErr error ch reflect.Value
}
var writeLk sync.Mutex type wsConn struct {
// outside params
conn *websocket.Conn
handler handlers
requests <-chan clientRequest
stop <-chan struct{}
// inflight are requests we sent to the remote // incoming messages
var inflight = map[int64]clientRequest{} incoming chan io.Reader
incomingErr error
// outgoing messages
writeLk sync.Mutex
// ////
// Client related
// inflight are requests we've sent to the remote
inflight map[int64]clientRequest
// ////
// Server related
// handling are the calls we handle // handling are the calls we handle
var handling = map[int64]context.CancelFunc{} handling map[int64]context.CancelFunc
var handlingLk sync.Mutex handlingLk sync.Mutex
type outChanReg struct { spawnOutChanHandlerOnce sync.Once
id uint64
ch reflect.Value
}
var chOnce sync.Once
// chanCtr is a counter used for identifying output channels on the server side // chanCtr is a counter used for identifying output channels on the server side
var chanCtr uint64 chanCtr uint64
}
func (c *wsConn) handleWsConn(ctx context.Context) {
c.incoming = make(chan io.Reader)
c.inflight = map[int64]clientRequest{}
c.handling = map[int64]context.CancelFunc{}
var registerCh = make(chan outChanReg) var registerCh = make(chan outChanReg)
defer close(registerCh) defer close(registerCh)
@ -61,27 +82,27 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
// nextMessage wait for one message and puts it to the incoming channel // nextMessage wait for one message and puts it to the incoming channel
nextMessage := func() { nextMessage := func() {
msgType, r, err := conn.NextReader() msgType, r, err := c.conn.NextReader()
if err != nil { if err != nil {
incomingErr = err c.incomingErr = err
close(incoming) close(c.incoming)
return return
} }
if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage { if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage {
incomingErr = errors.New("unsupported message type") c.incomingErr = errors.New("unsupported message type")
close(incoming) close(c.incoming)
return return
} }
incoming <- r c.incoming <- r
} }
// nextWriter waits for writeLk and invokes the cb callback with WS message // nextWriter waits for writeLk and invokes the cb callback with WS message
// writer when the lock is acquired // writer when the lock is acquired
nextWriter := func(cb func(io.Writer)) { nextWriter := func(cb func(io.Writer)) {
writeLk.Lock() c.writeLk.Lock()
defer writeLk.Unlock() defer c.writeLk.Unlock()
wcl, err := conn.NextWriter(websocket.TextMessage) wcl, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil { if err != nil {
log.Error("handle me:", err) log.Error("handle me:", err)
return return
@ -96,13 +117,13 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
sendReq := func(req request) { sendReq := func(req request) {
writeLk.Lock() c.writeLk.Lock()
if err := conn.WriteJSON(req); err != nil { if err := c.conn.WriteJSON(req); err != nil {
log.Error("handle me:", err) log.Error("handle me:", err)
writeLk.Unlock() c.writeLk.Unlock()
return return
} }
writeLk.Unlock() c.writeLk.Unlock()
} }
// //// // ////
@ -179,10 +200,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
// handleChanOut registers output channel for forwarding to client // handleChanOut registers output channel for forwarding to client
handleChanOut := func(ch reflect.Value) interface{} { handleChanOut := func(ch reflect.Value) interface{} {
chOnce.Do(func() { c.spawnOutChanHandlerOnce.Do(func() {
go handleOutChans() go handleOutChans()
}) })
id := atomic.AddUint64(&chanCtr, 1) id := atomic.AddUint64(&c.chanCtr, 1)
registerCh <- outChanReg{ registerCh <- outChanReg{
id: id, id: id,
@ -197,7 +218,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
// on close, make sure to return from all pending calls, and cancel context // on close, make sure to return from all pending calls, and cancel context
// on all calls we handle // on all calls we handle
defer func() { defer func() {
for id, req := range inflight { for id, req := range c.inflight {
req.ready <- clientResponse{ req.ready <- clientResponse{
Jsonrpc: "2.0", Jsonrpc: "2.0",
ID: id, ID: id,
@ -206,11 +227,11 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
}, },
} }
handlingLk.Lock() c.handlingLk.Lock()
for _, cancel := range handling { for _, cancel := range c.handling {
cancel() cancel()
} }
handlingLk.Unlock() c.handlingLk.Unlock()
} }
}() }()
@ -242,10 +263,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
return return
} }
handlingLk.Lock() c.handlingLk.Lock()
defer handlingLk.Unlock() defer c.handlingLk.Unlock()
cf, ok := handling[id] cf, ok := c.handling[id]
if ok { if ok {
cf() cf()
} }
@ -262,10 +283,10 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
} }
select { select {
case r, ok := <-incoming: case r, ok := <-c.incoming:
if !ok { if !ok {
if incomingErr != nil { if c.incomingErr != nil {
log.Debugf("websocket error", "error", incomingErr) log.Debugf("websocket error", "error", c.incomingErr)
} }
return // remote closed return // remote closed
} }
@ -286,7 +307,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
// anything else - incoming remote call // anything else - incoming remote call
switch frame.Method { switch frame.Method {
case "": // Response to our call case "": // Response to our call
req, ok := inflight[*frame.ID] req, ok := c.inflight[*frame.ID]
if !ok { if !ok {
log.Error("client got unknown ID in response") log.Error("client got unknown ID in response")
continue continue
@ -311,7 +332,7 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
ID: *frame.ID, ID: *frame.ID,
Error: frame.Error, Error: frame.Error,
} }
delete(inflight, *frame.ID) delete(c.inflight, *frame.ID)
case wsCancel: case wsCancel:
cancelCtx(frame) cancelCtx(frame)
case chValue: case chValue:
@ -365,30 +386,30 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
if frame.ID != nil { if frame.ID != nil {
nw = nextWriter nw = nextWriter
handlingLk.Lock() c.handlingLk.Lock()
handling[*frame.ID] = cf c.handling[*frame.ID] = cf
handlingLk.Unlock() c.handlingLk.Unlock()
done = func(keepctx bool) { done = func(keepctx bool) {
handlingLk.Lock() c.handlingLk.Lock()
defer handlingLk.Unlock() defer c.handlingLk.Unlock()
if !keepctx { if !keepctx {
cf() cf()
delete(handling, *frame.ID) delete(c.handling, *frame.ID)
} }
} }
} }
go handler.handle(ctx, req, nw, rpcError, done, handleChanOut) go c.handler.handle(ctx, req, nw, rpcError, done, handleChanOut)
} }
case req := <-requests: case req := <-c.requests:
if req.req.ID != nil { if req.req.ID != nil {
inflight[*req.req.ID] = req c.inflight[*req.req.ID] = req
} }
sendReq(req.req) sendReq(req.req)
case <-stop: case <-c.stop:
if err := conn.Close(); err != nil { if err := c.conn.Close(); err != nil {
log.Debugf("websocket close error", "error", err) log.Debugf("websocket close error", "error", err)
} }
return return