rpc: Contexts
License: MIT Signed-off-by: Jakub Sztandera <kubuxu@protonmail.ch>
This commit is contained in:
parent
46407e2033
commit
5238872c02
@ -2,6 +2,7 @@ package rpclib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -10,6 +11,11 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
errorType = reflect.TypeOf(new(error)).Elem()
|
||||
contextType = reflect.TypeOf(new(context.Context)).Elem()
|
||||
)
|
||||
|
||||
type ErrClient struct {
|
||||
err error
|
||||
}
|
||||
@ -65,7 +71,7 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
||||
out[valOut] = reflect.Value(resp.Result).Elem()
|
||||
}
|
||||
if errOut != -1 {
|
||||
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
|
||||
out[errOut] = reflect.New(errorType).Elem()
|
||||
if resp.Error != nil {
|
||||
out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message)))
|
||||
}
|
||||
@ -81,17 +87,22 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
||||
out[valOut] = reflect.New(ftyp.Out(valOut)).Elem()
|
||||
}
|
||||
if errOut != -1 {
|
||||
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
|
||||
out[errOut] = reflect.New(errorType).Elem()
|
||||
out[errOut].Set(reflect.ValueOf(&ErrClient{err}))
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
hasCtx := 0
|
||||
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
|
||||
hasCtx = 1
|
||||
}
|
||||
|
||||
fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) {
|
||||
id := atomic.AddInt64(&idCtr, 1)
|
||||
params := make([]param, len(args))
|
||||
for i, arg := range args {
|
||||
params := make([]param, len(args) - hasCtx)
|
||||
for i, arg := range args[hasCtx:] {
|
||||
params[i] = param{
|
||||
v: arg,
|
||||
}
|
||||
@ -109,12 +120,25 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
||||
return processError(err)
|
||||
}
|
||||
|
||||
httpResp, err := http.Post(addr, "application/json", bytes.NewReader(b))
|
||||
// prepare / execute http request
|
||||
|
||||
hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return processError(err)
|
||||
}
|
||||
if hasCtx == 1 {
|
||||
hreq = hreq.WithContext(args[0].Interface().(context.Context))
|
||||
}
|
||||
hreq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpResp, err := http.DefaultClient.Do(hreq)
|
||||
if err != nil {
|
||||
return processError(err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
// process response
|
||||
|
||||
// TODO: check error codes in spec
|
||||
if httpResp.StatusCode != 200 {
|
||||
return processError(errors.New("non 200 response code"))
|
||||
|
@ -15,6 +15,8 @@ type rpcHandler struct {
|
||||
receiver reflect.Value
|
||||
handlerFunc reflect.Value
|
||||
|
||||
hasCtx int
|
||||
|
||||
errOut int
|
||||
valOut int
|
||||
}
|
||||
@ -82,8 +84,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
callParams := make([]reflect.Value, 1+handler.nParams)
|
||||
callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
|
||||
callParams[0] = handler.receiver
|
||||
if handler.hasCtx == 1 {
|
||||
callParams[1] = reflect.ValueOf(r.Context())
|
||||
}
|
||||
|
||||
for i := 0; i < handler.nParams; i++ {
|
||||
rp := reflect.New(handler.paramReceivers[i])
|
||||
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
|
||||
@ -92,7 +98,7 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
callParams[i+1] = reflect.ValueOf(rp.Elem().Interface())
|
||||
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Elem().Interface())
|
||||
}
|
||||
|
||||
callResult := handler.handlerFunc.Call(callParams)
|
||||
@ -133,13 +139,16 @@ func (s *RPCServer) Register(r interface{}) {
|
||||
for i := 0; i < val.NumMethod(); i++ {
|
||||
method := val.Type().Method(i)
|
||||
|
||||
fmt.Println(name + "." + method.Name)
|
||||
|
||||
funcType := method.Func.Type()
|
||||
ins := funcType.NumIn() - 1
|
||||
hasCtx := 0
|
||||
if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
|
||||
hasCtx = 1
|
||||
}
|
||||
|
||||
ins := funcType.NumIn() - 1 - hasCtx
|
||||
recvs := make([]reflect.Type, ins)
|
||||
for i := 0; i < ins; i++ {
|
||||
recvs[i] = method.Type.In(i + 1)
|
||||
recvs[i] = method.Type.In(i + 1 + hasCtx)
|
||||
}
|
||||
|
||||
valOut, errOut, _ := processFuncOut(funcType)
|
||||
@ -151,6 +160,8 @@ func (s *RPCServer) Register(r interface{}) {
|
||||
handlerFunc: method.Func,
|
||||
receiver: val,
|
||||
|
||||
hasCtx: hasCtx,
|
||||
|
||||
errOut: errOut,
|
||||
valOut: valOut,
|
||||
}
|
||||
@ -165,7 +176,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
|
||||
switch n {
|
||||
case 0:
|
||||
case 1:
|
||||
if funcType.Out(0) == reflect.TypeOf(new(error)).Elem() {
|
||||
if funcType.Out(0) == errorType {
|
||||
errOut = 0
|
||||
} else {
|
||||
valOut = 0
|
||||
@ -173,7 +184,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
|
||||
case 2:
|
||||
valOut = 0
|
||||
errOut = 1
|
||||
if funcType.Out(1) != reflect.TypeOf(new(error)).Elem() {
|
||||
if funcType.Out(1) != errorType {
|
||||
panic("expected error as second return value")
|
||||
}
|
||||
default:
|
||||
|
@ -1,10 +1,12 @@
|
||||
package rpclib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SimpleServerHandler struct {
|
||||
@ -143,11 +145,11 @@ func TestRPC(t *testing.T) {
|
||||
noparam.Add()
|
||||
|
||||
var erronly struct {
|
||||
Add func() error
|
||||
AddGet func() (int, error)
|
||||
}
|
||||
NewClient(testServ.URL, "SimpleServerHandler", &erronly)
|
||||
|
||||
err = erronly.Add()
|
||||
_, err = erronly.AddGet()
|
||||
if err == nil || err.Error() != "RPC client error: non 200 response code" {
|
||||
t.Error("wrong error")
|
||||
}
|
||||
@ -162,3 +164,59 @@ func TestRPC(t *testing.T) {
|
||||
t.Error("wrong error")
|
||||
}
|
||||
}
|
||||
|
||||
type CtxHandler struct {
|
||||
cancelled bool
|
||||
i int
|
||||
}
|
||||
|
||||
func (h *CtxHandler) Test(ctx context.Context) {
|
||||
timeout := time.After(300 * time.Millisecond)
|
||||
h.i++
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
case <-ctx.Done():
|
||||
h.cancelled = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestCtx(t *testing.T) {
|
||||
// setup server
|
||||
|
||||
serverHandler := &CtxHandler{}
|
||||
|
||||
rpcServer := NewServer()
|
||||
rpcServer.Register(serverHandler)
|
||||
|
||||
// httptest stuff
|
||||
testServ := httptest.NewServer(rpcServer)
|
||||
defer testServ.Close()
|
||||
|
||||
// setup client
|
||||
|
||||
var client struct {
|
||||
Test func(ctx context.Context)
|
||||
}
|
||||
NewClient(testServ.URL, "CtxHandler", &client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20 * time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
client.Test(ctx)
|
||||
if !serverHandler.cancelled {
|
||||
t.Error("expected cancellation on the server side")
|
||||
}
|
||||
|
||||
serverHandler.cancelled = false
|
||||
|
||||
var noCtxClient struct {
|
||||
Test func()
|
||||
}
|
||||
NewClient(testServ.URL, "CtxHandler", &noCtxClient)
|
||||
|
||||
noCtxClient.Test()
|
||||
if serverHandler.cancelled || serverHandler.i != 2 {
|
||||
t.Error("wrong serverHandler state")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user