339 lines
6.8 KiB
Go
339 lines
6.8 KiB
Go
package jsonrpc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"io/ioutil"
|
|
"reflect"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const wsCancel = "xrpc.cancel"
|
|
const chValue = "xrpc.ch.val"
|
|
const chClose = "xrpc.ch.close"
|
|
|
|
type frame struct {
|
|
// common
|
|
Jsonrpc string `json:"jsonrpc"`
|
|
ID *int64 `json:"id,omitempty"`
|
|
|
|
// request
|
|
Method string `json:"method,omitempty"`
|
|
Params []param `json:"params,omitempty"`
|
|
|
|
// response
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
Error *respError `json:"error,omitempty"`
|
|
}
|
|
|
|
func handleWsConn(ctx context.Context, conn *websocket.Conn, handler handlers, requests <-chan clientRequest, stop <-chan struct{}) {
|
|
incoming := make(chan io.Reader)
|
|
var incErr error
|
|
|
|
// nextMessage wait for one message and puts it to the incoming channel
|
|
nextMessage := func() {
|
|
mtype, r, err := conn.NextReader()
|
|
if err != nil {
|
|
incErr = err
|
|
close(incoming)
|
|
return
|
|
}
|
|
if mtype != websocket.BinaryMessage && mtype != websocket.TextMessage {
|
|
incErr = errors.New("unsupported message type")
|
|
close(incoming)
|
|
return
|
|
}
|
|
incoming <- r
|
|
}
|
|
|
|
var writeLk sync.Mutex
|
|
|
|
// nextWriter waits for writeLk and invokes the cb callback with WS message
|
|
// writer when the lock is acquired
|
|
nextWriter := func(cb func(io.Writer)) {
|
|
writeLk.Lock()
|
|
defer writeLk.Unlock()
|
|
|
|
wcl, err := conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
log.Error("handle me:", err)
|
|
return
|
|
}
|
|
|
|
cb(wcl)
|
|
|
|
if err := wcl.Close(); err != nil {
|
|
log.Error("handle me:", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
sendReq := func(req request) {
|
|
writeLk.Lock()
|
|
if err := conn.WriteJSON(req); err != nil {
|
|
log.Error("handle me:", err)
|
|
writeLk.Unlock()
|
|
return
|
|
}
|
|
writeLk.Unlock()
|
|
}
|
|
|
|
// wait for the first message
|
|
go nextMessage()
|
|
|
|
// inflight are requests we sent to the remote
|
|
inflight := map[int64]clientRequest{}
|
|
|
|
// handling are the calls we handle
|
|
handling := map[int64]context.CancelFunc{}
|
|
var handlingLk sync.Mutex
|
|
|
|
// ////
|
|
// Subscriptions (func() <-chan Typ - like methods)
|
|
|
|
var chOnce sync.Once
|
|
var outId uint64
|
|
type chReg struct {
|
|
id uint64
|
|
ch reflect.Value
|
|
}
|
|
registerCh := make(chan chReg)
|
|
defer close(registerCh)
|
|
|
|
handleOutChans := func() {
|
|
regV := reflect.ValueOf(registerCh)
|
|
|
|
cases := []reflect.SelectCase{
|
|
{ // registration chan always 0
|
|
Dir: reflect.SelectRecv,
|
|
Chan: regV,
|
|
},
|
|
}
|
|
var caseToId []uint64
|
|
|
|
for {
|
|
chosen, val, ok := reflect.Select(cases)
|
|
|
|
if chosen == 0 { // control channel
|
|
if !ok {
|
|
// not closing any channels as we're on receiving end.
|
|
// Also, context cancellation below should take care of any running
|
|
// requests
|
|
return
|
|
}
|
|
|
|
registration := val.Interface().(chReg)
|
|
|
|
caseToId = append(caseToId, registration.id)
|
|
cases = append(cases, reflect.SelectCase{
|
|
Dir: reflect.SelectRecv,
|
|
Chan: registration.ch,
|
|
})
|
|
|
|
continue
|
|
}
|
|
|
|
if !ok {
|
|
n := len(caseToId)
|
|
if n > 0 {
|
|
cases[chosen] = cases[n]
|
|
caseToId[chosen-1] = caseToId[n-1]
|
|
}
|
|
|
|
cases = cases[:n]
|
|
caseToId = caseToId[:n-1]
|
|
continue
|
|
}
|
|
|
|
sendReq(request{
|
|
Jsonrpc: "2.0",
|
|
ID: nil, // notification
|
|
Method: chValue,
|
|
Params: []param{{v: reflect.ValueOf(caseToId[chosen-1])}, {v: val}},
|
|
})
|
|
}
|
|
}
|
|
|
|
handleChanOut := func(ch reflect.Value) interface{} {
|
|
chOnce.Do(func() {
|
|
go handleOutChans()
|
|
})
|
|
id := atomic.AddUint64(&outId, 1)
|
|
|
|
registerCh <- chReg{
|
|
id: id,
|
|
ch: ch,
|
|
}
|
|
|
|
return id
|
|
}
|
|
|
|
// client side subs
|
|
|
|
chanHandlers := map[uint64]func(m []byte, ok bool){}
|
|
|
|
// ////
|
|
|
|
// on close, make sure to return from all pending calls, and cancel context
|
|
// on all calls we handle
|
|
defer func() {
|
|
for id, req := range inflight {
|
|
req.ready <- clientResponse{
|
|
Jsonrpc: "2.0",
|
|
ID: id,
|
|
Error: &respError{
|
|
Message: "handler: websocket connection closed",
|
|
},
|
|
}
|
|
|
|
handlingLk.Lock()
|
|
for _, cancel := range handling {
|
|
cancel()
|
|
}
|
|
handlingLk.Unlock()
|
|
}
|
|
}()
|
|
|
|
// cancelCtx is a built-in rpc which handles context cancellation over rpc
|
|
cancelCtx := func(req frame) {
|
|
if req.ID != nil {
|
|
log.Warnf("%s call with ID set, won't respond", wsCancel)
|
|
}
|
|
|
|
var id int64
|
|
if err := json.Unmarshal(req.Params[0].data, &id); err != nil {
|
|
log.Error("handle me:", err)
|
|
return
|
|
}
|
|
|
|
handlingLk.Lock()
|
|
defer handlingLk.Unlock()
|
|
|
|
cf, ok := handling[id]
|
|
if ok {
|
|
cf()
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case r, ok := <-incoming:
|
|
if !ok {
|
|
if incErr != nil {
|
|
log.Debugf("websocket error", "error", incErr)
|
|
}
|
|
return // remote closed
|
|
}
|
|
|
|
var frame frame
|
|
if err := json.NewDecoder(r).Decode(&frame); err != nil {
|
|
log.Error("handle me:", err)
|
|
return
|
|
}
|
|
|
|
// Get message type by method name:
|
|
// "" - response
|
|
// "xrpc.*" - builtin
|
|
// anything else - incoming remote call
|
|
switch frame.Method {
|
|
case "": // Response to our call
|
|
req, ok := inflight[*frame.ID]
|
|
if !ok {
|
|
log.Error("client got unknown ID in response")
|
|
continue
|
|
}
|
|
|
|
if req.retCh != nil {
|
|
// output is channel
|
|
var chid uint64
|
|
if err := json.Unmarshal(frame.Result, &chid); err != nil {
|
|
log.Error("failed to unmarshal channel id response: %s", err)
|
|
continue
|
|
}
|
|
|
|
chanHandlers[chid] = req.retCh()
|
|
}
|
|
|
|
req.ready <- clientResponse{
|
|
Jsonrpc: frame.Jsonrpc,
|
|
Result: frame.Result,
|
|
ID: *frame.ID,
|
|
Error: frame.Error,
|
|
}
|
|
delete(inflight, *frame.ID)
|
|
case wsCancel:
|
|
cancelCtx(frame)
|
|
case chValue:
|
|
var chid uint64
|
|
if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil {
|
|
log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err)
|
|
continue
|
|
}
|
|
|
|
hnd, ok := chanHandlers[chid]
|
|
if !ok {
|
|
log.Error("xrpc.ch.val: handler %d not found", chid)
|
|
continue
|
|
}
|
|
|
|
hnd(frame.Params[1].data, true)
|
|
default: // Remote call
|
|
req := request{
|
|
Jsonrpc: frame.Jsonrpc,
|
|
ID: frame.ID,
|
|
Method: frame.Method,
|
|
Params: frame.Params,
|
|
}
|
|
|
|
ctx, cf := context.WithCancel(ctx)
|
|
|
|
nw := func(cb func(io.Writer)) {
|
|
cb(ioutil.Discard)
|
|
}
|
|
done := func(keepctx bool) {
|
|
if !keepctx {
|
|
cf()
|
|
}
|
|
}
|
|
if frame.ID != nil {
|
|
nw = nextWriter
|
|
|
|
handlingLk.Lock()
|
|
handling[*frame.ID] = cf
|
|
handlingLk.Unlock()
|
|
|
|
done = func(keepctx bool) {
|
|
handlingLk.Lock()
|
|
defer handlingLk.Unlock()
|
|
|
|
if !keepctx {
|
|
cf()
|
|
}
|
|
|
|
delete(handling, *frame.ID)
|
|
}
|
|
}
|
|
|
|
go handler.handle(ctx, req, nw, rpcError, done, handleChanOut)
|
|
}
|
|
|
|
go nextMessage() // TODO: fix on errors
|
|
case req := <-requests:
|
|
if req.req.ID != nil {
|
|
inflight[*req.req.ID] = req
|
|
}
|
|
sendReq(req.req)
|
|
case <-stop:
|
|
if err := conn.Close(); err != nil {
|
|
log.Debugf("websocket close error", "error", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|