rpc: Contexts
License: MIT Signed-off-by: Jakub Sztandera <kubuxu@protonmail.ch>
This commit is contained in:
parent
46407e2033
commit
5238872c02
@ -2,6 +2,7 @@ package rpclib
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -10,6 +11,11 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errorType = reflect.TypeOf(new(error)).Elem()
|
||||||
|
contextType = reflect.TypeOf(new(context.Context)).Elem()
|
||||||
|
)
|
||||||
|
|
||||||
type ErrClient struct {
|
type ErrClient struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
@ -65,7 +71,7 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
|||||||
out[valOut] = reflect.Value(resp.Result).Elem()
|
out[valOut] = reflect.Value(resp.Result).Elem()
|
||||||
}
|
}
|
||||||
if errOut != -1 {
|
if errOut != -1 {
|
||||||
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
|
out[errOut] = reflect.New(errorType).Elem()
|
||||||
if resp.Error != nil {
|
if resp.Error != nil {
|
||||||
out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message)))
|
out[errOut].Set(reflect.ValueOf(errors.New(resp.Error.Message)))
|
||||||
}
|
}
|
||||||
@ -81,17 +87,22 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
|||||||
out[valOut] = reflect.New(ftyp.Out(valOut)).Elem()
|
out[valOut] = reflect.New(ftyp.Out(valOut)).Elem()
|
||||||
}
|
}
|
||||||
if errOut != -1 {
|
if errOut != -1 {
|
||||||
out[errOut] = reflect.New(reflect.TypeOf(new(error)).Elem()).Elem()
|
out[errOut] = reflect.New(errorType).Elem()
|
||||||
out[errOut].Set(reflect.ValueOf(&ErrClient{err}))
|
out[errOut].Set(reflect.ValueOf(&ErrClient{err}))
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hasCtx := 0
|
||||||
|
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
|
||||||
|
hasCtx = 1
|
||||||
|
}
|
||||||
|
|
||||||
fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) {
|
fn := reflect.MakeFunc(ftyp, func(args []reflect.Value) (results []reflect.Value) {
|
||||||
id := atomic.AddInt64(&idCtr, 1)
|
id := atomic.AddInt64(&idCtr, 1)
|
||||||
params := make([]param, len(args))
|
params := make([]param, len(args) - hasCtx)
|
||||||
for i, arg := range args {
|
for i, arg := range args[hasCtx:] {
|
||||||
params[i] = param{
|
params[i] = param{
|
||||||
v: arg,
|
v: arg,
|
||||||
}
|
}
|
||||||
@ -109,12 +120,25 @@ func NewClient(addr string, namespace string, handler interface{}) {
|
|||||||
return processError(err)
|
return processError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpResp, err := http.Post(addr, "application/json", bytes.NewReader(b))
|
// prepare / execute http request
|
||||||
|
|
||||||
|
hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b))
|
||||||
|
if err != nil {
|
||||||
|
return processError(err)
|
||||||
|
}
|
||||||
|
if hasCtx == 1 {
|
||||||
|
hreq = hreq.WithContext(args[0].Interface().(context.Context))
|
||||||
|
}
|
||||||
|
hreq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
httpResp, err := http.DefaultClient.Do(hreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return processError(err)
|
return processError(err)
|
||||||
}
|
}
|
||||||
defer httpResp.Body.Close()
|
defer httpResp.Body.Close()
|
||||||
|
|
||||||
|
// process response
|
||||||
|
|
||||||
// TODO: check error codes in spec
|
// TODO: check error codes in spec
|
||||||
if httpResp.StatusCode != 200 {
|
if httpResp.StatusCode != 200 {
|
||||||
return processError(errors.New("non 200 response code"))
|
return processError(errors.New("non 200 response code"))
|
||||||
|
@ -15,6 +15,8 @@ type rpcHandler struct {
|
|||||||
receiver reflect.Value
|
receiver reflect.Value
|
||||||
handlerFunc reflect.Value
|
handlerFunc reflect.Value
|
||||||
|
|
||||||
|
hasCtx int
|
||||||
|
|
||||||
errOut int
|
errOut int
|
||||||
valOut int
|
valOut int
|
||||||
}
|
}
|
||||||
@ -82,8 +84,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
callParams := make([]reflect.Value, 1+handler.nParams)
|
callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams)
|
||||||
callParams[0] = handler.receiver
|
callParams[0] = handler.receiver
|
||||||
|
if handler.hasCtx == 1 {
|
||||||
|
callParams[1] = reflect.ValueOf(r.Context())
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < handler.nParams; i++ {
|
for i := 0; i < handler.nParams; i++ {
|
||||||
rp := reflect.New(handler.paramReceivers[i])
|
rp := reflect.New(handler.paramReceivers[i])
|
||||||
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
|
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
|
||||||
@ -92,7 +98,7 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
callParams[i+1] = reflect.ValueOf(rp.Elem().Interface())
|
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Elem().Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
callResult := handler.handlerFunc.Call(callParams)
|
callResult := handler.handlerFunc.Call(callParams)
|
||||||
@ -133,13 +139,16 @@ func (s *RPCServer) Register(r interface{}) {
|
|||||||
for i := 0; i < val.NumMethod(); i++ {
|
for i := 0; i < val.NumMethod(); i++ {
|
||||||
method := val.Type().Method(i)
|
method := val.Type().Method(i)
|
||||||
|
|
||||||
fmt.Println(name + "." + method.Name)
|
|
||||||
|
|
||||||
funcType := method.Func.Type()
|
funcType := method.Func.Type()
|
||||||
ins := funcType.NumIn() - 1
|
hasCtx := 0
|
||||||
|
if funcType.NumIn() >= 2 && funcType.In(1) == contextType {
|
||||||
|
hasCtx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
ins := funcType.NumIn() - 1 - hasCtx
|
||||||
recvs := make([]reflect.Type, ins)
|
recvs := make([]reflect.Type, ins)
|
||||||
for i := 0; i < ins; i++ {
|
for i := 0; i < ins; i++ {
|
||||||
recvs[i] = method.Type.In(i + 1)
|
recvs[i] = method.Type.In(i + 1 + hasCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
valOut, errOut, _ := processFuncOut(funcType)
|
valOut, errOut, _ := processFuncOut(funcType)
|
||||||
@ -151,6 +160,8 @@ func (s *RPCServer) Register(r interface{}) {
|
|||||||
handlerFunc: method.Func,
|
handlerFunc: method.Func,
|
||||||
receiver: val,
|
receiver: val,
|
||||||
|
|
||||||
|
hasCtx: hasCtx,
|
||||||
|
|
||||||
errOut: errOut,
|
errOut: errOut,
|
||||||
valOut: valOut,
|
valOut: valOut,
|
||||||
}
|
}
|
||||||
@ -165,7 +176,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
|
|||||||
switch n {
|
switch n {
|
||||||
case 0:
|
case 0:
|
||||||
case 1:
|
case 1:
|
||||||
if funcType.Out(0) == reflect.TypeOf(new(error)).Elem() {
|
if funcType.Out(0) == errorType {
|
||||||
errOut = 0
|
errOut = 0
|
||||||
} else {
|
} else {
|
||||||
valOut = 0
|
valOut = 0
|
||||||
@ -173,7 +184,7 @@ func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) {
|
|||||||
case 2:
|
case 2:
|
||||||
valOut = 0
|
valOut = 0
|
||||||
errOut = 1
|
errOut = 1
|
||||||
if funcType.Out(1) != reflect.TypeOf(new(error)).Elem() {
|
if funcType.Out(1) != errorType {
|
||||||
panic("expected error as second return value")
|
panic("expected error as second return value")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package rpclib
|
package rpclib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SimpleServerHandler struct {
|
type SimpleServerHandler struct {
|
||||||
@ -143,11 +145,11 @@ func TestRPC(t *testing.T) {
|
|||||||
noparam.Add()
|
noparam.Add()
|
||||||
|
|
||||||
var erronly struct {
|
var erronly struct {
|
||||||
Add func() error
|
AddGet func() (int, error)
|
||||||
}
|
}
|
||||||
NewClient(testServ.URL, "SimpleServerHandler", &erronly)
|
NewClient(testServ.URL, "SimpleServerHandler", &erronly)
|
||||||
|
|
||||||
err = erronly.Add()
|
_, err = erronly.AddGet()
|
||||||
if err == nil || err.Error() != "RPC client error: non 200 response code" {
|
if err == nil || err.Error() != "RPC client error: non 200 response code" {
|
||||||
t.Error("wrong error")
|
t.Error("wrong error")
|
||||||
}
|
}
|
||||||
@ -162,3 +164,59 @@ func TestRPC(t *testing.T) {
|
|||||||
t.Error("wrong error")
|
t.Error("wrong error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CtxHandler struct {
|
||||||
|
cancelled bool
|
||||||
|
i int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CtxHandler) Test(ctx context.Context) {
|
||||||
|
timeout := time.After(300 * time.Millisecond)
|
||||||
|
h.i++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
case <-ctx.Done():
|
||||||
|
h.cancelled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCtx(t *testing.T) {
|
||||||
|
// setup server
|
||||||
|
|
||||||
|
serverHandler := &CtxHandler{}
|
||||||
|
|
||||||
|
rpcServer := NewServer()
|
||||||
|
rpcServer.Register(serverHandler)
|
||||||
|
|
||||||
|
// httptest stuff
|
||||||
|
testServ := httptest.NewServer(rpcServer)
|
||||||
|
defer testServ.Close()
|
||||||
|
|
||||||
|
// setup client
|
||||||
|
|
||||||
|
var client struct {
|
||||||
|
Test func(ctx context.Context)
|
||||||
|
}
|
||||||
|
NewClient(testServ.URL, "CtxHandler", &client)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20 * time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client.Test(ctx)
|
||||||
|
if !serverHandler.cancelled {
|
||||||
|
t.Error("expected cancellation on the server side")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverHandler.cancelled = false
|
||||||
|
|
||||||
|
var noCtxClient struct {
|
||||||
|
Test func()
|
||||||
|
}
|
||||||
|
NewClient(testServ.URL, "CtxHandler", &noCtxClient)
|
||||||
|
|
||||||
|
noCtxClient.Test()
|
||||||
|
if serverHandler.cancelled || serverHandler.i != 2 {
|
||||||
|
t.Error("wrong serverHandler state")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user