diff --git a/chain/sub/bcast/consistent.go b/chain/sub/bcast/consistent.go index 918a6819f..af31389ee 100644 --- a/chain/sub/bcast/consistent.go +++ b/chain/sub/bcast/consistent.go @@ -27,9 +27,23 @@ type blksInfo struct { } type bcastDict struct { - // TODO: Consider making this a KeyMutexed map - lk sync.RWMutex - blks map[cid.Cid]*blksInfo // map[epoch + VRFProof]blksInfo + // thread-safe map impl for the dictionary + // sync.Map accepts `any` as keys and values. + // To make it type safe and only support the right + // types we use this auxiliary type. + m *sync.Map +} + +func (bd *bcastDict) load(key multihash.Multihash) (*blksInfo, bool) { + v, ok := bd.m.Load(key.String()) + if !ok { + return nil, ok + } + return v.(*blksInfo), ok +} + +func (bd *bcastDict) store(key multihash.Multihash, d *blksInfo) { + bd.m.Store(key.String(), d) } type ConsistentBCast struct { @@ -40,16 +54,14 @@ type ConsistentBCast struct { } func newBcastDict(delay time.Duration) *bcastDict { - return &bcastDict{ - blks: make(map[cid.Cid]*blksInfo), - } + return &bcastDict{new(sync.Map)} } // TODO: What if the VRFProof is already small?? We donĀ“t need the CID. Useless computation. -func BCastKey(bh *types.BlockHeader) cid.Cid { +func BCastKey(bh *types.BlockHeader) (multihash.Multihash, error) { proof := bh.Ticket.VRFProof binary.PutVarint(proof, int64(bh.Height)) - return cid.NewCidV0(multihash.Multihash(proof)) + return multihash.Sum(proof, multihash.SHA2_256, -1) } func NewConsistentBCast(delay time.Duration) *ConsistentBCast { @@ -78,14 +90,16 @@ func (cb *ConsistentBCast) RcvBlock(ctx context.Context, blk *types.BlockMsg) er bcastDict, ok := cb.m[blk.Header.Height] if !ok { bcastDict = newBcastDict(cb.delay) + cb.m[blk.Header.Height] = bcastDict } cb.lk.Unlock() - key := BCastKey(blk.Header) + key, err := BCastKey(blk.Header) + if err != nil { + return err + } blkCid := blk.Cid() - bcastDict.lk.Lock() - defer bcastDict.lk.Unlock() - bInfo, ok := bcastDict.blks[key] + bInfo, ok := bcastDict.load(key) if ok { if len(bInfo.blks) > 1 { return bInfo.eqErr() @@ -98,17 +112,18 @@ func (cb *ConsistentBCast) RcvBlock(ctx context.Context, blk *types.BlockMsg) er return nil } - ctx, cancel := context.WithTimeout(ctx, cb.delay) - bcastDict.blks[key] = &blksInfo{ctx, cancel, []cid.Cid{blkCid}} + ctx, cancel := context.WithTimeout(ctx, cb.delay*time.Second) + bcastDict.store(key, &blksInfo{ctx, cancel, []cid.Cid{blkCid}}) return nil } func (cb *ConsistentBCast) WaitForDelivery(bh *types.BlockHeader) error { bcastDict := cb.m[bh.Height] - key := BCastKey(bh) - bcastDict.lk.RLock() - defer bcastDict.lk.RUnlock() - bInfo, ok := bcastDict.blks[key] + key, err := BCastKey(bh) + if err != nil { + return err + } + bInfo, ok := bcastDict.load(key) if !ok { return fmt.Errorf("something went wrong, unknown block with Epoch + VRFProof (cid=%s) in consistent broadcast storage", key) } @@ -126,6 +141,10 @@ func (cb *ConsistentBCast) GarbageCollect(currEpoch abi.ChainEpoch) { // keep currEpoch-2 and delete a few more in the past // as a sanity-check + // Garbage collection is triggered before block delivery, + // and we use the sanity-check in case there were a few rounds + // without delivery, and the garbage collection wasn't triggered + // for a few epochs. for i := 0; i < GC_SANITY_CHECK; i++ { delete(cb.m, currEpoch-abi.ChainEpoch(2-i)) } diff --git a/chain/sub/bcast/consistent_test.go b/chain/sub/bcast/consistent_test.go index 2f3a9e4de..a2945b46b 100644 --- a/chain/sub/bcast/consistent_test.go +++ b/chain/sub/bcast/consistent_test.go @@ -1,27 +1,217 @@ package bcast_test import ( + "context" "crypto/rand" + "fmt" + mrand "math/rand" + "strconv" + "sync" "testing" + "time" + "github.com/filecoin-project/go-address" "github.com/filecoin-project/go-state-types/abi" + "github.com/filecoin-project/lotus/chain/sub/bcast" "github.com/filecoin-project/lotus/chain/types" + "github.com/ipfs/go-cid" + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" ) +const TEST_DELAY = 1 + func TestSimpleDelivery(t *testing.T) { + cb := bcast.NewConsistentBCast(TEST_DELAY) + // Check that we wait for delivery. + start := time.Now() + testSimpleDelivery(t, cb, 100, 5) + since := time.Since(start) + require.GreaterOrEqual(t, since, TEST_DELAY*time.Second) } -func newBlock(t *testing.T, epoch abi.ChainEpoch) *types.BlockMsg { +func testSimpleDelivery(t *testing.T, cb *bcast.ConsistentBCast, epoch abi.ChainEpoch, numBlocks int) { + ctx := context.Background() + + wg := new(sync.WaitGroup) + errs := make([]error, 0) + wg.Add(numBlocks) + for i := 0; i < numBlocks; i++ { + go func(i int) { + // Add a random delay in block reception + r := mrand.Intn(200) + time.Sleep(time.Duration(r) * time.Millisecond) + blk := newBlock(t, epoch, randomProof(t), []byte("test"+strconv.Itoa(i))) + cb.RcvBlock(ctx, blk) + err := cb.WaitForDelivery(blk.Header) + if err != nil { + errs = append(errs, err) + } + wg.Done() + }(i) + } + wg.Wait() + + for _, v := range errs { + t.Fatalf("error in delivery: %s", v) + } +} + +func TestSeveralEpochs(t *testing.T) { + cb := bcast.NewConsistentBCast(TEST_DELAY) + numEpochs := 5 + wg := new(sync.WaitGroup) + wg.Add(numEpochs) + for i := 0; i < numEpochs; i++ { + go func(i int) { + // Add a random delay between epochs + r := mrand.Intn(500) + time.Sleep(time.Duration(i*TEST_DELAY)*time.Second + time.Duration(r)*time.Millisecond) + rNumBlocks := mrand.Intn(5) + flip, err := flipCoin(0.7) + require.NoError(t, err) + t.Logf("Running epoch %d with %d with equivocation=%v", i, rNumBlocks, !flip) + if flip { + testSimpleDelivery(t, cb, abi.ChainEpoch(i), rNumBlocks) + } else { + testEquivocation(t, cb, abi.ChainEpoch(i), rNumBlocks) + } + wg.Done() + }(i) + } + wg.Wait() +} + +// bias is expected to be 0-1 +func flipCoin(bias float32) (bool, error) { + if bias > 1 || bias < 0 { + return false, fmt.Errorf("wrong bias. expected (0,1)") + } + r := mrand.Intn(100) + return r < int(bias*100), nil +} + +func testEquivocation(t *testing.T, cb *bcast.ConsistentBCast, epoch abi.ChainEpoch, numBlocks int) { + ctx := context.Background() + + wg := new(sync.WaitGroup) + errs := make([]error, 0) + wg.Add(numBlocks + 1) + for i := 0; i < numBlocks; i++ { + proof := randomProof(t) + // Valid blocks + go func(i int, proof []byte) { + r := mrand.Intn(200) + time.Sleep(time.Duration(r) * time.Millisecond) + blk := newBlock(t, 100, proof, []byte("valid"+strconv.Itoa(i))) + cb.RcvBlock(ctx, blk) + err := cb.WaitForDelivery(blk.Header) + if err != nil { + errs = append(errs, err) + } + wg.Done() + }(i, proof) + + // Equivocation for the last block + if i == numBlocks-1 { + // Attempting equivocation + go func(i int, proof []byte) { + // Use the same proof and the same epoch + blk := newBlock(t, 100, proof, []byte("invalid"+strconv.Itoa(i))) + cb.RcvBlock(ctx, blk) + err := cb.WaitForDelivery(blk.Header) + // Equivocation detected + require.Error(t, err) + wg.Done() + }(i, proof) + } + } + wg.Wait() + + // The equivocated block arrived too late, so + // we delivered all the valid blocks. + require.Len(t, errs, 1) +} + +func TestEquivocation(t *testing.T) { + cb := bcast.NewConsistentBCast(TEST_DELAY) + testEquivocation(t, cb, 100, 5) +} + +func TestFailedEquivocation(t *testing.T) { + cb := bcast.NewConsistentBCast(TEST_DELAY) + ctx := context.Background() + numBlocks := 5 + + wg := new(sync.WaitGroup) + errs := make([]error, 0) + wg.Add(numBlocks + 1) + for i := 0; i < numBlocks; i++ { + proof := randomProof(t) + // Valid blocks + go func(i int, proof []byte) { + r := mrand.Intn(200) + time.Sleep(time.Duration(r) * time.Millisecond) + blk := newBlock(t, 100, proof, []byte("valid"+strconv.Itoa(i))) + cb.RcvBlock(ctx, blk) + err := cb.WaitForDelivery(blk.Header) + if err != nil { + errs = append(errs, err) + } + wg.Done() + }(i, proof) + + // Equivocation for the last block + if i == numBlocks-1 { + // Attempting equivocation + go func(i int, proof []byte) { + // The equivocated block arrives late + time.Sleep(2 * TEST_DELAY * time.Second) + // Use the same proof and the same epoch + blk := newBlock(t, 100, proof, []byte("invalid"+strconv.Itoa(i))) + cb.RcvBlock(ctx, blk) + err := cb.WaitForDelivery(blk.Header) + // Equivocation detected + require.Error(t, err) + wg.Done() + }(i, proof) + } + } + wg.Wait() + + // The equivocated block arrived too late, so + // we delivered all the valid blocks. + require.Len(t, errs, 0) +} + +func randomProof(t *testing.T) []byte { proof := make([]byte, 10) _, err := rand.Read(proof) - if err != err { + if err != nil { + t.Fatal(err) + } + return proof +} + +func newBlock(t *testing.T, epoch abi.ChainEpoch, proof []byte, mCidSeed []byte) *types.BlockMsg { + h, err := multihash.Sum(mCidSeed, multihash.SHA2_256, -1) + if err != nil { + t.Fatal(err) + } + testCid := cid.NewCidV0(h) + addr, err := address.NewIDAddress(10) + if err != nil { t.Fatal(err) } bh := &types.BlockHeader{ + Miner: addr, + ParentStateRoot: testCid, + ParentMessageReceipts: testCid, Ticket: &types.Ticket{ - VRFProof: []byte("vrf proof0000000vrf proof0000000"), + VRFProof: proof, }, - Height: 85919298723, + Height: epoch, + Messages: testCid, } return &types.BlockMsg{ Header: bh,