rpc: Contexts

License: MIT
Signed-off-by: Jakub Sztandera <kubuxu@protonmail.ch>
This commit is contained in:
Łukasz Magiera 2019-06-28 16:53:01 +02:00 committed by Jakub Sztandera
parent 46407e2033
commit 5238872c02
3 changed files with 108 additions and 15 deletions

View File

@ -2,6 +2,7 @@ package rpclib
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -10,6 +11,11 @@ import (
"sync/atomic" "sync/atomic"
) )
var (
errorType = reflect.TypeOf(new(error)).Elem()
contextType = reflect.TypeOf(new(context.Context)).Elem()
)
type ErrClient struct { type ErrClient struct {
err error err error
} }
@ -65,7 +71,7 @@ func NewClient(addr string, namespace string, handler interface{}) {
out[valOut] = reflect.Value(resp.Result).Elem() out[valOut] = reflect.Value(resp.Result).Elem()
} }
if errOut != -1 { if errOut != -1 {
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem() out[errOut] = reflect.New(errorType).Elem()
if resp.Error != nil { if resp.Error != nil {
out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message))) out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message)))
} }
@ -81,17 +87,22 @@ func NewClient(addr string, namespace string, handler interface{}) {
out[valOut] = reflect.New(ftyp.Out(valOut)).Elem() out[valOut] = reflect.New(ftyp.Out(valOut)).Elem()
} }
if errOut != -1 { if errOut != -1 {
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem() out[errOut] = reflect.New(errorType).Elem()
out[errOut].Set(reflect.ValueOf(&ErrClient{err})) out[errOut].Set(reflect.ValueOf(&ErrClient{err}))
} }
return out return out
} }
hasCtx := 0
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
hasCtx = 1
}
fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) { fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) {
id := atomic.AddInt64(&idCtr, 1) id := atomic.AddInt64(&idCtr, 1)
params := make([]param, len(args)) params := make([]param, len(args) - hasCtx)
for i, arg := range args { for i, arg := range args[hasCtx:] {
params[i] = param{ params[i] = param{
v: arg, v: arg,
} }
@ -109,12 +120,25 @@ func NewClient(addr string, namespace string, handler interface{}) {
return processError(err) return processError(err)
} }
httpResp, err := http.Post(addr, "application/json", bytes.NewReader(b)) // prepare / execute http request
hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b))
if err != nil {
return processError(err)
}
if hasCtx == 1 {
hreq = hreq.WithContext(args[0].Interface().(context.Context))
}
hreq.Header.Set("Content-Type", "application/json")
httpResp, err := http.DefaultClient.Do(hreq)
if err != nil { if err != nil {
return processError(err) return processError(err)
} }
defer httpResp.Body.Close() defer httpResp.Body.Close()
// process response
// TODO: check error codes in spec // TODO: check error codes in spec
if httpResp.StatusCode != 200 { if httpResp.StatusCode != 200 {
return processError(errors.New("non 200 response code")) return processError(errors.New("non 200 response code"))

View File

@ -15,6 +15,8 @@ type rpcHandler struct {
receiver reflect.Value receiver reflect.Value
handlerFunc reflect.Value handlerFunc reflect.Value
hasCtx int
errOut int errOut int
valOut int valOut int
} }
@ -82,8 +84,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
callParams := make([]reflect.Value, 1+handler.nParams) callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
callParams[0] = handler.receiver callParams[0] = handler.receiver
if handler.hasCtx == 1 {
callParams[1] = reflect.ValueOf(r.Context())
}
for i := 0; i < handler.nParams; i++ { for i := 0; i < handler.nParams; i++ {
rp := reflect.New(handler.paramReceivers[i]) rp := reflect.New(handler.paramReceivers[i])
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil { if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
@ -92,7 +98,7 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
callParams[i+1] = reflect.ValueOf(rp.Elem().Interface()) callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Elem().Interface())
} }
callResult := handler.handlerFunc.Call(callParams) callResult := handler.handlerFunc.Call(callParams)
@ -133,13 +139,16 @@ func (s *RPCServer) Register(r interface{}) {
for i := 0; i < val.NumMethod(); i++ { for i := 0; i < val.NumMethod(); i++ {
method := val.Type().Method(i) method := val.Type().Method(i)
fmt.Println(name + "." + method.Name)
funcType := method.Func.Type() funcType := method.Func.Type()
ins := funcType.NumIn() - 1 hasCtx := 0
if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
hasCtx = 1
}
ins := funcType.NumIn() - 1 - hasCtx
recvs := make([]reflect.Type, ins) recvs := make([]reflect.Type, ins)
for i := 0; i < ins; i++ { for i := 0; i < ins; i++ {
recvs[i] = method.Type.In(i + 1) recvs[i] = method.Type.In(i + 1 + hasCtx)
} }
valOut, errOut, _ := processFuncOut(funcType) valOut, errOut, _ := processFuncOut(funcType)
@ -151,6 +160,8 @@ func (s *RPCServer) Register(r interface{}) {
handlerFunc: method.Func, handlerFunc: method.Func,
receiver: val, receiver: val,
hasCtx: hasCtx,
errOut: errOut, errOut: errOut,
valOut: valOut, valOut: valOut,
} }
@ -165,7 +176,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
switch n { switch n {
case 0: case 0:
case 1: case 1:
if funcType.Out(0) == reflect.TypeOf(new(error)).Elem() { if funcType.Out(0) == errorType {
errOut = 0 errOut = 0
} else { } else {
valOut = 0 valOut = 0
@ -173,7 +184,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
case 2: case 2:
valOut = 0 valOut = 0
errOut = 1 errOut = 1
if funcType.Out(1) != reflect.TypeOf(new(error)).Elem() { if funcType.Out(1) != errorType {
panic("expected error as second return value") panic("expected error as second return value")
} }
default: default:

View File

@ -1,10 +1,12 @@
package rpclib package rpclib
import ( import (
"context"
"errors" "errors"
"net/http/httptest" "net/http/httptest"
"strconv" "strconv"
"testing" "testing"
"time"
) )
type SimpleServerHandler struct { type SimpleServerHandler struct {
@ -143,11 +145,11 @@ func TestRPC(t *testing.T) {
noparam.Add() noparam.Add()
var erronly struct { var erronly struct {
Add func() error AddGet func() (int, error)
} }
NewClient(testServ.URL, "SimpleServerHandler", &erronly) NewClient(testServ.URL, "SimpleServerHandler", &erronly)
err = erronly.Add() _, err = erronly.AddGet()
if err == nil || err.Error() != "RPC client error: non 200 response code" { if err == nil || err.Error() != "RPC client error: non 200 response code" {
t.Error("wrong error") t.Error("wrong error")
} }
@ -162,3 +164,59 @@ func TestRPC(t *testing.T) {
t.Error("wrong error") t.Error("wrong error")
} }
} }
type CtxHandler struct {
cancelled bool
i int
}
func (h *CtxHandler) Test(ctx context.Context) {
timeout := time.After(300 * time.Millisecond)
h.i++
select {
case <-timeout:
case <-ctx.Done():
h.cancelled = true
}
}
func TestCtx(t *testing.T) {
// setup server
serverHandler := &CtxHandler{}
rpcServer := NewServer()
rpcServer.Register(serverHandler)
// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
// setup client
var client struct {
Test func(ctx context.Context)
}
NewClient(testServ.URL, "CtxHandler", &client)
ctx, cancel := context.WithTimeout(context.Background(), 20 * time.Millisecond)
defer cancel()
client.Test(ctx)
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
serverHandler.cancelled = false
var noCtxClient struct {
Test func()
}
NewClient(testServ.URL, "CtxHandler", &noCtxClient)
noCtxClient.Test()
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
}