Context cancellation over websockets
This commit is contained in:
parent
b93d71e8cb
commit
1153f050bb
@ -151,14 +151,40 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser
|
||||
req: req,
|
||||
ready: rchan,
|
||||
}
|
||||
resp := <- rchan
|
||||
var ctxDone <-chan struct{}
|
||||
var resp clientResponse
|
||||
|
||||
if hasCtx == 1 {
|
||||
ctxDone = args[0].Interface().(context.Context).Done()
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case resp = <-rchan:
|
||||
break loop
|
||||
case <-ctxDone: // send cancel request
|
||||
ctxDone = nil
|
||||
|
||||
requests <- clientRequest{
|
||||
req: request{
|
||||
Jsonrpc: "2.0",
|
||||
Method: wsCancel,
|
||||
Params: []param{{v: reflect.ValueOf(id)}},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
var rval reflect.Value
|
||||
|
||||
if valOut != -1 {
|
||||
log.Debugw("rpc result", "type", ftyp.Out(valOut))
|
||||
rval = reflect.New(ftyp.Out(valOut))
|
||||
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
|
||||
return processError(err)
|
||||
|
||||
if resp.Result != nil {
|
||||
log.Debugw("rpc result", "type", ftyp.Out(valOut))
|
||||
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
|
||||
return processError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,19 +92,25 @@ func (h handlers) register(namespace string, r interface{}) {
|
||||
|
||||
// Handle
|
||||
|
||||
type rpcErrFunc func(w io.Writer, req *request, code int, err error)
|
||||
type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
|
||||
|
||||
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
|
||||
wf := func(cb func(io.Writer)) {
|
||||
cb(w)
|
||||
}
|
||||
|
||||
var req request
|
||||
if err := json.NewDecoder(r).Decode(&req); err != nil {
|
||||
rpcError(w, &req, rpcParseError, err)
|
||||
rpcError(wf, &req, rpcParseError, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.handle(ctx, req, w, rpcError)
|
||||
h.handle(ctx, req, wf, rpcError, func() {})
|
||||
}
|
||||
|
||||
func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError rpcErrFunc) {
|
||||
func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func()) {
|
||||
defer done()
|
||||
|
||||
handler, ok := h[req.Method]
|
||||
if !ok {
|
||||
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
|
||||
@ -159,8 +165,10 @@ func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError
|
||||
resp.Result = callResult[handler.valOut].Interface()
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
w(func(w io.Writer) {
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -47,25 +47,30 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.methods.handleReader(r.Context(), r.Body, w, s.rpcError)
|
||||
s.methods.handleReader(r.Context(), r.Body, w, rpcError)
|
||||
}
|
||||
|
||||
func (s *RPCServer) rpcError(w io.Writer, req *request, code int, err error) {
|
||||
w.(http.ResponseWriter).WriteHeader(500)
|
||||
if req.ID == nil { // notification
|
||||
return
|
||||
}
|
||||
func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {
|
||||
wf(func(w io.Writer) {
|
||||
if hw, ok := w.(http.ResponseWriter); ok {
|
||||
hw.WriteHeader(500)
|
||||
}
|
||||
|
||||
resp := response{
|
||||
Jsonrpc: "2.0",
|
||||
ID: *req.ID,
|
||||
Error: &respError{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
if req.ID == nil { // notification
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
resp := response{
|
||||
Jsonrpc: "2.0",
|
||||
ID: *req.ID,
|
||||
Error: &respError{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
}
|
||||
|
||||
// Register registers new RPC handler
|
||||
|
@ -5,10 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const wsCancel = "xrpc.cancel"
|
||||
|
||||
type frame struct {
|
||||
// common
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
@ -42,9 +46,50 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
||||
incoming <- r
|
||||
}
|
||||
|
||||
var writeLk sync.Mutex
|
||||
nextWriter := func(cb func(io.Writer)) {
|
||||
writeLk.Lock()
|
||||
defer writeLk.Unlock()
|
||||
|
||||
wcl, err := conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
}
|
||||
|
||||
cb(wcl)
|
||||
|
||||
if err := wcl.Close(); err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go nextMessage()
|
||||
|
||||
inflight := map[int64]clientRequest{}
|
||||
handling := map[int64]context.CancelFunc{}
|
||||
var handlingLk sync.Mutex
|
||||
|
||||
cancelCtx := func(req frame) {
|
||||
if req.ID != nil {
|
||||
log.Warnf("%s call with ID set, won't respond", wsCancel)
|
||||
}
|
||||
|
||||
var id int64
|
||||
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
}
|
||||
|
||||
handlingLk.Lock()
|
||||
defer handlingLk.Unlock()
|
||||
|
||||
cf, ok := handling[id]
|
||||
if ok {
|
||||
cf()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -62,33 +107,8 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
||||
return
|
||||
}
|
||||
|
||||
if frame.Method != "" {
|
||||
// call
|
||||
req := request{
|
||||
Jsonrpc: frame.Jsonrpc,
|
||||
ID: frame.ID,
|
||||
Method: frame.Method,
|
||||
Params: frame.Params,
|
||||
}
|
||||
|
||||
// TODO: ignore ID
|
||||
wcl, err := conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler.handle(ctx, req, wcl, func(w io.Writer, req *request, code int, err error) {
|
||||
log.Error("handle me:", err) // TODO: seriously
|
||||
return
|
||||
})
|
||||
|
||||
if err := wcl.Close(); err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// response
|
||||
switch frame.Method {
|
||||
case "": // Response to our call
|
||||
req, ok := inflight[*frame.ID]
|
||||
if !ok {
|
||||
log.Error("client got unknown ID in response")
|
||||
@ -102,11 +122,48 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
||||
Error: frame.Error,
|
||||
}
|
||||
delete(inflight, *frame.ID)
|
||||
case wsCancel:
|
||||
cancelCtx(frame)
|
||||
default: // Remote call
|
||||
req := request{
|
||||
Jsonrpc: frame.Jsonrpc,
|
||||
ID: frame.ID,
|
||||
Method: frame.Method,
|
||||
Params: frame.Params,
|
||||
}
|
||||
|
||||
ctx, cf := context.WithCancel(ctx)
|
||||
|
||||
nw := func(cb func(io.Writer)) {
|
||||
cb(ioutil.Discard)
|
||||
}
|
||||
done := func(){}
|
||||
if frame.ID != nil {
|
||||
nw = nextWriter
|
||||
|
||||
|
||||
handlingLk.Lock()
|
||||
handling[*frame.ID] = cf
|
||||
handlingLk.Unlock()
|
||||
|
||||
done = func() {
|
||||
handlingLk.Lock()
|
||||
defer handlingLk.Unlock()
|
||||
|
||||
cf := handling[*frame.ID]
|
||||
cf()
|
||||
delete(handling, *frame.ID)
|
||||
}
|
||||
}
|
||||
|
||||
go handler.handle(ctx, req, nw, rpcError, done)
|
||||
}
|
||||
|
||||
go nextMessage()
|
||||
case req := <-requests:
|
||||
inflight[*req.req.ID] = req
|
||||
if req.req.ID != nil {
|
||||
inflight[*req.req.ID] = req
|
||||
}
|
||||
if err := conn.WriteJSON(req.req); err != nil {
|
||||
log.Error("handle me:", err)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user