clean up output channel handling logic to prevent send on closed channels

This commit is contained in:
Jeromy 2020-05-13 16:31:17 -07:00
parent 0533cf22e7
commit 8d85aedeff
2 changed files with 85 additions and 56 deletions

View File

@ -8,7 +8,6 @@ import (
"fmt"
"net/http"
"reflect"
"sync"
"sync/atomic"
"time"
@ -161,14 +160,60 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
chCtx, chCancel := context.WithCancel(ctx)
retVal = ch.Convert(ftyp.Out(valOut))
incoming := make(chan reflect.Value, 32)
// gorotuine to handle buffering of items
go func() {
buf := (&list.List{}).Init()
var bufLk sync.Mutex
for {
front := buf.Front()
cases := []reflect.SelectCase{
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(chCtx.Done()),
},
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(incoming),
},
}
if front != nil {
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectSend,
Chan: ch,
Send: front.Value.(reflect.Value).Elem(),
})
}
chosen, val, _ := reflect.Select(cases)
switch chosen {
case 0:
ch.Close()
return
case 1:
vvval := val.Interface().(reflect.Value)
buf.PushBack(vvval)
if buf.Len() > 1 {
if buf.Len() > 10 {
log.Warnw("rpc output message buffer", "n", buf.Len())
} else {
log.Infow("rpc output message buffer", "n", buf.Len())
}
}
case 2:
buf.Remove(front)
}
}
}()
return ctx, func(result []byte, ok bool) {
if !ok {
chCancel()
// remote channel closed, close ours too
ch.Close()
return
}
@ -178,56 +223,15 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
return
}
bufLk.Lock()
if ctx.Err() != nil {
log.Errorf("got rpc message with cancelled context: %s", ctx.Err())
bufLk.Unlock()
return
}
buf.PushBack(val)
if buf.Len() > 1 {
if buf.Len() > 10 {
log.Warnw("rpc output message buffer", "n", buf.Len())
} else {
log.Infow("rpc output message buffer", "n", buf.Len())
select {
case incoming <- val:
case <-chCtx.Done():
}
bufLk.Unlock()
return
}
go func() {
for buf.Len() > 0 {
front := buf.Front()
bufLk.Unlock()
cases := []reflect.SelectCase{
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(chCtx.Done()),
},
{
Dir: reflect.SelectSend,
Chan: ch,
Send: front.Value.(reflect.Value).Elem(),
},
}
chosen, _, _ := reflect.Select(cases)
bufLk.Lock()
switch chosen {
case 0:
buf.Init()
case 1:
buf.Remove(front)
}
}
bufLk.Unlock()
}()
}
}
@ -293,7 +297,7 @@ type rpcFunc struct {
errOut int
hasCtx int
retCh bool
returnValueIsChannel bool
retry bool
}
@ -350,7 +354,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
// if the function returns a channel, we need to provide a sink for the
// messages
var chCtor makeChanSink
if fn.retCh {
if fn.returnValueIsChannel {
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
}
@ -385,7 +389,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
return fn.processError(xerrors.New("request and response id didn't match"))
}
if fn.valOut != -1 && !fn.retCh {
if fn.valOut != -1 && !fn.returnValueIsChannel {
val := reflect.New(fn.ftyp.Out(fn.valOut))
if resp.Result != nil {
@ -425,7 +429,7 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
fun.hasCtx = 1
}
fun.retCh = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
}

View File

@ -13,9 +13,14 @@ import (
"time"
"github.com/gorilla/websocket"
logging "github.com/ipfs/go-log/v2"
"github.com/stretchr/testify/require"
)
func init() {
logging.SetLogLevel("rpc", "DEBUG")
}
type SimpleServerHandler struct {
n int
}
@ -283,6 +288,9 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
out := make(chan int)
h.ctxdone = ctx.Done()
wait := h.wait
log.Warnf("SERVER SUB!")
go func() {
defer close(out)
var n int
@ -292,7 +300,7 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
case <-ctx.Done():
fmt.Println("ctxdone1")
return
case <-h.wait:
case <-wait:
}
n += i
@ -365,15 +373,19 @@ func TestChan(t *testing.T) {
// sub (again)
serverHandler.wait = make(chan struct{}, 5)
serverHandler.wait <- struct{}{}
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
log.Warnf("last sub")
sub, err = client.Sub(ctx, 3, 6)
require.NoError(t, err)
log.Warnf("waiting for value now")
require.Equal(t, 3, <-sub)
log.Warnf("not equal")
// close (remote)
serverHandler.wait <- struct{}{}
@ -535,12 +547,25 @@ func testControlChanDeadlock(t *testing.T) {
sub, err := client.Sub(ctx, 1, -1)
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < n; i++ {
require.Equal(t, i+1, <-sub)
if <-sub != i+1 {
panic("bad!")
//require.Equal(t, i+1, <-sub)
}
}
}()
// reset this channel so its not shared between the sub requests...
serverHandler.wait = make(chan struct{}, n)
for i := 0; i < n; i++ {
serverHandler.wait <- struct{}{}
}
_, err = client.Sub(ctx, 2, -1)
require.NoError(t, err)
<-done
}