diff --git a/chain/exchange/client.go b/chain/exchange/client.go index 57563d5b2..20a130c1d 100644 --- a/chain/exchange/client.go +++ b/chain/exchange/client.go @@ -210,31 +210,9 @@ func (c *client) processResponse(req *Request, res *Response) (*validatedRespons // If the headers were also returned check that the compression // indexes are valid before `toFullTipSets()` is called by the // consumer. - for tipsetIdx := 0; tipsetIdx < resLength; tipsetIdx++ { - msgs := res.Chain[tipsetIdx].Messages - blocksNum := len(res.Chain[tipsetIdx].Blocks) - if len(msgs.BlsIncludes) != blocksNum { - return nil, xerrors.Errorf("BlsIncludes (%d) does not match number of blocks (%d)", - len(msgs.BlsIncludes), blocksNum) - } - if len(msgs.SecpkIncludes) != blocksNum { - return nil, xerrors.Errorf("SecpkIncludes (%d) does not match number of blocks (%d)", - len(msgs.SecpkIncludes), blocksNum) - } - for blockIdx := 0; blockIdx < blocksNum; blockIdx++ { - for _, mi := range msgs.BlsIncludes[blockIdx] { - if int(mi) >= len(msgs.Bls) { - return nil, xerrors.Errorf("index in BlsIncludes (%d) exceeds number of messages (%d)", - mi, len(msgs.Bls)) - } - } - for _, mi := range msgs.SecpkIncludes[blockIdx] { - if int(mi) >= len(msgs.Secpk) { - return nil, xerrors.Errorf("index in SecpkIncludes (%d) exceeds number of messages (%d)", - mi, len(msgs.Secpk)) - } - } - } + err := c.validateCompressedIndices(res.Chain) + if err != nil { + return nil, err } } } @@ -242,6 +220,42 @@ func (c *client) processResponse(req *Request, res *Response) (*validatedRespons return validRes, nil } +func (c *client) validateCompressedIndices(chain []*BSTipSet) error { + resLength := len(chain) + for tipsetIdx := 0; tipsetIdx < resLength; tipsetIdx++ { + msgs := chain[tipsetIdx].Messages + blocksNum := len(chain[tipsetIdx].Blocks) + + if len(msgs.BlsIncludes) != blocksNum { + return xerrors.Errorf("BlsIncludes (%d) does not match number of blocks (%d)", + len(msgs.BlsIncludes), blocksNum) + } + + if len(msgs.SecpkIncludes) != blocksNum { + return xerrors.Errorf("SecpkIncludes (%d) does not match number of blocks (%d)", + len(msgs.SecpkIncludes), blocksNum) + } + + for blockIdx := 0; blockIdx < blocksNum; blockIdx++ { + for _, mi := range msgs.BlsIncludes[blockIdx] { + if int(mi) >= len(msgs.Bls) { + return xerrors.Errorf("index in BlsIncludes (%d) exceeds number of messages (%d)", + mi, len(msgs.Bls)) + } + } + + for _, mi := range msgs.SecpkIncludes[blockIdx] { + if int(mi) >= len(msgs.Secpk) { + return xerrors.Errorf("index in SecpkIncludes (%d) exceeds number of messages (%d)", + mi, len(msgs.Secpk)) + } + } + } + } + + return nil +} + // GetBlocks implements Client.GetBlocks(). Refer to the godocs there. func (c *client) GetBlocks(ctx context.Context, tsk types.TipSetKey, count int) ([]*types.TipSet, error) { ctx, span := trace.StartSpan(ctx, "bsync.GetBlocks")