diff --git a/chain/sub/incoming.go b/chain/sub/incoming.go index fc2b7baff..3a11f7c98 100644 --- a/chain/sub/incoming.go +++ b/chain/sub/incoming.go @@ -529,9 +529,8 @@ func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg msgCid := idxrMsg.Cid - var msgInfo *peerMsgInfo - msgInfo, ok := v.peerCache.Get(minerAddr) - if !ok { + msgInfo, cached := v.peerCache.Get(minerAddr) + if !cached { msgInfo = &peerMsgInfo{} } @@ -539,17 +538,17 @@ func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg msgInfo.mutex.Lock() defer msgInfo.mutex.Unlock() - if ok { + var seqno uint64 + if cached { // Reject replayed messages. - seqno := binary.BigEndian.Uint64(msg.Message.GetSeqno()) + seqno = binary.BigEndian.Uint64(msg.Message.GetSeqno()) if seqno <= msgInfo.lastSeqno { log.Debugf("ignoring replayed indexer message") return pubsub.ValidationIgnore } - msgInfo.lastSeqno = seqno } - if !ok || originPeer != msgInfo.peerID { + if !cached || originPeer != msgInfo.peerID { // Check that the miner ID maps to the peer that sent the message. err = v.authenticateMessage(ctx, minerAddr, originPeer) if err != nil { @@ -558,7 +557,7 @@ func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg return pubsub.ValidationReject } msgInfo.peerID = originPeer - if !ok { + if !cached { // Add msgInfo to cache only after being authenticated. If two // messages from the same peer are handled concurrently, there is a // small chance that one msgInfo could replace the other here when @@ -567,6 +566,9 @@ func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg } } + // Update message info cache with the latest message's sequence number. + msgInfo.lastSeqno = seqno + // See if message needs to be ignored due to rate limiting. if v.rateLimitPeer(msgInfo, msgCid) { return pubsub.ValidationIgnore diff --git a/chain/sub/incoming_test.go b/chain/sub/incoming_test.go index f54e09049..d8ee99b7f 100644 --- a/chain/sub/incoming_test.go +++ b/chain/sub/incoming_test.go @@ -12,10 +12,12 @@ import ( "github.com/ipni/go-libipni/announce/message" pubsub "github.com/libp2p/go-libp2p-pubsub" pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/filecoin-project/go-address" + "github.com/filecoin-project/lotus/api" "github.com/filecoin-project/lotus/api/mocks" "github.com/filecoin-project/lotus/chain/types" ) @@ -134,3 +136,123 @@ func TestIndexerMessageValidator_Validate(t *testing.T) { }) } } + +func TestIdxValidator(t *testing.T) { + validCid, err := cid.Decode("QmbpDgg5kRLDgMxS8vPKNFXEcA6D5MC4CkuUdSWDVtHPGK") + if err != nil { + t.Fatal(err) + } + + addr, err := address.NewFromString("f01024") + if err != nil { + t.Fatal(err) + } + + buf1, err := addr.MarshalBinary() + if err != nil { + t.Fatal(err) + } + + selfPID := "12D3KooWQiCbqEStCkdqUvr69gQsrp9urYJZUCkzsQXia7mbqbFW" + senderPID := "12D3KooWE8yt84RVwW3sFcd6WMjbUdWrZer2YtT4dmtj3dHdahSZ" + extraData := buf1 + + mc := gomock.NewController(t) + node := mocks.NewMockFullNode(mc) + node.EXPECT().ChainHead(gomock.Any()).Return(nil, nil).AnyTimes() + + subject := NewIndexerMessageValidator(peer.ID(selfPID), node, node) + message := message.Message{ + Cid: validCid, + Addrs: nil, + ExtraData: extraData, + } + buf := bytes.NewBuffer(nil) + if err := message.MarshalCBOR(buf); err != nil { + t.Fatal(err) + } + + topic := "topic" + + privk, _, err := crypto.GenerateKeyPair(crypto.RSA, 2048) + if err != nil { + t.Fatal(err) + } + id, err := peer.IDFromPublicKey(privk.GetPublic()) + if err != nil { + t.Fatal(err) + } + + node.EXPECT().StateMinerInfo(gomock.Any(), gomock.Any(), gomock.Any()).Return(api.MinerInfo{PeerId: &id}, nil).AnyTimes() + + pbm := &pb.Message{ + Data: buf.Bytes(), + Topic: &topic, + From: []byte(id), + Seqno: []byte{1, 1, 1, 1, 2, 2, 2, 2}, + } + validate := subject.Validate(context.Background(), peer.ID(senderPID), &pubsub.Message{ + Message: pbm, + ReceivedFrom: peer.ID("f01024"), // peer.ID(senderPID), + ValidatorData: nil, + }) + if validate != pubsub.ValidationAccept { + t.Error("Expected to receive ValidationAccept") + } + msgInfo, cached := subject.peerCache.Get(addr) + if !cached { + t.Fatal("Message info should be in cache") + } + seqno := msgInfo.lastSeqno + msgInfo.rateLimit = nil // prevent interference from rate limiting + + t.Log("Sending DoS msg") + privk, _, err = crypto.GenerateKeyPair(crypto.RSA, 2048) + if err != nil { + t.Fatal(err) + } + id2, err := peer.IDFromPublicKey(privk.GetPublic()) + if err != nil { + t.Fatal(err) + } + pbm = &pb.Message{ + Data: buf.Bytes(), + Topic: &topic, + From: []byte(id2), + Seqno: []byte{255, 255, 255, 255, 255, 255, 255, 255}, + } + validate = subject.Validate(context.Background(), peer.ID(senderPID), &pubsub.Message{ + Message: pbm, + ReceivedFrom: peer.ID(senderPID), + ValidatorData: nil, + }) + if validate != pubsub.ValidationReject { + t.Error("Expected to get ValidationReject") + } + msgInfo, cached = subject.peerCache.Get(addr) + if !cached { + t.Fatal("Message info should be in cache") + } + msgInfo.rateLimit = nil // prevent interference from rate limiting + + // Check if DoS is possible. + if msgInfo.lastSeqno != seqno { + t.Fatal("Sequence number should not have been updated") + } + + t.Log("Sending another valid message from miner...") + pbm = &pb.Message{ + Data: buf.Bytes(), + Topic: &topic, + From: []byte(id), + Seqno: []byte{1, 1, 1, 1, 2, 2, 2, 3}, + } + validate = subject.Validate(context.Background(), peer.ID(senderPID), &pubsub.Message{ + Message: pbm, + ReceivedFrom: peer.ID("f01024"), // peer.ID(senderPID), + ValidatorData: nil, + }) + if validate != pubsub.ValidationAccept { + t.Fatal("Did not receive ValidationAccept") + } +}