Context cancellation over websockets

This commit is contained in:
Łukasz Magiera 2019-07-15 18:21:48 +02:00
parent b93d71e8cb
commit 1153f050bb
4 changed files with 151 additions and 55 deletions

View File

@ -151,14 +151,40 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser
req: req,
ready: rchan,
}
resp := <- rchan
var ctxDone <-chan struct{}
var resp clientResponse
if hasCtx == 1 {
ctxDone = args[0].Interface().(context.Context).Done()
}
loop:
for {
select {
case resp = <-rchan:
break loop
case <-ctxDone: // send cancel request
ctxDone = nil
requests <- clientRequest{
req: request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(id)}},
},
}
}
}
var rval reflect.Value
if valOut != -1 {
log.Debugw("rpc result", "type", ftyp.Out(valOut))
rval = reflect.New(ftyp.Out(valOut))
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
return processError(err)
if resp.Result != nil {
log.Debugw("rpc result", "type", ftyp.Out(valOut))
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
return processError(err)
}
}
}

View File

@ -92,19 +92,25 @@ func (h handlers) register(namespace string, r interface{}) {
// Handle
type rpcErrFunc func(w io.Writer, req *request, code int, err error)
type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
wf := func(cb func(io.Writer)) {
cb(w)
}
var req request
if err := json.NewDecoder(r).Decode(&req); err != nil {
rpcError(w, &req, rpcParseError, err)
rpcError(wf, &req, rpcParseError, err)
return
}
h.handle(ctx, req, w, rpcError)
h.handle(ctx, req, wf, rpcError, func() {})
}
func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError rpcErrFunc) {
func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func()) {
defer done()
handler, ok := h[req.Method]
if !ok {
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
@ -159,8 +165,10 @@ func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError
resp.Result = callResult[handler.valOut].Interface()
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
fmt.Println(err)
return
}
w(func(w io.Writer) {
if err := json.NewEncoder(w).Encode(resp); err != nil {
fmt.Println(err)
return
}
})
}

View File

@ -47,25 +47,30 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
s.methods.handleReader(r.Context(), r.Body, w, s.rpcError)
s.methods.handleReader(r.Context(), r.Body, w, rpcError)
}
func (s *RPCServer) rpcError(w io.Writer, req *request, code int, err error) {
w.(http.ResponseWriter).WriteHeader(500)
if req.ID == nil { // notification
return
}
func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {
wf(func(w io.Writer) {
if hw, ok := w.(http.ResponseWriter); ok {
hw.WriteHeader(500)
}
resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
Error: &respError{
Code: code,
Message: err.Error(),
},
}
if req.ID == nil { // notification
return
}
_ = json.NewEncoder(w).Encode(resp)
resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
Error: &respError{
Code: code,
Message: err.Error(),
},
}
_ = json.NewEncoder(w).Encode(resp)
})
}
// Register registers new RPC handler

View File

@ -5,10 +5,14 @@ import (
"encoding/json"
"errors"
"io"
"io/ioutil"
"sync"
"github.com/gorilla/websocket"
)
const wsCancel = "xrpc.cancel"
type frame struct {
// common
Jsonrpc string `json:"jsonrpc"`
@ -42,9 +46,50 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
incoming <- r
}
var writeLk sync.Mutex
nextWriter := func(cb func(io.Writer)) {
writeLk.Lock()
defer writeLk.Unlock()
wcl, err := conn.NextWriter(websocket.TextMessage)
if err != nil {
log.Error("handle me:", err)
return
}
cb(wcl)
if err := wcl.Close(); err != nil {
log.Error("handle me:", err)
return
}
}
go nextMessage()
inflight := map[int64]clientRequest{}
handling := map[int64]context.CancelFunc{}
var handlingLk sync.Mutex
cancelCtx := func(req frame) {
if req.ID != nil {
log.Warnf("%s call with ID set, won't respond", wsCancel)
}
var id int64
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
log.Error("handle me:", err)
return
}
handlingLk.Lock()
defer handlingLk.Unlock()
cf, ok := handling[id]
if ok {
cf()
}
}
for {
select {
@ -62,33 +107,8 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
return
}
if frame.Method != "" {
// call
req := request{
Jsonrpc: frame.Jsonrpc,
ID: frame.ID,
Method: frame.Method,
Params: frame.Params,
}
// TODO: ignore ID
wcl, err := conn.NextWriter(websocket.TextMessage)
if err != nil {
log.Error("handle me:", err)
return
}
handler.handle(ctx, req, wcl, func(w io.Writer, req *request, code int, err error) {
log.Error("handle me:", err) // TODO: seriously
return
})
if err := wcl.Close(); err != nil {
log.Error("handle me:", err)
return
}
} else {
// response
switch frame.Method {
case "": // Response to our call
req, ok := inflight[*frame.ID]
if !ok {
log.Error("client got unknown ID in response")
@ -102,11 +122,48 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
Error: frame.Error,
}
delete(inflight, *frame.ID)
case wsCancel:
cancelCtx(frame)
default: // Remote call
req := request{
Jsonrpc: frame.Jsonrpc,
ID: frame.ID,
Method: frame.Method,
Params: frame.Params,
}
ctx, cf := context.WithCancel(ctx)
nw := func(cb func(io.Writer)) {
cb(ioutil.Discard)
}
done := func(){}
if frame.ID != nil {
nw = nextWriter
handlingLk.Lock()
handling[*frame.ID] = cf
handlingLk.Unlock()
done = func() {
handlingLk.Lock()
defer handlingLk.Unlock()
cf := handling[*frame.ID]
cf()
delete(handling, *frame.ID)
}
}
go handler.handle(ctx, req, nw, rpcError, done)
}
go nextMessage()
case req := <-requests:
inflight[*req.req.ID] = req
if req.req.ID != nil {
inflight[*req.req.ID] = req
}
if err := conn.WriteJSON(req.req); err != nil {
log.Error("handle me:", err)
return