194 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			194 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package jsonrpc
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"reflect"
 | 
						|
 | 
						|
	"golang.org/x/xerrors"
 | 
						|
)
 | 
						|
 | 
						|
type rpcHandler struct {
 | 
						|
	paramReceivers []reflect.Type
 | 
						|
	nParams        int
 | 
						|
 | 
						|
	receiver    reflect.Value
 | 
						|
	handlerFunc reflect.Value
 | 
						|
 | 
						|
	hasCtx int
 | 
						|
 | 
						|
	errOut int
 | 
						|
	valOut int
 | 
						|
}
 | 
						|
 | 
						|
type handlers map[string]rpcHandler
 | 
						|
 | 
						|
// Request / response
 | 
						|
 | 
						|
type request struct {
 | 
						|
	Jsonrpc string  `json:"jsonrpc"`
 | 
						|
	ID      *int64  `json:"id,omitempty"`
 | 
						|
	Method  string  `json:"method"`
 | 
						|
	Params  []param `json:"params"`
 | 
						|
}
 | 
						|
 | 
						|
type respError struct {
 | 
						|
	Code    int    `json:"code"`
 | 
						|
	Message string `json:"message"`
 | 
						|
}
 | 
						|
 | 
						|
func (e *respError) Error() string {
 | 
						|
	if e.Code >= -32768 && e.Code <= -32000 {
 | 
						|
		return fmt.Sprintf("RPC error (%d): %s", e.Code, e.Message)
 | 
						|
	}
 | 
						|
	return e.Message
 | 
						|
}
 | 
						|
 | 
						|
type response struct {
 | 
						|
	Jsonrpc string      `json:"jsonrpc"`
 | 
						|
	Result  interface{} `json:"result,omitempty"`
 | 
						|
	ID      int64       `json:"id"`
 | 
						|
	Error   *respError  `json:"error,omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
// Register
 | 
						|
 | 
						|
func (h handlers) register(namespace string, r interface{}) {
 | 
						|
	val := reflect.ValueOf(r)
 | 
						|
	//TODO: expect ptr
 | 
						|
 | 
						|
	for i := 0; i < val.NumMethod(); i++ {
 | 
						|
		method := val.Type().Method(i)
 | 
						|
 | 
						|
		funcType := method.Func.Type()
 | 
						|
		hasCtx := 0
 | 
						|
		if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
 | 
						|
			hasCtx = 1
 | 
						|
		}
 | 
						|
 | 
						|
		ins := funcType.NumIn() - 1 - hasCtx
 | 
						|
		recvs := make([]reflect.Type, ins)
 | 
						|
		for i := 0; i < ins; i++ {
 | 
						|
			recvs[i] = method.Type.In(i + 1 + hasCtx)
 | 
						|
		}
 | 
						|
 | 
						|
		valOut, errOut, _ := processFuncOut(funcType)
 | 
						|
 | 
						|
		h[namespace+"."+method.Name] = rpcHandler{
 | 
						|
			paramReceivers: recvs,
 | 
						|
			nParams:        ins,
 | 
						|
 | 
						|
			handlerFunc: method.Func,
 | 
						|
			receiver:    val,
 | 
						|
 | 
						|
			hasCtx: hasCtx,
 | 
						|
 | 
						|
			errOut: errOut,
 | 
						|
			valOut: valOut,
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Handle
 | 
						|
 | 
						|
type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
 | 
						|
type chanOut func(reflect.Value) interface{}
 | 
						|
 | 
						|
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
 | 
						|
	wf := func(cb func(io.Writer)) {
 | 
						|
		cb(w)
 | 
						|
	}
 | 
						|
 | 
						|
	var req request
 | 
						|
	if err := json.NewDecoder(r).Decode(&req); err != nil {
 | 
						|
		rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err))
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	h.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
 | 
						|
}
 | 
						|
 | 
						|
func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) {
 | 
						|
	handler, ok := h[req.Method]
 | 
						|
	if !ok {
 | 
						|
		rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method))
 | 
						|
		done(false)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if len(req.Params) != handler.nParams {
 | 
						|
		rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count"))
 | 
						|
		done(false)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan
 | 
						|
	defer done(outCh)
 | 
						|
 | 
						|
	if chOut == nil && outCh {
 | 
						|
		rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not supported in this mode (no out channel support)", req.Method))
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
 | 
						|
	callParams[0] = handler.receiver
 | 
						|
	if handler.hasCtx == 1 {
 | 
						|
		callParams[1] = reflect.ValueOf(ctx)
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < handler.nParams; i++ {
 | 
						|
		rp := reflect.New(handler.paramReceivers[i])
 | 
						|
		if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
 | 
						|
			rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s': %w", handler.handlerFunc, err))
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Elem().Interface())
 | 
						|
	}
 | 
						|
 | 
						|
	///////////////////
 | 
						|
 | 
						|
	callResult := handler.handlerFunc.Call(callParams)
 | 
						|
	if req.ID == nil {
 | 
						|
		return // notification
 | 
						|
	}
 | 
						|
 | 
						|
	///////////////////
 | 
						|
 | 
						|
	resp := response{
 | 
						|
		Jsonrpc: "2.0",
 | 
						|
		ID:      *req.ID,
 | 
						|
	}
 | 
						|
 | 
						|
	if handler.errOut != -1 {
 | 
						|
		err := callResult[handler.errOut].Interface()
 | 
						|
		if err != nil {
 | 
						|
			resp.Error = &respError{
 | 
						|
				Code:    1,
 | 
						|
				Message: err.(error).Error(),
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if handler.valOut != -1 {
 | 
						|
		resp.Result = callResult[handler.valOut].Interface()
 | 
						|
	}
 | 
						|
 | 
						|
	w(func(w io.Writer) {
 | 
						|
		if resp.Result != nil && reflect.TypeOf(resp.Result).Kind() == reflect.Chan {
 | 
						|
			// this must happen in the writer callback, otherwise we may start sending
 | 
						|
			// channel messages before we send this response
 | 
						|
 | 
						|
			//noinspection GoNilness // already checked above
 | 
						|
			resp.Result = chOut(callResult[handler.valOut])
 | 
						|
		}
 | 
						|
 | 
						|
		if err := json.NewEncoder(w).Encode(resp); err != nil {
 | 
						|
			fmt.Println(err)
 | 
						|
			return
 | 
						|
		}
 | 
						|
	})
 | 
						|
}
 |