Add indexer pubsub message authentication and rate limiting

This commit is contained in:
gammazero 2022-02-08 02:53:25 -08:00
parent eb8296120b
commit 1dc6a2fea6
5 changed files with 413 additions and 11 deletions

View File

@ -2,15 +2,17 @@ package sub
import ( import (
"context" "context"
"fmt" "sync"
"time" "time"
address "github.com/filecoin-project/go-address" address "github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/lotus/build" "github.com/filecoin-project/lotus/build"
"github.com/filecoin-project/lotus/chain" "github.com/filecoin-project/lotus/chain"
"github.com/filecoin-project/lotus/chain/consensus" "github.com/filecoin-project/lotus/chain/consensus"
"github.com/filecoin-project/lotus/chain/messagepool" "github.com/filecoin-project/lotus/chain/messagepool"
"github.com/filecoin-project/lotus/chain/store" "github.com/filecoin-project/lotus/chain/store"
"github.com/filecoin-project/lotus/chain/sub/ratelimit"
"github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/metrics" "github.com/filecoin-project/lotus/metrics"
"github.com/filecoin-project/lotus/node/impl/client" "github.com/filecoin-project/lotus/node/impl/client"
@ -22,6 +24,7 @@ import (
connmgr "github.com/libp2p/go-libp2p-core/connmgr" connmgr "github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
pubsub "github.com/libp2p/go-libp2p-pubsub" pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-varint"
"go.opencensus.io/stats" "go.opencensus.io/stats"
"go.opencensus.io/tag" "go.opencensus.io/tag"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@ -168,12 +171,12 @@ func fetchCids(
cidIndex := make(map[cid.Cid]int) cidIndex := make(map[cid.Cid]int)
for i, c := range cids { for i, c := range cids {
if c.Prefix() != msgCidPrefix { if c.Prefix() != msgCidPrefix {
return fmt.Errorf("invalid msg CID: %s", c) return xerrors.Errorf("invalid msg CID: %s", c)
} }
cidIndex[c] = i cidIndex[c] = i
} }
if len(cids) != len(cidIndex) { if len(cids) != len(cidIndex) {
return fmt.Errorf("duplicate CIDs in fetchCids input") return xerrors.Errorf("duplicate CIDs in fetchCids input")
} }
for block := range bserv.GetBlocks(ctx, cids) { for block := range bserv.GetBlocks(ctx, cids) {
@ -196,7 +199,7 @@ func fetchCids(
if len(cidIndex) > 0 { if len(cidIndex) > 0 {
err := ctx.Err() err := ctx.Err()
if err == nil { if err == nil {
err = fmt.Errorf("failed to fetch %d messages for unknown reasons", len(cidIndex)) err = xerrors.Errorf("failed to fetch %d messages for unknown reasons", len(cidIndex))
} }
return err return err
} }
@ -445,23 +448,199 @@ func recordFailure(ctx context.Context, metric *stats.Int64Measure, failureType
stats.Record(ctx, metric.M(1)) stats.Record(ctx, metric.M(1))
} }
type IndexerMessageValidator struct { type peerMsgInfo struct {
self peer.ID peerID peer.ID
lastCid cid.Cid
rateLimit *ratelimit.Window
mutex sync.Mutex
} }
func NewIndexerMessageValidator(self peer.ID) *IndexerMessageValidator { type IndexerMessageValidator struct {
return &IndexerMessageValidator{self: self} self peer.ID
peerCache *lru.TwoQueueCache
fullNode api.FullNode
}
func NewIndexerMessageValidator(self peer.ID, fullNode api.FullNode) *IndexerMessageValidator {
peerCache, _ := lru.New2Q(1024)
return &IndexerMessageValidator{
self: self,
peerCache: peerCache,
fullNode: fullNode,
}
} }
func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg *pubsub.Message) pubsub.ValidationResult { func (v *IndexerMessageValidator) Validate(ctx context.Context, pid peer.ID, msg *pubsub.Message) pubsub.ValidationResult {
// This chain-node should not be publishing its own messages. These are // This chain-node should not be publishing its own messages. These are
// relayed from miner-nodes or index publishers. If a node appears to be // relayed from market-nodes. If a node appears to be local, reject it.
// local, reject it.
if pid == v.self { if pid == v.self {
log.Warnf("refusing to relay indexer message from self") log.Warnf("refusing to relay indexer message from self")
stats.Record(ctx, metrics.IndexerMessageValidationFailure.M(1)) stats.Record(ctx, metrics.IndexerMessageValidationFailure.M(1))
return pubsub.ValidationReject return pubsub.ValidationReject
} }
originPeer := msg.GetFrom()
if originPeer == v.self {
log.Warnf("refusing to relay indexer message originating from self")
stats.Record(ctx, metrics.IndexerMessageValidationFailure.M(1))
return pubsub.ValidationReject
}
// Decode CID and originator addresses from message.
minerID, msgCid, err := decodeIndexerMessage(msg.Data)
if err != nil {
log.Errorw("Could not decode pubsub message", "err", err)
return pubsub.ValidationReject
}
if minerID == "" {
log.Warnw("ignoring messsage missing miner id", "peer", originPeer)
return pubsub.ValidationIgnore
}
var msgInfo *peerMsgInfo
val, ok := v.peerCache.Get(minerID)
if !ok {
msgInfo = &peerMsgInfo{}
} else {
msgInfo = val.(*peerMsgInfo)
}
// Lock this peer's message info.
msgInfo.mutex.Lock()
defer msgInfo.mutex.Unlock()
if !ok || originPeer != msgInfo.peerID {
// Check that the message was signed by an authenticated peer.
err = v.authenticateMessage(ctx, minerID, originPeer)
if err != nil {
log.Warnw("cannot authenticate messsage", "err", err, "peer", originPeer, "minerID", minerID)
stats.Record(ctx, metrics.IndexerMessageValidationFailure.M(1))
return pubsub.ValidationReject
}
msgInfo.peerID = originPeer
if !ok {
// Add msgInfo to cache only after being authenticated. If two
// messages from the same peer are handled concurrently, there is a
// small chance that one msgInfo could replace the other here when
// the info is first cached. This is OK, so no need to prevent it.
v.peerCache.Add(minerID, msgInfo)
}
}
// See if message needs to be ignored due to rate limiting.
if v.rateLimitPeer(msgInfo, msgCid) {
return pubsub.ValidationIgnore
}
stats.Record(ctx, metrics.IndexerMessageValidationSuccess.M(1)) stats.Record(ctx, metrics.IndexerMessageValidationSuccess.M(1))
return pubsub.ValidationAccept return pubsub.ValidationAccept
} }
func (v *IndexerMessageValidator) rateLimitPeer(msgInfo *peerMsgInfo, msgCid cid.Cid) bool {
const (
msgLimit = 5
msgTimeLimit = 10 * time.Second
repeatTimeLimit = 2 * time.Hour
)
timeWindow := msgInfo.rateLimit
// Check overall message rate.
if timeWindow == nil {
timeWindow = ratelimit.NewWindow(msgLimit, msgTimeLimit)
msgInfo.rateLimit = timeWindow
} else if msgInfo.lastCid == msgCid {
// Check if this is a repeat of the previous message data.
if time.Since(timeWindow.Newest()) < repeatTimeLimit {
log.Warnw("ignoring repeated indexer message", "sender", msgInfo.peerID)
return true
}
}
err := timeWindow.Add()
if err != nil {
log.Warnw("ignoring indexer message", "sender", msgInfo.peerID, "err", err)
return true
}
msgInfo.lastCid = msgCid
return false
}
func decodeIndexerMessage(data []byte) (string, cid.Cid, error) {
n, msgCid, err := cid.CidFromBytes(data)
if err != nil {
return "", cid.Undef, err
}
if n > len(data) {
return "", cid.Undef, xerrors.New("bad cid length encoding")
}
data = data[n:]
var minerID string
if len(data) != 0 {
addrCount, n, err := varint.FromUvarint(data)
if err != nil {
return "", cid.Undef, xerrors.Errorf("cannot read number of multiaddrs: %w", err)
}
if n > len(data) {
return "", cid.Undef, xerrors.New("bad multiaddr count encoding")
}
data = data[n:]
if addrCount != 0 {
// Read multiaddrs if there is any more data in message data. This allows
// backward-compatability with publishers that do not supply address data.
for i := 0; i < int(addrCount); i++ {
val, n, err := varint.FromUvarint(data)
if err != nil {
return "", cid.Undef, xerrors.Errorf("cannot read multiaddrs length: %w", err)
}
if n > len(data) {
return "", cid.Undef, xerrors.New("bad multiaddr length encoding")
}
data = data[n:]
if len(data) < int(val) {
return "", cid.Undef, xerrors.New("bad multiaddr encoding")
}
data = data[val:]
}
}
if len(data) != 0 {
minerID = string(data)
}
}
return minerID, msgCid, nil
}
func (v *IndexerMessageValidator) authenticateMessage(ctx context.Context, minerID string, peerID peer.ID) error {
// Get miner info from lotus
minerAddress, err := address.NewFromString(minerID)
if err != nil {
return xerrors.Errorf("invalid miner id: %w", err)
}
ts, err := v.fullNode.ChainHead(ctx)
if err != nil {
return err
}
minerInfo, err := v.fullNode.StateMinerInfo(ctx, minerAddress, ts.Key())
if err != nil {
return err
}
if minerInfo.PeerId == nil {
return xerrors.New("no peer id for miner")
}
if *minerInfo.PeerId != peerID {
return xerrors.New("message not signed by peer in miner info")
}
return nil
}

View File

@ -0,0 +1,89 @@
package ratelimit
import "errors"
var ErrRate = errors.New("rate exceeded")
type queue struct {
buf []int64
count int
head int
tail int
}
// cap returns the queue capacity
func (q *queue) cap() int {
return len(q.buf)
}
// len returns the number of items in the queue
func (q *queue) len() int {
return q.count
}
// push adds an element to the end of the queue.
func (q *queue) push(elem int64) error {
if q.count == len(q.buf) {
return ErrRate
}
q.buf[q.tail] = elem
// Calculate new tail position.
q.tail = q.next(q.tail)
q.count++
return nil
}
// pop removes and returns the element from the front of the queue.
func (q *queue) pop() int64 {
if q.count <= 0 {
panic("pop from empty queue")
}
ret := q.buf[q.head]
// Calculate new head position.
q.head = q.next(q.head)
q.count--
return ret
}
// front returns the element at the front of the queue. This is the element
// that would be returned by pop(). This call panics if the queue is empty.
func (q *queue) front() int64 {
if q.count <= 0 {
panic("front() called when empty")
}
return q.buf[q.head]
}
// back returns the element at the back of the queue. This call panics if the
// queue is empty.
func (q *queue) back() int64 {
if q.count <= 0 {
panic("back() called when empty")
}
return q.buf[q.prev(q.tail)]
}
// prev returns the previous buffer position wrapping around buffer.
func (q *queue) prev(i int) int {
if i == 0 {
return len(q.buf) - 1
}
return (i - 1) % len(q.buf)
}
// next returns the next buffer position wrapping around buffer.
func (q *queue) next(i int) int {
return (i + 1) % len(q.buf)
}
// truncate pops values that are less than or equal the specified threshold.
func (q *queue) truncate(threshold int64) {
for q.count != 0 && q.buf[q.head] <= threshold {
// pop() without returning a value
q.head = q.next(q.head)
q.count--
}
}

View File

@ -0,0 +1,70 @@
package ratelimit
import "time"
// Window is a time windows for counting events within a span of time. The
// windows slides forward in time so that it spans from the most recent event
// to size time in the past.
type Window struct {
q *queue
size int64
}
// NewWindow creates a new Window that limits the number of events to maximum
// count of events withing a duration of time. The capacity sets the maximum
// number of events, and size sets the span of time over which the events are
// counted.
func NewWindow(capacity int, size time.Duration) *Window {
return &Window{
q: &queue{
buf: make([]int64, capacity),
},
size: int64(size),
}
}
// Add attempts to append a new timestamp into the current window. Previously
// added values that are not not within `size` difference from the value being
// added are first removed. Add fails if adding the value would cause the
// window to exceed capacity.
func (w *Window) Add() error {
now := time.Now().UnixNano()
if w.Len() != 0 {
w.q.truncate(now - w.size)
}
return w.q.push(now)
}
// Cap returns the maximum number of items the window can hold.
func (w *Window) Cap() int {
return w.q.cap()
}
// Len returns the number of elements currently in the window.
func (w *Window) Len() int {
return w.q.len()
}
// Span returns the distance from the first to the last item in the window.
func (w *Window) Span() time.Duration {
if w.q.len() < 2 {
return 0
}
return time.Duration(w.q.back() - w.q.front())
}
// Oldest returns the oldest timestamp in the window.
func (w *Window) Oldest() time.Time {
if w.q.len() == 0 {
return time.Time{}
}
return time.Unix(0, w.q.front())
}
// Newest returns the newest timestamp in the window.
func (w *Window) Newest() time.Time {
if w.q.len() == 0 {
return time.Time{}
}
return time.Unix(0, w.q.back())
}

View File

@ -0,0 +1,61 @@
package ratelimit
import (
"testing"
"time"
)
func TestWindow(t *testing.T) {
const (
maxEvents = 3
timeLimit = 100 * time.Millisecond
)
w := NewWindow(maxEvents, timeLimit)
if w.Len() != 0 {
t.Fatal("q.Len() =", w.Len(), "expect 0")
}
if w.Cap() != maxEvents {
t.Fatal("q.Cap() =", w.Cap(), "expect 3")
}
if !w.Newest().IsZero() {
t.Fatal("expected newest to be zero time with empty window")
}
if !w.Oldest().IsZero() {
t.Fatal("expected oldest to be zero time with empty window")
}
if w.Span() != 0 {
t.Fatal("expected span to be zero time with empty window")
}
var err error
for i := 0; i < maxEvents; i++ {
err = w.Add()
if err != nil {
t.Fatalf("cannot add event %d", i)
}
}
if w.Len() != maxEvents {
t.Fatalf("q.Len() is %d, expected %d", w.Len(), maxEvents)
}
if err = w.Add(); err == nil {
t.Fatalf("add event %d within time limit should have failed", maxEvents+1)
}
time.Sleep(timeLimit)
if err = w.Add(); err != nil {
t.Fatalf("cannot add event after time limit: %s", err)
}
prev := w.Newest()
time.Sleep(timeLimit)
err = w.Add()
if err != nil {
t.Fatalf("cannot add event")
}
if w.Newest().Before(prev) {
t.Fatal("newest is before previous value")
}
if w.Oldest().Before(prev) {
t.Fatal("oldest is before previous value")
}
}

View File

@ -20,6 +20,7 @@ import (
"github.com/filecoin-project/go-fil-markets/discovery" "github.com/filecoin-project/go-fil-markets/discovery"
discoveryimpl "github.com/filecoin-project/go-fil-markets/discovery/impl" discoveryimpl "github.com/filecoin-project/go-fil-markets/discovery/impl"
"github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/lotus/build" "github.com/filecoin-project/lotus/build"
"github.com/filecoin-project/lotus/chain" "github.com/filecoin-project/lotus/chain"
"github.com/filecoin-project/lotus/chain/beacon" "github.com/filecoin-project/lotus/chain/beacon"
@ -201,7 +202,9 @@ func HandleIncomingMessages(mctx helpers.MetricsCtx, lc fx.Lifecycle, ps *pubsub
func RelayIndexerMessages(lc fx.Lifecycle, ps *pubsub.PubSub, nn dtypes.NetworkName, h host.Host) error { func RelayIndexerMessages(lc fx.Lifecycle, ps *pubsub.PubSub, nn dtypes.NetworkName, h host.Host) error {
topicName := build.IndexerIngestTopic(nn) topicName := build.IndexerIngestTopic(nn)
v := sub.NewIndexerMessageValidator(h.ID()) // TODO: How do this get set?
var fullNode api.FullNode
v := sub.NewIndexerMessageValidator(h.ID(), fullNode)
if err := ps.RegisterTopicValidator(topicName, v.Validate); err != nil { if err := ps.RegisterTopicValidator(topicName, v.Validate); err != nil {
return xerrors.Errorf("failed to register validator for topic %s, err: %w", topicName, err) return xerrors.Errorf("failed to register validator for topic %s, err: %w", topicName, err)