rpc: Contexts

License: MIT
Signed-off-by: Jakub Sztandera <kubuxu@protonmail.ch>
This commit is contained in:
Łukasz Magiera 2019-06-28 16:53:01 +02:00 committed by Jakub Sztandera
parent 46407e2033
commit 5238872c02
3 changed files with 108 additions and 15 deletions

View File

@ -2,6 +2,7 @@ package rpclib
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -10,6 +11,11 @@ import (
"sync/atomic"
)
var (
errorType = reflect.TypeOf(new(error)).Elem()
contextType = reflect.TypeOf(new(context.Context)).Elem()
)
type ErrClient struct {
err error
}
@ -65,7 +71,7 @@ func NewClient(addr string, namespace string, handler interface{}) {
out[valOut] = reflect.Value(resp.Result).Elem()
}
if errOut != -1 {
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
out[errOut] = reflect.New(errorType).Elem()
if resp.Error != nil {
out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message)))
}
@ -81,17 +87,22 @@ func NewClient(addr string, namespace string, handler interface{}) {
out[valOut] = reflect.New(ftyp.Out(valOut)).Elem()
}
if errOut != -1 {
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
out[errOut] = reflect.New(errorType).Elem()
out[errOut].Set(reflect.ValueOf(&ErrClient{err}))
}
return out
}
hasCtx := 0
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
hasCtx = 1
}
fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) {
id := atomic.AddInt64(&idCtr, 1)
params := make([]param, len(args))
for i, arg := range args {
params := make([]param, len(args) - hasCtx)
for i, arg := range args[hasCtx:] {
params[i] = param{
v: arg,
}
@ -109,12 +120,25 @@ func NewClient(addr string, namespace string, handler interface{}) {
return processError(err)
}
httpResp, err := http.Post(addr, "application/json", bytes.NewReader(b))
// prepare / execute http request
hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b))
if err != nil {
return processError(err)
}
if hasCtx == 1 {
hreq = hreq.WithContext(args[0].Interface().(context.Context))
}
hreq.Header.Set("Content-Type", "application/json")
httpResp, err := http.DefaultClient.Do(hreq)
if err != nil {
return processError(err)
}
defer httpResp.Body.Close()
// process response
// TODO: check error codes in spec
if httpResp.StatusCode != 200 {
return processError(errors.New("non 200 response code"))

View File

@ -15,6 +15,8 @@ type rpcHandler struct {
receiver reflect.Value
handlerFunc reflect.Value
hasCtx int
errOut int
valOut int
}
@ -82,8 +84,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
callParams := make([]reflect.Value, 1+handler.nParams)
callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
callParams[0] = handler.receiver
if handler.hasCtx == 1 {
callParams[1] = reflect.ValueOf(r.Context())
}
for i := 0; i < handler.nParams; i++ {
rp := reflect.New(handler.paramReceivers[i])
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
@ -92,7 +98,7 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
callParams[i+1] = reflect.ValueOf(rp.Elem().Interface())
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Elem().Interface())
}
callResult := handler.handlerFunc.Call(callParams)
@ -133,13 +139,16 @@ func (s *RPCServer) Register(r interface{}) {
for i := 0; i < val.NumMethod(); i++ {
method := val.Type().Method(i)
fmt.Println(name + "." + method.Name)
funcType := method.Func.Type()
ins := funcType.NumIn() - 1
hasCtx := 0
if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
hasCtx = 1
}
ins := funcType.NumIn() - 1 - hasCtx
recvs := make([]reflect.Type, ins)
for i := 0; i < ins; i++ {
recvs[i] = method.Type.In(i + 1)
recvs[i] = method.Type.In(i + 1 + hasCtx)
}
valOut, errOut, _ := processFuncOut(funcType)
@ -151,6 +160,8 @@ func (s *RPCServer) Register(r interface{}) {
handlerFunc: method.Func,
receiver: val,
hasCtx: hasCtx,
errOut: errOut,
valOut: valOut,
}
@ -165,7 +176,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
switch n {
case 0:
case 1:
if funcType.Out(0) == reflect.TypeOf(new(error)).Elem() {
if funcType.Out(0) == errorType {
errOut = 0
} else {
valOut = 0
@ -173,7 +184,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
case 2:
valOut = 0
errOut = 1
if funcType.Out(1) != reflect.TypeOf(new(error)).Elem() {
if funcType.Out(1) != errorType {
panic("expected error as second return value")
}
default:

View File

@ -1,10 +1,12 @@
package rpclib
import (
"context"
"errors"
"net/http/httptest"
"strconv"
"testing"
"time"
)
type SimpleServerHandler struct {
@ -143,11 +145,11 @@ func TestRPC(t *testing.T) {
noparam.Add()
var erronly struct {
Add func() error
AddGet func() (int, error)
}
NewClient(testServ.URL, "SimpleServerHandler", &erronly)
err = erronly.Add()
_, err = erronly.AddGet()
if err == nil || err.Error() != "RPC client error: non 200 response code" {
t.Error("wrong error")
}
@ -162,3 +164,59 @@ func TestRPC(t *testing.T) {
t.Error("wrong error")
}
}
type CtxHandler struct {
cancelled bool
i int
}
func (h *CtxHandler) Test(ctx context.Context) {
timeout := time.After(300 * time.Millisecond)
h.i++
select {
case <-timeout:
case <-ctx.Done():
h.cancelled = true
}
}
func TestCtx(t *testing.T) {
// setup server
serverHandler := &CtxHandler{}
rpcServer := NewServer()
rpcServer.Register(serverHandler)
// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
// setup client
var client struct {
Test func(ctx context.Context)
}
NewClient(testServ.URL, "CtxHandler", &client)
ctx, cancel := context.WithTimeout(context.Background(), 20 * time.Millisecond)
defer cancel()
client.Test(ctx)
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
serverHandler.cancelled = false
var noCtxClient struct {
Test func()
}
NewClient(testServ.URL, "CtxHandler", &noCtxClient)
noCtxClient.Test()
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
}