clean up output channel handling logic to prevent send on closed channels
This commit is contained in:
parent
0533cf22e7
commit
8d85aedeff
@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -161,14 +160,60 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
|
|||||||
chCtx, chCancel := context.WithCancel(ctx)
|
chCtx, chCancel := context.WithCancel(ctx)
|
||||||
retVal = ch.Convert(ftyp.Out(valOut))
|
retVal = ch.Convert(ftyp.Out(valOut))
|
||||||
|
|
||||||
buf := (&list.List{}).Init()
|
incoming := make(chan reflect.Value, 32)
|
||||||
var bufLk sync.Mutex
|
|
||||||
|
// gorotuine to handle buffering of items
|
||||||
|
go func() {
|
||||||
|
buf := (&list.List{}).Init()
|
||||||
|
|
||||||
|
for {
|
||||||
|
front := buf.Front()
|
||||||
|
|
||||||
|
cases := []reflect.SelectCase{
|
||||||
|
{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(chCtx.Done()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(incoming),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if front != nil {
|
||||||
|
cases = append(cases, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectSend,
|
||||||
|
Chan: ch,
|
||||||
|
Send: front.Value.(reflect.Value).Elem(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
chosen, val, _ := reflect.Select(cases)
|
||||||
|
|
||||||
|
switch chosen {
|
||||||
|
case 0:
|
||||||
|
ch.Close()
|
||||||
|
return
|
||||||
|
case 1:
|
||||||
|
vvval := val.Interface().(reflect.Value)
|
||||||
|
buf.PushBack(vvval)
|
||||||
|
if buf.Len() > 1 {
|
||||||
|
if buf.Len() > 10 {
|
||||||
|
log.Warnw("rpc output message buffer", "n", buf.Len())
|
||||||
|
} else {
|
||||||
|
log.Infow("rpc output message buffer", "n", buf.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 2:
|
||||||
|
buf.Remove(front)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return ctx, func(result []byte, ok bool) {
|
return ctx, func(result []byte, ok bool) {
|
||||||
if !ok {
|
if !ok {
|
||||||
chCancel()
|
chCancel()
|
||||||
// remote channel closed, close ours too
|
|
||||||
ch.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,56 +223,15 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bufLk.Lock()
|
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
log.Errorf("got rpc message with cancelled context: %s", ctx.Err())
|
log.Errorf("got rpc message with cancelled context: %s", ctx.Err())
|
||||||
bufLk.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
buf.PushBack(val)
|
select {
|
||||||
|
case incoming <- val:
|
||||||
if buf.Len() > 1 {
|
case <-chCtx.Done():
|
||||||
if buf.Len() > 10 {
|
|
||||||
log.Warnw("rpc output message buffer", "n", buf.Len())
|
|
||||||
} else {
|
|
||||||
log.Infow("rpc output message buffer", "n", buf.Len())
|
|
||||||
}
|
|
||||||
bufLk.Unlock()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
for buf.Len() > 0 {
|
|
||||||
front := buf.Front()
|
|
||||||
bufLk.Unlock()
|
|
||||||
|
|
||||||
cases := []reflect.SelectCase{
|
|
||||||
{
|
|
||||||
Dir: reflect.SelectRecv,
|
|
||||||
Chan: reflect.ValueOf(chCtx.Done()),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Dir: reflect.SelectSend,
|
|
||||||
Chan: ch,
|
|
||||||
Send: front.Value.(reflect.Value).Elem(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
chosen, _, _ := reflect.Select(cases)
|
|
||||||
bufLk.Lock()
|
|
||||||
|
|
||||||
switch chosen {
|
|
||||||
case 0:
|
|
||||||
buf.Init()
|
|
||||||
case 1:
|
|
||||||
buf.Remove(front)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bufLk.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,8 +296,8 @@ type rpcFunc struct {
|
|||||||
valOut int
|
valOut int
|
||||||
errOut int
|
errOut int
|
||||||
|
|
||||||
hasCtx int
|
hasCtx int
|
||||||
retCh bool
|
returnValueIsChannel bool
|
||||||
|
|
||||||
retry bool
|
retry bool
|
||||||
}
|
}
|
||||||
@ -350,7 +354,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
|
|||||||
// if the function returns a channel, we need to provide a sink for the
|
// if the function returns a channel, we need to provide a sink for the
|
||||||
// messages
|
// messages
|
||||||
var chCtor makeChanSink
|
var chCtor makeChanSink
|
||||||
if fn.retCh {
|
if fn.returnValueIsChannel {
|
||||||
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
|
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,7 +389,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
|
|||||||
return fn.processError(xerrors.New("request and response id didn't match"))
|
return fn.processError(xerrors.New("request and response id didn't match"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if fn.valOut != -1 && !fn.retCh {
|
if fn.valOut != -1 && !fn.returnValueIsChannel {
|
||||||
val := reflect.New(fn.ftyp.Out(fn.valOut))
|
val := reflect.New(fn.ftyp.Out(fn.valOut))
|
||||||
|
|
||||||
if resp.Result != nil {
|
if resp.Result != nil {
|
||||||
@ -425,7 +429,7 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
|
|||||||
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
|
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
|
||||||
fun.hasCtx = 1
|
fun.hasCtx = 1
|
||||||
}
|
}
|
||||||
fun.retCh = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
|
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
|
||||||
|
|
||||||
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
|
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
logging "github.com/ipfs/go-log/v2"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logging.SetLogLevel("rpc", "DEBUG")
|
||||||
|
}
|
||||||
|
|
||||||
type SimpleServerHandler struct {
|
type SimpleServerHandler struct {
|
||||||
n int
|
n int
|
||||||
}
|
}
|
||||||
@ -283,6 +288,9 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
|
|||||||
out := make(chan int)
|
out := make(chan int)
|
||||||
h.ctxdone = ctx.Done()
|
h.ctxdone = ctx.Done()
|
||||||
|
|
||||||
|
wait := h.wait
|
||||||
|
|
||||||
|
log.Warnf("SERVER SUB!")
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
var n int
|
var n int
|
||||||
@ -292,7 +300,7 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
fmt.Println("ctxdone1")
|
fmt.Println("ctxdone1")
|
||||||
return
|
return
|
||||||
case <-h.wait:
|
case <-wait:
|
||||||
}
|
}
|
||||||
|
|
||||||
n += i
|
n += i
|
||||||
@ -365,15 +373,19 @@ func TestChan(t *testing.T) {
|
|||||||
|
|
||||||
// sub (again)
|
// sub (again)
|
||||||
|
|
||||||
|
serverHandler.wait = make(chan struct{}, 5)
|
||||||
serverHandler.wait <- struct{}{}
|
serverHandler.wait <- struct{}{}
|
||||||
|
|
||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
log.Warnf("last sub")
|
||||||
sub, err = client.Sub(ctx, 3, 6)
|
sub, err = client.Sub(ctx, 3, 6)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
log.Warnf("waiting for value now")
|
||||||
require.Equal(t, 3, <-sub)
|
require.Equal(t, 3, <-sub)
|
||||||
|
log.Warnf("not equal")
|
||||||
|
|
||||||
// close (remote)
|
// close (remote)
|
||||||
serverHandler.wait <- struct{}{}
|
serverHandler.wait <- struct{}{}
|
||||||
@ -535,12 +547,25 @@ func testControlChanDeadlock(t *testing.T) {
|
|||||||
sub, err := client.Sub(ctx, 1, -1)
|
sub, err := client.Sub(ctx, 1, -1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(done)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
require.Equal(t, i+1, <-sub)
|
if <-sub != i+1 {
|
||||||
|
panic("bad!")
|
||||||
|
//require.Equal(t, i+1, <-sub)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// reset this channel so its not shared between the sub requests...
|
||||||
|
serverHandler.wait = make(chan struct{}, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
serverHandler.wait <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
_, err = client.Sub(ctx, 2, -1)
|
_, err = client.Sub(ctx, 2, -1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
<-done
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user