clean up output channel handling logic to prevent send on closed channels
This commit is contained in:
parent
0533cf22e7
commit
8d85aedeff
@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -161,14 +160,60 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
|
||||
chCtx, chCancel := context.WithCancel(ctx)
|
||||
retVal = ch.Convert(ftyp.Out(valOut))
|
||||
|
||||
incoming := make(chan reflect.Value, 32)
|
||||
|
||||
// gorotuine to handle buffering of items
|
||||
go func() {
|
||||
buf := (&list.List{}).Init()
|
||||
var bufLk sync.Mutex
|
||||
|
||||
for {
|
||||
front := buf.Front()
|
||||
|
||||
cases := []reflect.SelectCase{
|
||||
{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(chCtx.Done()),
|
||||
},
|
||||
{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(incoming),
|
||||
},
|
||||
}
|
||||
|
||||
if front != nil {
|
||||
cases = append(cases, reflect.SelectCase{
|
||||
Dir: reflect.SelectSend,
|
||||
Chan: ch,
|
||||
Send: front.Value.(reflect.Value).Elem(),
|
||||
})
|
||||
}
|
||||
|
||||
chosen, val, _ := reflect.Select(cases)
|
||||
|
||||
switch chosen {
|
||||
case 0:
|
||||
ch.Close()
|
||||
return
|
||||
case 1:
|
||||
vvval := val.Interface().(reflect.Value)
|
||||
buf.PushBack(vvval)
|
||||
if buf.Len() > 1 {
|
||||
if buf.Len() > 10 {
|
||||
log.Warnw("rpc output message buffer", "n", buf.Len())
|
||||
} else {
|
||||
log.Infow("rpc output message buffer", "n", buf.Len())
|
||||
}
|
||||
}
|
||||
|
||||
case 2:
|
||||
buf.Remove(front)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, func(result []byte, ok bool) {
|
||||
if !ok {
|
||||
chCancel()
|
||||
// remote channel closed, close ours too
|
||||
ch.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@ -178,56 +223,15 @@ func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int)
|
||||
return
|
||||
}
|
||||
|
||||
bufLk.Lock()
|
||||
if ctx.Err() != nil {
|
||||
log.Errorf("got rpc message with cancelled context: %s", ctx.Err())
|
||||
bufLk.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
buf.PushBack(val)
|
||||
|
||||
if buf.Len() > 1 {
|
||||
if buf.Len() > 10 {
|
||||
log.Warnw("rpc output message buffer", "n", buf.Len())
|
||||
} else {
|
||||
log.Infow("rpc output message buffer", "n", buf.Len())
|
||||
select {
|
||||
case incoming <- val:
|
||||
case <-chCtx.Done():
|
||||
}
|
||||
bufLk.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for buf.Len() > 0 {
|
||||
front := buf.Front()
|
||||
bufLk.Unlock()
|
||||
|
||||
cases := []reflect.SelectCase{
|
||||
{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(chCtx.Done()),
|
||||
},
|
||||
{
|
||||
Dir: reflect.SelectSend,
|
||||
Chan: ch,
|
||||
Send: front.Value.(reflect.Value).Elem(),
|
||||
},
|
||||
}
|
||||
|
||||
chosen, _, _ := reflect.Select(cases)
|
||||
bufLk.Lock()
|
||||
|
||||
switch chosen {
|
||||
case 0:
|
||||
buf.Init()
|
||||
case 1:
|
||||
buf.Remove(front)
|
||||
}
|
||||
}
|
||||
|
||||
bufLk.Unlock()
|
||||
}()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,7 +297,7 @@ type rpcFunc struct {
|
||||
errOut int
|
||||
|
||||
hasCtx int
|
||||
retCh bool
|
||||
returnValueIsChannel bool
|
||||
|
||||
retry bool
|
||||
}
|
||||
@ -350,7 +354,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
|
||||
// if the function returns a channel, we need to provide a sink for the
|
||||
// messages
|
||||
var chCtor makeChanSink
|
||||
if fn.retCh {
|
||||
if fn.returnValueIsChannel {
|
||||
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
|
||||
}
|
||||
|
||||
@ -385,7 +389,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
|
||||
return fn.processError(xerrors.New("request and response id didn't match"))
|
||||
}
|
||||
|
||||
if fn.valOut != -1 && !fn.retCh {
|
||||
if fn.valOut != -1 && !fn.returnValueIsChannel {
|
||||
val := reflect.New(fn.ftyp.Out(fn.valOut))
|
||||
|
||||
if resp.Result != nil {
|
||||
@ -425,7 +429,7 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
|
||||
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
|
||||
fun.hasCtx = 1
|
||||
}
|
||||
fun.retCh = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
|
||||
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
|
||||
|
||||
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
|
||||
}
|
||||
|
@ -13,9 +13,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logging.SetLogLevel("rpc", "DEBUG")
|
||||
}
|
||||
|
||||
type SimpleServerHandler struct {
|
||||
n int
|
||||
}
|
||||
@ -283,6 +288,9 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
|
||||
out := make(chan int)
|
||||
h.ctxdone = ctx.Done()
|
||||
|
||||
wait := h.wait
|
||||
|
||||
log.Warnf("SERVER SUB!")
|
||||
go func() {
|
||||
defer close(out)
|
||||
var n int
|
||||
@ -292,7 +300,7 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
|
||||
case <-ctx.Done():
|
||||
fmt.Println("ctxdone1")
|
||||
return
|
||||
case <-h.wait:
|
||||
case <-wait:
|
||||
}
|
||||
|
||||
n += i
|
||||
@ -365,15 +373,19 @@ func TestChan(t *testing.T) {
|
||||
|
||||
// sub (again)
|
||||
|
||||
serverHandler.wait = make(chan struct{}, 5)
|
||||
serverHandler.wait <- struct{}{}
|
||||
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
log.Warnf("last sub")
|
||||
sub, err = client.Sub(ctx, 3, 6)
|
||||
require.NoError(t, err)
|
||||
|
||||
log.Warnf("waiting for value now")
|
||||
require.Equal(t, 3, <-sub)
|
||||
log.Warnf("not equal")
|
||||
|
||||
// close (remote)
|
||||
serverHandler.wait <- struct{}{}
|
||||
@ -535,12 +547,25 @@ func testControlChanDeadlock(t *testing.T) {
|
||||
sub, err := client.Sub(ctx, 1, -1)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for i := 0; i < n; i++ {
|
||||
require.Equal(t, i+1, <-sub)
|
||||
if <-sub != i+1 {
|
||||
panic("bad!")
|
||||
//require.Equal(t, i+1, <-sub)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// reset this channel so its not shared between the sub requests...
|
||||
serverHandler.wait = make(chan struct{}, n)
|
||||
for i := 0; i < n; i++ {
|
||||
serverHandler.wait <- struct{}{}
|
||||
}
|
||||
|
||||
_, err = client.Sub(ctx, 2, -1)
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user