rpc: handle wrong HTTP batch response length (#26064)

This commit is contained in:
Jordan Krage 2022-11-02 09:29:33 -05:00 committed by Felix Lange
parent 27600a5b84
commit 211dbb7197
3 changed files with 52 additions and 0 deletions

View File

@ -31,6 +31,7 @@ import (
) )
var ( var (
ErrBadResult = errors.New("bad result in JSON-RPC response")
ErrClientQuit = errors.New("client is closed") ErrClientQuit = errors.New("client is closed")
ErrNoResult = errors.New("no result in JSON-RPC response") ErrNoResult = errors.New("no result in JSON-RPC response")
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")

View File

@ -19,6 +19,7 @@ package rpc
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@ -144,6 +145,53 @@ func TestClientBatchRequest(t *testing.T) {
} }
} }
func TestClientBatchRequest_len(t *testing.T) {
b, err := json.Marshal([]jsonrpcMessage{
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)},
})
if err != nil {
t.Fatal("failed to encode jsonrpc message:", err)
}
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(b)
if err != nil {
t.Error("failed to write response:", err)
}
}))
t.Cleanup(s.Close)
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
t.Run("too-few", func(t *testing.T) {
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
}
})
t.Run("too-many", func(t *testing.T) {
batch := []BatchElem{
{Method: "foo"},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
}
})
}
func TestClientNotify(t *testing.T) { func TestClientNotify(t *testing.T) {
server := newTestServer() server := newTestServer()
defer server.Stop() defer server.Stop()

View File

@ -173,6 +173,9 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
return err return err
} }
if len(respmsgs) != len(msgs) {
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
}
for i := 0; i < len(respmsgs); i++ { for i := 0; i < len(respmsgs); i++ {
op.resp <- &respmsgs[i] op.resp <- &respmsgs[i]
} }