Context cancellation over websockets
This commit is contained in:
parent
b93d71e8cb
commit
1153f050bb
@ -151,16 +151,42 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser
|
|||||||
req: req,
|
req: req,
|
||||||
ready: rchan,
|
ready: rchan,
|
||||||
}
|
}
|
||||||
resp := <- rchan
|
var ctxDone <-chan struct{}
|
||||||
|
var resp clientResponse
|
||||||
|
|
||||||
|
if hasCtx == 1 {
|
||||||
|
ctxDone = args[0].Interface().(context.Context).Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case resp = <-rchan:
|
||||||
|
break loop
|
||||||
|
case <-ctxDone: // send cancel request
|
||||||
|
ctxDone = nil
|
||||||
|
|
||||||
|
requests <- clientRequest{
|
||||||
|
req: request{
|
||||||
|
Jsonrpc: "2.0",
|
||||||
|
Method: wsCancel,
|
||||||
|
Params: []param{{v: reflect.ValueOf(id)}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
var rval reflect.Value
|
var rval reflect.Value
|
||||||
|
|
||||||
if valOut != -1 {
|
if valOut != -1 {
|
||||||
log.Debugw("rpc result", "type", ftyp.Out(valOut))
|
|
||||||
rval = reflect.New(ftyp.Out(valOut))
|
rval = reflect.New(ftyp.Out(valOut))
|
||||||
|
|
||||||
|
if resp.Result != nil {
|
||||||
|
log.Debugw("rpc result", "type", ftyp.Out(valOut))
|
||||||
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
|
if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil {
|
||||||
return processError(err)
|
return processError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ID != *req.ID {
|
if resp.ID != *req.ID {
|
||||||
return processError(errors.New("request and response id didn't match"))
|
return processError(errors.New("request and response id didn't match"))
|
||||||
|
@ -92,19 +92,25 @@ func (h handlers) register(namespace string, r interface{}) {
|
|||||||
|
|
||||||
// Handle
|
// Handle
|
||||||
|
|
||||||
type rpcErrFunc func(w io.Writer, req *request, code int, err error)
|
type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
|
||||||
|
|
||||||
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
|
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
|
||||||
|
wf := func(cb func(io.Writer)) {
|
||||||
|
cb(w)
|
||||||
|
}
|
||||||
|
|
||||||
var req request
|
var req request
|
||||||
if err := json.NewDecoder(r).Decode(&req); err != nil {
|
if err := json.NewDecoder(r).Decode(&req); err != nil {
|
||||||
rpcError(w, &req, rpcParseError, err)
|
rpcError(wf, &req, rpcParseError, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.handle(ctx, req, w, rpcError)
|
h.handle(ctx, req, wf, rpcError, func() {})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError rpcErrFunc) {
|
func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func()) {
|
||||||
|
defer done()
|
||||||
|
|
||||||
handler, ok := h[req.Method]
|
handler, ok := h[req.Method]
|
||||||
if !ok {
|
if !ok {
|
||||||
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
|
rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
|
||||||
@ -159,8 +165,10 @@ func (h handlers) handle(ctx context.Context, req request, w io.Writer, rpcError
|
|||||||
resp.Result = callResult[handler.valOut].Interface()
|
resp.Result = callResult[handler.valOut].Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w(func(w io.Writer) {
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -47,11 +47,15 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.methods.handleReader(r.Context(), r.Body, w, s.rpcError)
|
s.methods.handleReader(r.Context(), r.Body, w, rpcError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RPCServer) rpcError(w io.Writer, req *request, code int, err error) {
|
func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {
|
||||||
w.(http.ResponseWriter).WriteHeader(500)
|
wf(func(w io.Writer) {
|
||||||
|
if hw, ok := w.(http.ResponseWriter); ok {
|
||||||
|
hw.WriteHeader(500)
|
||||||
|
}
|
||||||
|
|
||||||
if req.ID == nil { // notification
|
if req.ID == nil { // notification
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -66,6 +70,7 @@ func (s *RPCServer) rpcError(w io.Writer, req *request, code int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_ = json.NewEncoder(w).Encode(resp)
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers new RPC handler
|
// Register registers new RPC handler
|
||||||
|
@ -5,10 +5,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const wsCancel = "xrpc.cancel"
|
||||||
|
|
||||||
type frame struct {
|
type frame struct {
|
||||||
// common
|
// common
|
||||||
Jsonrpc string `json:"jsonrpc"`
|
Jsonrpc string `json:"jsonrpc"`
|
||||||
@ -42,9 +46,50 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
|||||||
incoming <- r
|
incoming <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var writeLk sync.Mutex
|
||||||
|
nextWriter := func(cb func(io.Writer)) {
|
||||||
|
writeLk.Lock()
|
||||||
|
defer writeLk.Unlock()
|
||||||
|
|
||||||
|
wcl, err := conn.NextWriter(websocket.TextMessage)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("handle me:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(wcl)
|
||||||
|
|
||||||
|
if err := wcl.Close(); err != nil {
|
||||||
|
log.Error("handle me:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go nextMessage()
|
go nextMessage()
|
||||||
|
|
||||||
inflight := map[int64]clientRequest{}
|
inflight := map[int64]clientRequest{}
|
||||||
|
handling := map[int64]context.CancelFunc{}
|
||||||
|
var handlingLk sync.Mutex
|
||||||
|
|
||||||
|
cancelCtx := func(req frame) {
|
||||||
|
if req.ID != nil {
|
||||||
|
log.Warnf("%s call with ID set, won't respond", wsCancel)
|
||||||
|
}
|
||||||
|
|
||||||
|
var id int64
|
||||||
|
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
|
||||||
|
log.Error("handle me:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
handlingLk.Lock()
|
||||||
|
defer handlingLk.Unlock()
|
||||||
|
|
||||||
|
cf, ok := handling[id]
|
||||||
|
if ok {
|
||||||
|
cf()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -62,33 +107,8 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if frame.Method != "" {
|
switch frame.Method {
|
||||||
// call
|
case "": // Response to our call
|
||||||
req := request{
|
|
||||||
Jsonrpc: frame.Jsonrpc,
|
|
||||||
ID: frame.ID,
|
|
||||||
Method: frame.Method,
|
|
||||||
Params: frame.Params,
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: ignore ID
|
|
||||||
wcl, err := conn.NextWriter(websocket.TextMessage)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("handle me:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.handle(ctx, req, wcl, func(w io.Writer, req *request, code int, err error) {
|
|
||||||
log.Error("handle me:", err) // TODO: seriously
|
|
||||||
return
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := wcl.Close(); err != nil {
|
|
||||||
log.Error("handle me:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// response
|
|
||||||
req, ok := inflight[*frame.ID]
|
req, ok := inflight[*frame.ID]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error("client got unknown ID in response")
|
log.Error("client got unknown ID in response")
|
||||||
@ -102,11 +122,48 @@ func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, r
|
|||||||
Error: frame.Error,
|
Error: frame.Error,
|
||||||
}
|
}
|
||||||
delete(inflight, *frame.ID)
|
delete(inflight, *frame.ID)
|
||||||
|
case wsCancel:
|
||||||
|
cancelCtx(frame)
|
||||||
|
default: // Remote call
|
||||||
|
req := request{
|
||||||
|
Jsonrpc: frame.Jsonrpc,
|
||||||
|
ID: frame.ID,
|
||||||
|
Method: frame.Method,
|
||||||
|
Params: frame.Params,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cf := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
nw := func(cb func(io.Writer)) {
|
||||||
|
cb(ioutil.Discard)
|
||||||
|
}
|
||||||
|
done := func(){}
|
||||||
|
if frame.ID != nil {
|
||||||
|
nw = nextWriter
|
||||||
|
|
||||||
|
|
||||||
|
handlingLk.Lock()
|
||||||
|
handling[*frame.ID] = cf
|
||||||
|
handlingLk.Unlock()
|
||||||
|
|
||||||
|
done = func() {
|
||||||
|
handlingLk.Lock()
|
||||||
|
defer handlingLk.Unlock()
|
||||||
|
|
||||||
|
cf := handling[*frame.ID]
|
||||||
|
cf()
|
||||||
|
delete(handling, *frame.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go handler.handle(ctx, req, nw, rpcError, done)
|
||||||
}
|
}
|
||||||
|
|
||||||
go nextMessage()
|
go nextMessage()
|
||||||
case req := <-requests:
|
case req := <-requests:
|
||||||
|
if req.req.ID != nil {
|
||||||
inflight[*req.req.ID] = req
|
inflight[*req.req.ID] = req
|
||||||
|
}
|
||||||
if err := conn.WriteJSON(req.req); err != nil {
|
if err := conn.WriteJSON(req.req); err != nil {
|
||||||
log.Error("handle me:", err)
|
log.Error("handle me:", err)
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user