Merge pull request #1123 from filecoin-project/fix/jsonrpc-chclose-panic

jsonrpc: Fix channel closing race
This commit is contained in:
Łukasz Magiera 2020-01-21 17:14:55 +01:00 committed by GitHub
commit b6563ffc29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 33 deletions

View File

@ -99,7 +99,7 @@ func (h handlers) register(namespace string, r interface{}) {
// Handle // Handle
type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error) type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
type chanOut func(reflect.Value) interface{} type chanOut func(reflect.Value, int64) error
func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) { func (h handlers) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
wf := func(cb func(io.Writer)) { wf := func(cb func(io.Writer)) {
@ -222,16 +222,25 @@ func (h handlers) handle(ctx context.Context, req request, w func(func(io.Writer
if handler.valOut != -1 { if handler.valOut != -1 {
resp.Result = callResult[handler.valOut].Interface() resp.Result = callResult[handler.valOut].Interface()
} }
w(func(w io.Writer) {
if resp.Result != nil && reflect.TypeOf(resp.Result).Kind() == reflect.Chan { if resp.Result != nil && reflect.TypeOf(resp.Result).Kind() == reflect.Chan {
// this must happen in the writer callback, otherwise we may start sending // Channel responses are sent from channel control goroutine.
// channel messages before we send this response // Sending responses here could cause deadlocks on writeLk, or allow
// sending channel messages before this rpc call returns
//noinspection GoNilness // already checked above //noinspection GoNilness // already checked above
resp.Result = chOut(callResult[handler.valOut]) err = chOut(callResult[handler.valOut], *req.ID)
if err == nil {
return // channel goroutine handles responding
} }
log.Warnf("failed to setup channel in RPC call to '%s': %+v", req.Method, err)
resp.Error = &respError{
Code: 1,
Message: err.(error).Error(),
}
}
w(func(w io.Writer) {
if err := json.NewEncoder(w).Encode(resp); err != nil { if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error(err) log.Error(err)
return return

View File

@ -375,5 +375,51 @@ func TestChan(t *testing.T) {
serverHandler.wait <- struct{}{} serverHandler.wait <- struct{}{}
_, ok = <-sub _, ok = <-sub
require.Equal(t, false, ok) require.Equal(t, false, ok)
}
func TestControlChanDeadlock(t *testing.T) {
for r := 0; r < 20; r++ {
testControlChanDeadlock(t)
}
}
func testControlChanDeadlock(t *testing.T) {
var client struct {
Sub func(context.Context, int, int) (<-chan int, error)
}
n := 5000
serverHandler := &ChanHandler{
wait: make(chan struct{}, n),
}
rpcServer := NewServer()
rpcServer.Register("ChanHandler", serverHandler)
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
closer, err := NewClient("ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil)
require.NoError(t, err)
defer closer()
for i := 0; i < n; i++ {
serverHandler.wait <- struct{}{}
}
ctx, _ := context.WithCancel(context.Background())
sub, err := client.Sub(ctx, 1, -1)
require.NoError(t, err)
go func() {
for i := 0; i < n; i++ {
require.Equal(t, i+1, <-sub)
}
}()
_, err = client.Sub(ctx, 2, -1)
require.NoError(t, err)
} }

View File

@ -11,6 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"golang.org/x/xerrors"
) )
const wsCancel = "xrpc.cancel" const wsCancel = "xrpc.cancel"
@ -33,7 +34,9 @@ type frame struct {
} }
type outChanReg struct { type outChanReg struct {
id uint64 reqID int64
chID uint64
ch reflect.Value ch reflect.Value
} }
@ -134,51 +137,80 @@ func (c *wsConn) sendRequest(req request) {
// (forwards channel messages to client) // (forwards channel messages to client)
func (c *wsConn) handleOutChans() { func (c *wsConn) handleOutChans() {
regV := reflect.ValueOf(c.registerCh) regV := reflect.ValueOf(c.registerCh)
exitV := reflect.ValueOf(c.exiting)
cases := []reflect.SelectCase{ cases := []reflect.SelectCase{
{ // registration chan always 0 { // registration chan always 0
Dir: reflect.SelectRecv, Dir: reflect.SelectRecv,
Chan: regV, Chan: regV,
}, },
{ // exit chan always 1
Dir: reflect.SelectRecv,
Chan: exitV,
},
} }
internal := len(cases)
var caseToID []uint64 var caseToID []uint64
for { for {
chosen, val, ok := reflect.Select(cases) chosen, val, ok := reflect.Select(cases)
if chosen == 0 { // control channel switch chosen {
case 0: // registration channel
if !ok { if !ok {
// control channel closed - signals closed connection // control channel closed - signals closed connection
// This shouldn't happen, instead the exiting channel should get closed
log.Warn("control channel closed")
return
}
registration := val.Interface().(outChanReg)
caseToID = append(caseToID, registration.chID)
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: registration.ch,
})
c.nextWriter(func(w io.Writer) {
resp := &response{
Jsonrpc: "2.0",
ID: registration.reqID,
Result: registration.chID,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error(err)
return
}
})
continue
case 1: // exiting channel
if !ok {
// exiting channel closed - signals closed connection
// //
// We're not closing any channels as we're on receiving end. // We're not closing any channels as we're on receiving end.
// Also, context cancellation below should take care of any running // Also, context cancellation below should take care of any running
// requests // requests
return return
} }
log.Warn("exiting channel received a message")
registration := val.Interface().(outChanReg)
caseToID = append(caseToID, registration.id)
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: registration.ch,
})
continue continue
} }
if !ok { if !ok {
// Output channel closed, cleanup, and tell remote that this happened // Output channel closed, cleanup, and tell remote that this happened
n := len(caseToID) n := len(cases) - 1
if n > 0 { if n > 0 {
cases[chosen] = cases[n] cases[chosen] = cases[n]
caseToID[chosen-1] = caseToID[n-1] caseToID[chosen-internal] = caseToID[n-internal]
} }
id := caseToID[chosen-1] id := caseToID[chosen-internal]
cases = cases[:n] cases = cases[:n]
caseToID = caseToID[:n-1] caseToID = caseToID[:n-internal]
c.sendRequest(request{ c.sendRequest(request{
Jsonrpc: "2.0", Jsonrpc: "2.0",
@ -194,24 +226,29 @@ func (c *wsConn) handleOutChans() {
Jsonrpc: "2.0", Jsonrpc: "2.0",
ID: nil, // notification ID: nil, // notification
Method: chValue, Method: chValue,
Params: []param{{v: reflect.ValueOf(caseToID[chosen-1])}, {v: val}}, Params: []param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}},
}) })
} }
} }
// handleChanOut registers output channel for forwarding to client // handleChanOut registers output channel for forwarding to client
func (c *wsConn) handleChanOut(ch reflect.Value) interface{} { func (c *wsConn) handleChanOut(ch reflect.Value, req int64) error {
c.spawnOutChanHandlerOnce.Do(func() { c.spawnOutChanHandlerOnce.Do(func() {
go c.handleOutChans() go c.handleOutChans()
}) })
id := atomic.AddUint64(&c.chanCtr, 1) id := atomic.AddUint64(&c.chanCtr, 1)
c.registerCh <- outChanReg{ select {
id: id, case c.registerCh <- outChanReg{
ch: ch, reqID: req,
}
return id chID: id,
ch: ch,
}:
return nil
case <-c.exiting:
return xerrors.New("connection closing")
}
} }
// // // //
@ -389,7 +426,6 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
c.chanHandlers = map[uint64]func(m []byte, ok bool){} c.chanHandlers = map[uint64]func(m []byte, ok bool){}
c.registerCh = make(chan outChanReg) c.registerCh = make(chan outChanReg)
defer close(c.registerCh)
defer close(c.exiting) defer close(c.exiting)
// //// // ////