diff --git a/daemon/cmd.go b/daemon/cmd.go index a46264728..4029f581e 100644 --- a/daemon/cmd.go +++ b/daemon/cmd.go @@ -4,6 +4,7 @@ package daemon import ( "context" + "github.com/multiformats/go-multiaddr" "gopkg.in/urfave/cli.v2" diff --git a/lib/jsonrpc/client.go b/lib/jsonrpc/client.go index d065c6a2c..f165d1a9f 100644 --- a/lib/jsonrpc/client.go +++ b/lib/jsonrpc/client.go @@ -34,24 +34,19 @@ func (e *ErrClient) Unwrap(err error) error { return e.err } -type result []byte - -func (p *result) UnmarshalJSON(raw []byte) error { - *p = make([]byte, len(raw)) - copy(*p, raw) - return nil -} - type clientResponse struct { - Jsonrpc string `json:"jsonrpc"` - Result result `json:"result"` - ID int64 `json:"id"` - Error *respError `json:"error,omitempty"` + Jsonrpc string `json:"jsonrpc"` + Result json.RawMessage `json:"result"` + ID int64 `json:"id"` + Error *respError `json:"error,omitempty"` } type clientRequest struct { req request ready chan clientResponse + + // retCh provides a context and sink for handling incoming channel messages + retCh func() (context.Context, func([]byte, bool)) } // ClientCloser is used to close Client from further use @@ -65,11 +60,11 @@ type ClientCloser func() func NewClient(addr string, namespace string, handler interface{}) (ClientCloser, error) { htyp := reflect.TypeOf(handler) if htyp.Kind() != reflect.Ptr { - panic("expected handler to be a pointer") + return nil, xerrors.New("expected handler to be a pointer") } typ := htyp.Elem() if typ.Kind() != reflect.Struct { - panic("handler should be a struct") + return nil, xerrors.New("handler should be a struct") } val := reflect.ValueOf(handler) @@ -85,13 +80,18 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser requests := make(chan clientRequest) handlers := map[string]rpcHandler{} - go handleWsConn(context.TODO(), conn, handlers, requests, stop) + go (&wsConn{ + conn: conn, + handler: handlers, + requests: requests, + stop: stop, + }).handleWsConn(context.TODO()) for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) ftyp := f.Type if ftyp.Kind() != reflect.Func { - panic("handler field not a func") + return nil, xerrors.New("handler field not a func") } valOut, errOut, nout := processFuncOut(ftyp) @@ -100,7 +100,7 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser out := make([]reflect.Value, nout) if valOut != -1 { - out[valOut] = rval.Elem() + out[valOut] = rval } if errOut != -1 { out[errOut] = reflect.New(errorType).Elem() @@ -130,6 +130,7 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser if ftyp.NumIn() > 0 && ftyp.In(0) == contextType { hasCtx = 1 } + retCh := valOut != -1 && ftyp.Out(valOut).Kind() == reflect.Chan fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) { id := atomic.AddInt64(&idCtr, 1) @@ -140,6 +141,44 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser } } + var ctx context.Context + if hasCtx == 1 { + ctx = args[0].Interface().(context.Context) + } + + var retVal reflect.Value + + // if the function returns a channel, we need to provide a sink for the + // messages + var chCtor func() (context.Context, func([]byte, bool)) + + if retCh { + retVal = reflect.Zero(ftyp.Out(valOut)) + + chCtor = func() (context.Context, func([]byte, bool)) { + // unpack chan type to make sure it's reflect.BothDir + ctyp := reflect.ChanOf(reflect.BothDir, ftyp.Out(valOut).Elem()) + ch := reflect.MakeChan(ctyp, 0) // todo: buffer? + retVal = ch.Convert(ftyp.Out(valOut)) + + return ctx, func(result []byte, ok bool) { + if !ok { + // remote channel closed, close ours too + ch.Close() + return + } + + val := reflect.New(ftyp.Out(valOut).Elem()) + if err := json.Unmarshal(result, val.Interface()); err != nil { + log.Errorf("error unmarshaling chan response: %s", err) + return + } + + ch.Send(val.Elem()) // todo: select on ctx is probably a good idea + } + } + } + req := request{ Jsonrpc: "2.0", ID: &id, @@ -151,14 +190,17 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser requests <- clientRequest{ req: req, ready: rchan, + + retCh: chCtor, } var ctxDone <-chan struct{} var resp clientResponse - if hasCtx == 1 { - ctxDone = args[0].Interface().(context.Context).Done() + if ctx != nil { + ctxDone = ctx.Done() } + // wait for response, handle context cancellation loop: for { select { @@ -176,24 +218,25 @@ func NewClient(addr string, namespace string, handler interface{}) (ClientCloser } } } - var rval reflect.Value - if valOut != -1 { - rval = reflect.New(ftyp.Out(valOut)) + if valOut != -1 && !retCh { + retVal = reflect.New(ftyp.Out(valOut)) if resp.Result != nil { log.Debugw("rpc result", "type", ftyp.Out(valOut)) - if err := json.Unmarshal(resp.Result, rval.Interface()); err != nil { + if err := json.Unmarshal(resp.Result, retVal.Interface()); err != nil { return processError(xerrors.Errorf("unmarshaling result: %w", err)) } } + + retVal = retVal.Elem() } if resp.ID != *req.ID { return processError(errors.New("request and response id didn't match")) } - return processResponse(resp, rval) + return processResponse(resp, retVal) }) val.Elem().Field(i).Set(fn) diff --git a/lib/jsonrpc/handler.go b/lib/jsonrpc/handler.go index 97607dc50..b56d08fa9 100644 --- a/lib/jsonrpc/handler.go +++ b/lib/jsonrpc/handler.go @@ -95,6 +95,7 @@ func (h handlers) register(namespace string, r interface{}) { // Handle type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error) +type chanOut func(reflect.Value) interface{} func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) { wf := func(cb func(io.Writer)) { @@ -107,20 +108,28 @@ func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp return } - h.handle(ctx, req, wf, rpcError, func() {}) + h.handle(ctx, req, wf, rpcError, func(bool) {}, nil) } -func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func()) { - defer done() - +func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) { handler, ok := h[req.Method] if !ok { rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method)) + done(false) return } if len(req.Params) != handler.nParams { rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count")) + done(false) + return + } + + outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan + defer done(outCh) + + if chOut == nil && outCh { + rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not supported in this mode (no out channel support)", req.Method)) return } @@ -168,6 +177,14 @@ func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer } w(func(w io.Writer) { + if resp.Result != nil && reflect.TypeOf(resp.Result).Kind() == reflect.Chan { + // this must happen in the writer callback, otherwise we may start sending + // channel messages before we send this response + + //noinspection GoNilness // already checked above + resp.Result = chOut(callResult[handler.valOut]) + } + if err := json.NewEncoder(w).Encode(resp); err != nil { fmt.Println(err) return diff --git a/lib/jsonrpc/rpc_test.go b/lib/jsonrpc/rpc_test.go index 053b89ba3..427cf092d 100644 --- a/lib/jsonrpc/rpc_test.go +++ b/lib/jsonrpc/rpc_test.go @@ -3,12 +3,15 @@ package jsonrpc import ( "context" "errors" + "fmt" "net/http/httptest" "strconv" "strings" "sync" "testing" "time" + + "github.com/stretchr/testify/require" ) type SimpleServerHandler struct { @@ -71,62 +74,37 @@ func TestRPC(t *testing.T) { StringMatch func(t TestType, i2 int64) (out TestOut, err error) } closer, err := NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer closer() // Add(int) error - if err := client.Add(2); err != nil { - t.Fatal(err) - } - - if serverHandler.n != 2 { - t.Error("expected 2") - } + require.NoError(t, client.Add(2)) + require.Equal(t, 2, serverHandler.n) err = client.Add(-3546) - if err == nil { - t.Fatal("expected error") - } - if err.Error() != "test" { - t.Fatal("wrong error", err) - } + require.EqualError(t, err, "test") // AddGet(int) int n := client.AddGet(3) - if n != 5 { - t.Error("wrong n") - } - - if serverHandler.n != 5 { - t.Error("expected 5") - } + require.Equal(t, 5, n) + require.Equal(t, 5, serverHandler.n) // StringMatch o, err := client.StringMatch(TestType{S: "0"}, 0) - if err != nil { - t.Error(err) - } - if o.S != "0" || o.I != 0 { - t.Error("wrong result") - } + require.NoError(t, err) + require.Equal(t, "0", o.S) + require.Equal(t, 0, o.I) _, err = client.StringMatch(TestType{S: "5"}, 5) - if err == nil || err.Error() != ":(" { - t.Error("wrong err") - } + require.EqualError(t, err, ":(") o, err = client.StringMatch(TestType{S: "8", I: 8}, 8) - if err != nil { - t.Error(err) - } - if o.S != "8" || o.I != 8 { - t.Error("wrong result") - } + require.NoError(t, err) + require.Equal(t, "8", o.S) + require.Equal(t, 8, o.I) // Invalid client handlers @@ -134,24 +112,18 @@ func TestRPC(t *testing.T) { Add func(int) } closer, err = NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noret) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // this one should actually work noret.Add(4) - if serverHandler.n != 9 { - t.Error("expected 9") - } + require.Equal(t, 9, serverHandler.n) closer() var noparam struct { Add func() } closer, err = NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noparam) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // shouldn't panic noparam.Add() @@ -161,9 +133,7 @@ func TestRPC(t *testing.T) { AddGet func() (int, error) } closer, err = NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &erronly) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, err = erronly.AddGet() if err == nil || err.Error() != "RPC error (-32602): wrong param count" { @@ -175,9 +145,7 @@ func TestRPC(t *testing.T) { Add func(string) error } closer, err = NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &wrongtype) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = wrongtype.Add("not an int") if err == nil || !strings.Contains(err.Error(), "RPC error (-32700):") || !strings.Contains(err.Error(), "json: cannot unmarshal string into Go value of type int") { @@ -189,9 +157,7 @@ func TestRPC(t *testing.T) { NotThere func(string) error } closer, err = NewClient("ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", ¬found) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = notfound.NotThere("hello?") if err == nil || err.Error() != "RPC error (-32601): method 'SimpleServerHandler.NotThere' not found" { @@ -238,9 +204,7 @@ func TestCtx(t *testing.T) { Test func(ctx context.Context) } closer, err := NewClient("ws://"+testServ.Listener.Addr().String(), "CtxHandler", &client) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -276,3 +240,140 @@ func TestCtx(t *testing.T) { serverHandler.lk.Unlock() closer() } + +type UnUnmarshalable int + +func (*UnUnmarshalable) UnmarshalJSON([]byte) error { + return errors.New("nope") +} + +type UnUnmarshalableHandler struct{} + +func (*UnUnmarshalableHandler) GetUnUnmarshalableStuff() (UnUnmarshalable, error) { + return UnUnmarshalable(5), nil +} + +func TestUnmarshalableResult(t *testing.T) { + var client struct { + GetUnUnmarshalableStuff func() (UnUnmarshalable, error) + } + + rpcServer := NewServer() + rpcServer.Register("Handler", &UnUnmarshalableHandler{}) + + testServ := httptest.NewServer(rpcServer) + defer testServ.Close() + + closer, err := NewClient("ws://"+testServ.Listener.Addr().String(), "Handler", &client) + require.NoError(t, err) + defer closer() + + _, err = client.GetUnUnmarshalableStuff() + require.EqualError(t, err, "RPC client error: unmarshaling result: nope") +} + +type ChanHandler struct { + wait chan struct{} +} + +func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error) { + out := make(chan int) + + go func() { + defer close(out) + var n int + + for { + select { + case <-ctx.Done(): + fmt.Println("ctxdone1") + return + case <-h.wait: + } + + n += i + + if n == eq { + fmt.Println("eq") + return + } + + select { + case <-ctx.Done(): + fmt.Println("ctxdone2") + return + case out <- n: + } + } + }() + + return out, nil +} + +func TestChan(t *testing.T) { + var client struct { + Sub func(context.Context, int, int) (<-chan int, error) + } + + serverHandler := &ChanHandler{ + wait: make(chan struct{}, 5), + } + + rpcServer := NewServer() + rpcServer.Register("ChanHandler", serverHandler) + + testServ := httptest.NewServer(rpcServer) + defer testServ.Close() + + closer, err := NewClient("ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client) + require.NoError(t, err) + + defer closer() + + serverHandler.wait <- struct{}{} + + ctx, cancel := context.WithCancel(context.Background()) + + // sub + + sub, err := client.Sub(ctx, 2, -1) + require.NoError(t, err) + + // recv one + + require.Equal(t, 2, <-sub) + + // recv many (order) + + serverHandler.wait <- struct{}{} + serverHandler.wait <- struct{}{} + serverHandler.wait <- struct{}{} + + require.Equal(t, 4, <-sub) + require.Equal(t, 6, <-sub) + require.Equal(t, 8, <-sub) + + // close (through ctx) + cancel() + + _, ok := <-sub + require.Equal(t, false, ok) + + // sub (again) + + serverHandler.wait <- struct{}{} + + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + + sub, err = client.Sub(ctx, 3, 6) + require.NoError(t, err) + + require.Equal(t, 3, <-sub) + + // close (remote) + serverHandler.wait <- struct{}{} + _, ok = <-sub + require.Equal(t, false, ok) + +} diff --git a/lib/jsonrpc/server.go b/lib/jsonrpc/server.go index fdf3a5a67..d6111e0d7 100644 --- a/lib/jsonrpc/server.go +++ b/lib/jsonrpc/server.go @@ -36,7 +36,10 @@ func (s *RPCServer) handleWS(w http.ResponseWriter, r *http.Request) { return } - handleWsConn(r.Context(), c, s.methods, nil, nil) + (&wsConn{ + conn: c, + handler: s.methods, + }).handleWsConn(r.Context()) if err := c.Close(); err != nil { log.Error(err) @@ -60,6 +63,8 @@ func rpcError(wf func(func(io.Writer)), req *request, code int, err error) { hw.WriteHeader(500) } + log.Warnf("rpc error: %s", err) + if req.ID == nil { // notification return } @@ -73,7 +78,11 @@ func rpcError(wf func(func(io.Writer)), req *request, code int, err error) { }, } - _ = json.NewEncoder(w).Encode(resp) + err = json.NewEncoder(w).Encode(resp) + if err != nil { + log.Warnf("failed to write rpc error: %s", err) + return + } }) } diff --git a/lib/jsonrpc/websocket.go b/lib/jsonrpc/websocket.go index f52c0591c..c2778e7d2 100644 --- a/lib/jsonrpc/websocket.go +++ b/lib/jsonrpc/websocket.go @@ -6,12 +6,16 @@ import ( "errors" "io" "io/ioutil" + "reflect" "sync" + "sync/atomic" "github.com/gorilla/websocket" ) const wsCancel = "xrpc.cancel" +const chValue = "xrpc.ch.val" +const chClose = "xrpc.ch.close" type frame struct { // common @@ -23,169 +27,420 @@ type frame struct { Params []param `json:"params,omitempty"` // response - Result result `json:"result,omitempty"` - Error *respError `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *respError `json:"error,omitempty"` } -func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) { - incoming := make(chan io.Reader) - var incErr error +type outChanReg struct { + id uint64 + ch reflect.Value +} - nextMessage := func() { - mtype, r, err := conn.NextReader() - if err != nil { - incErr = err - close(incoming) - return - } - if mtype != websocket.BinaryMessage && mtype != websocket.TextMessage { - incErr = errors.New("unsupported message type") - close(incoming) - return - } - incoming <- r +type wsConn struct { + // outside params + conn *websocket.Conn + handler handlers + requests <-chan clientRequest + stop <-chan struct{} + + // incoming messages + incoming chan io.Reader + incomingErr error + + // outgoing messages + writeLk sync.Mutex + + // //// + // Client related + + // inflight are requests we've sent to the remote + inflight map[int64]clientRequest + + // chanHandlers is a map of client-side channel handlers + chanHandlers map[uint64]func(m []byte, ok bool) + + // //// + // Server related + + // handling are the calls we handle + handling map[int64]context.CancelFunc + handlingLk sync.Mutex + + spawnOutChanHandlerOnce sync.Once + + // chanCtr is a counter used for identifying output channels on the server side + chanCtr uint64 + + registerCh chan outChanReg +} + +// // +// WebSocket Message utils // +// // + +// nextMessage wait for one message and puts it to the incoming channel +func (c *wsConn) nextMessage() { + msgType, r, err := c.conn.NextReader() + if err != nil { + c.incomingErr = err + close(c.incoming) + return + } + if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage { + c.incomingErr = errors.New("unsupported message type") + close(c.incoming) + return + } + c.incoming <- r +} + +// nextWriter waits for writeLk and invokes the cb callback with WS message +// writer when the lock is acquired +func (c *wsConn) nextWriter(cb func(io.Writer)) { + c.writeLk.Lock() + defer c.writeLk.Unlock() + + wcl, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + log.Error("handle me:", err) + return } - var writeLk sync.Mutex + cb(wcl) + + if err := wcl.Close(); err != nil { + log.Error("handle me:", err) + return + } +} + +func (c *wsConn) sendRequest(req request) { + c.writeLk.Lock() + if err := c.conn.WriteJSON(req); err != nil { + log.Error("handle me:", err) + c.writeLk.Unlock() + return + } + c.writeLk.Unlock() +} + +// // +// Output channels // +// // + +// handleOutChans handles channel communication on the server side +// (forwards channel messages to client) +func (c *wsConn) handleOutChans() { + regV := reflect.ValueOf(c.registerCh) + + cases := []reflect.SelectCase{ + { // registration chan always 0 + Dir: reflect.SelectRecv, + Chan: regV, + }, + } + var caseToID []uint64 + + for { + chosen, val, ok := reflect.Select(cases) + + if chosen == 0 { // control channel + if !ok { + // control channel closed - signals closed connection + // + // We're not closing any channels as we're on receiving end. + // Also, context cancellation below should take care of any running + // requests + return + } + + registration := val.Interface().(outChanReg) + + caseToID = append(caseToID, registration.id) + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: registration.ch, + }) + + continue + } + + if !ok { + // Output channel closed, cleanup, and tell remote that this happened + + n := len(caseToID) + if n > 0 { + cases[chosen] = cases[n] + caseToID[chosen-1] = caseToID[n-1] + } + + id := caseToID[chosen-1] + cases = cases[:n] + caseToID = caseToID[:n-1] + + c.sendRequest(request{ + Jsonrpc: "2.0", + ID: nil, // notification + Method: chClose, + Params: []param{{v: reflect.ValueOf(id)}}, + }) + continue + } + + // forward message + c.sendRequest(request{ + Jsonrpc: "2.0", + ID: nil, // notification + Method: chValue, + Params: []param{{v: reflect.ValueOf(caseToID[chosen-1])}, {v: val}}, + }) + } +} + +// handleChanOut registers output channel for forwarding to client +func (c *wsConn) handleChanOut(ch reflect.Value) interface{} { + c.spawnOutChanHandlerOnce.Do(func() { + go c.handleOutChans() + }) + id := atomic.AddUint64(&c.chanCtr, 1) + + c.registerCh <- outChanReg{ + id: id, + ch: ch, + } + + return id +} + +// // +// Context.Done propagation // +// // + +// handleCtxAsync handles context lifetimes for client +// TODO: this should be aware of events going through chanHandlers, and quit +// when the related channel is closed. +// This should also probably be a single goroutine, +// Note that not doing this should be fine for now as long as we are using +// contexts correctly (cancelling when async functions are no longer is use) +func (c *wsConn) handleCtxAsync(actx context.Context, id int64) { + <-actx.Done() + + c.sendRequest(request{ + Jsonrpc: "2.0", + Method: wsCancel, + Params: []param{{v: reflect.ValueOf(id)}}, + }) +} + +// cancelCtx is a built-in rpc which handles context cancellation over rpc +func (c *wsConn) cancelCtx(req frame) { + if req.ID != nil { + log.Warnf("%s call with ID set, won't respond", wsCancel) + } + + var id int64 + if err := json.Unmarshal(req.Params[0].data, &id); err != nil { + log.Error("handle me:", err) + return + } + + c.handlingLk.Lock() + defer c.handlingLk.Unlock() + + cf, ok := c.handling[id] + if ok { + cf() + } +} + +// // +// Main Handling logic // +// // + +func (c *wsConn) handleChanMessage(frame frame) { + var chid uint64 + if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { + log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) + return + } + + hnd, ok := c.chanHandlers[chid] + if !ok { + log.Errorf("xrpc.ch.val: handler %d not found", chid) + return + } + + hnd(frame.Params[1].data, true) +} + +func (c *wsConn) handleChanClose(frame frame) { + var chid uint64 + if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { + log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) + return + } + + hnd, ok := c.chanHandlers[chid] + if !ok { + log.Errorf("xrpc.ch.val: handler %d not found", chid) + return + } + + delete(c.chanHandlers, chid) + + hnd(nil, false) +} + +func (c *wsConn) handleResponse(frame frame) { + req, ok := c.inflight[*frame.ID] + if !ok { + log.Error("client got unknown ID in response") + return + } + + if req.retCh != nil && frame.Result != nil { + // output is channel + var chid uint64 + if err := json.Unmarshal(frame.Result, &chid); err != nil { + log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result)) + return + } + + var chanCtx context.Context + chanCtx, c.chanHandlers[chid] = req.retCh() + go c.handleCtxAsync(chanCtx, *frame.ID) + } + + req.ready <- clientResponse{ + Jsonrpc: frame.Jsonrpc, + Result: frame.Result, + ID: *frame.ID, + Error: frame.Error, + } + delete(c.inflight, *frame.ID) +} + +func (c *wsConn) handleCall(ctx context.Context, frame frame) { + req := request{ + Jsonrpc: frame.Jsonrpc, + ID: frame.ID, + Method: frame.Method, + Params: frame.Params, + } + + ctx, cancel := context.WithCancel(ctx) + nextWriter := func(cb func(io.Writer)) { - writeLk.Lock() - defer writeLk.Unlock() - - wcl, err := conn.NextWriter(websocket.TextMessage) - if err != nil { - log.Error("handle me:", err) - return + cb(ioutil.Discard) + } + done := func(keepCtx bool) { + if !keepCtx { + cancel() } + } + if frame.ID != nil { + nextWriter = c.nextWriter - cb(wcl) + c.handlingLk.Lock() + c.handling[*frame.ID] = cancel + c.handlingLk.Unlock() - if err := wcl.Close(); err != nil { - log.Error("handle me:", err) - return + done = func(keepctx bool) { + c.handlingLk.Lock() + defer c.handlingLk.Unlock() + + if !keepctx { + cancel() + delete(c.handling, *frame.ID) + } } } - go nextMessage() + go c.handler.handle(ctx, req, nextWriter, rpcError, done, c.handleChanOut) +} - inflight := map[int64]clientRequest{} - handling := map[int64]context.CancelFunc{} - var handlingLk sync.Mutex +// handleFrame handles all incoming messages (calls and responses) +func (c *wsConn) handleFrame(ctx context.Context, frame frame) { + // Get message type by method name: + // "" - response + // "xrpc.*" - builtin + // anything else - incoming remote call + switch frame.Method { + case "": // Response to our call + c.handleResponse(frame) + case wsCancel: + c.cancelCtx(frame) + case chValue: + c.handleChanMessage(frame) + case chClose: + c.handleChanClose(frame) + default: // Remote call + c.handleCall(ctx, frame) + } +} +func (c *wsConn) handleWsConn(ctx context.Context) { + c.incoming = make(chan io.Reader) + c.inflight = map[int64]clientRequest{} + c.handling = map[int64]context.CancelFunc{} + c.chanHandlers = map[uint64]func(m []byte, ok bool){} + + c.registerCh = make(chan outChanReg) + defer close(c.registerCh) + + // //// + + // on close, make sure to return from all pending calls, and cancel context + // on all calls we handle defer func() { - for id, req := range inflight { + for id, req := range c.inflight { req.ready <- clientResponse{ Jsonrpc: "2.0", - ID: id, + ID: id, Error: &respError{ Message: "handler: websocket connection closed", }, } - handlingLk.Lock() - for _, cancel := range handling { + c.handlingLk.Lock() + for _, cancel := range c.handling { cancel() } - handlingLk.Unlock() + c.handlingLk.Unlock() } }() - cancelCtx := func(req frame) { - if req.ID != nil { - log.Warnf("%s call with ID set, won't respond", wsCancel) - } - - var id int64 - if err := json.Unmarshal(req.Params[0].data, &id); err != nil { - log.Error("handle me:", err) - return - } - - handlingLk.Lock() - defer handlingLk.Unlock() - - cf, ok := handling[id] - if ok { - cf() - } - } + // wait for the first message + go c.nextMessage() for { select { - case r, ok := <-incoming: + case r, ok := <-c.incoming: if !ok { - if incErr != nil { - log.Debugf("websocket error", "error", incErr) + if c.incomingErr != nil { + log.Debugf("websocket error", "error", c.incomingErr) } return // remote closed } + // debug util - dump all messages to stderr + // r = io.TeeReader(r, os.Stderr) + var frame frame if err := json.NewDecoder(r).Decode(&frame); err != nil { log.Error("handle me:", err) return } - switch frame.Method { - case "": // Response to our call - req, ok := inflight[*frame.ID] - if !ok { - log.Error("client got unknown ID in response") - continue - } - - req.ready <- clientResponse{ - Jsonrpc: frame.Jsonrpc, - Result: frame.Result, - ID: *frame.ID, - Error: frame.Error, - } - delete(inflight, *frame.ID) - case wsCancel: - cancelCtx(frame) - default: // Remote call - req := request{ - Jsonrpc: frame.Jsonrpc, - ID: frame.ID, - Method: frame.Method, - Params: frame.Params, - } - - ctx, cf := context.WithCancel(ctx) - - nw := func(cb func(io.Writer)) { - cb(ioutil.Discard) - } - done := cf - if frame.ID != nil { - nw = nextWriter - - handlingLk.Lock() - handling[*frame.ID] = cf - handlingLk.Unlock() - - done = func() { - handlingLk.Lock() - defer handlingLk.Unlock() - - cf() - delete(handling, *frame.ID) - } - } - - go handler.handle(ctx, req, nw, rpcError, done) - } - - go nextMessage() - case req := <-requests: + c.handleFrame(ctx, frame) + go c.nextMessage() + case req := <-c.requests: if req.req.ID != nil { - inflight[*req.req.ID] = req + c.inflight[*req.req.ID] = req } - if err := conn.WriteJSON(req.req); err != nil { - log.Error("handle me:", err) - return - } - case <-stop: - if err := conn.Close(); err != nil { + c.sendRequest(req.req) + case <-c.stop: + if err := c.conn.Close(); err != nil { log.Debugf("websocket close error", "error", err) } return