From 8f309b214b3c67237fa885276a69f3d1d632e4b6 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Wed, 28 Apr 2021 12:49:21 -0700 Subject: [PATCH] chain: move checkpoint logic into chainstore That way, checkpoints can be enforced by the chainstore, removing a potential race where an in-progress sync of a fork could bypass a sync checkpoint. --- chain/checkpoint.go | 83 +++++++---------------- chain/store/checkpoint_test.go | 89 ++++++++++++++++++++++++ chain/store/store.go | 119 ++++++++++++++++++++++++++++++++- chain/sync.go | 33 ++++----- chain/sync_test.go | 10 +++ 5 files changed, 251 insertions(+), 83 deletions(-) create mode 100644 chain/store/checkpoint_test.go diff --git a/chain/checkpoint.go b/chain/checkpoint.go index 4718b294d..a3660a45c 100644 --- a/chain/checkpoint.go +++ b/chain/checkpoint.go @@ -2,41 +2,12 @@ package chain import ( "context" - "encoding/json" "github.com/filecoin-project/lotus/chain/types" - "github.com/filecoin-project/lotus/node/modules/dtypes" - "github.com/ipfs/go-datastore" "golang.org/x/xerrors" ) -var CheckpointKey = datastore.NewKey("/chain/checks") - -func loadCheckpoint(ds dtypes.MetadataDS) (types.TipSetKey, error) { - haveChks, err := ds.Has(CheckpointKey) - if err != nil { - return types.EmptyTSK, err - } - - if !haveChks { - return types.EmptyTSK, nil - } - - tskBytes, err := ds.Get(CheckpointKey) - if err != nil { - return types.EmptyTSK, err - } - - var tsk types.TipSetKey - err = json.Unmarshal(tskBytes, &tsk) - if err != nil { - return types.EmptyTSK, err - } - - return tsk, err -} - func (syncer *Syncer) SyncCheckpoint(ctx context.Context, tsk types.TipSetKey) error { if tsk == types.EmptyTSK { return xerrors.Errorf("called with empty tsk") @@ -53,42 +24,34 @@ func (syncer *Syncer) SyncCheckpoint(ctx context.Context, tsk types.TipSetKey) e ts = tss[0] } - hts := syncer.ChainStore().GetHeaviestTipSet() - anc, err := syncer.ChainStore().IsAncestorOf(ts, hts) - if err != nil { - return xerrors.Errorf("cannot determine whether checkpoint tipset is in main-chain: %w", err) - } - if !hts.Equals(ts) && !anc { - if err := syncer.collectChain(ctx, ts, hts); err != nil { - return xerrors.Errorf("failed to collect chain for checkpoint: %w", err) - } - if err := syncer.ChainStore().SetHead(ts); err != nil { - return xerrors.Errorf("failed to set the chain head: %w", err) - } + if err := syncer.switchChain(ctx, ts); err != nil { + return xerrors.Errorf("failed to switch chain when syncing checkpoint: %w", err) } - syncer.checkptLk.Lock() - defer syncer.checkptLk.Unlock() - - tskBytes, err := json.Marshal(tsk) - if err != nil { - return err + if err := syncer.ChainStore().SetCheckpoint(ts); err != nil { + return xerrors.Errorf("failed to set the chain checkpoint: %w", err) } - err = syncer.ds.Put(CheckpointKey, tskBytes) - if err != nil { - return err - } - - // TODO: This is racy. as there may be a concurrent sync in progress. - // The only real solution is to checkpoint inside the chainstore, not here. - syncer.checkpt = tsk - return nil } -func (syncer *Syncer) GetCheckpoint() types.TipSetKey { - syncer.checkptLk.Lock() - defer syncer.checkptLk.Unlock() - return syncer.checkpt +func (syncer *Syncer) switchChain(ctx context.Context, ts *types.TipSet) error { + hts := syncer.ChainStore().GetHeaviestTipSet() + if hts.Equals(ts) { + return nil + } + + if anc, err := syncer.store.IsAncestorOf(ts, hts); err == nil && anc { + return nil + } + + // Otherwise, sync the chain and set the head. + if err := syncer.collectChain(ctx, ts, hts, true); err != nil { + return xerrors.Errorf("failed to collect chain for checkpoint: %w", err) + } + + if err := syncer.ChainStore().SetHead(ts); err != nil { + return xerrors.Errorf("failed to set the chain head: %w", err) + } + return nil } diff --git a/chain/store/checkpoint_test.go b/chain/store/checkpoint_test.go new file mode 100644 index 000000000..320b76797 --- /dev/null +++ b/chain/store/checkpoint_test.go @@ -0,0 +1,89 @@ +package store_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/filecoin-project/lotus/chain/gen" +) + +func TestChainCheckpoint(t *testing.T) { + cg, err := gen.NewGenerator() + if err != nil { + t.Fatal(err) + } + + // Let the first miner mine some blocks. + last := cg.CurTipset.TipSet() + for i := 0; i < 4; i++ { + ts, err := cg.NextTipSetFromMiners(last, cg.Miners[:1]) + require.NoError(t, err) + + last = ts.TipSet.TipSet() + } + + cs := cg.ChainStore() + + checkpoint := last + checkpointParents, err := cs.GetTipSetFromKey(checkpoint.Parents()) + require.NoError(t, err) + + // Set the head to the block before the checkpoint. + err = cs.SetHead(checkpointParents) + require.NoError(t, err) + + // Verify it worked. + head := cs.GetHeaviestTipSet() + require.True(t, head.Equals(checkpointParents)) + + // Try to set the checkpoint in the future, it should fail. + err = cs.SetCheckpoint(checkpoint) + require.Error(t, err) + + // Then move the head back. + err = cs.SetHead(checkpoint) + require.NoError(t, err) + + // Verify it worked. + head = cs.GetHeaviestTipSet() + require.True(t, head.Equals(checkpoint)) + + // And checkpoint it. + err = cs.SetCheckpoint(checkpoint) + require.NoError(t, err) + + // Let the second miner miner mine a fork + last = checkpointParents + for i := 0; i < 4; i++ { + ts, err := cg.NextTipSetFromMiners(last, cg.Miners[1:]) + require.NoError(t, err) + + last = ts.TipSet.TipSet() + } + + // See if the chain will take the fork, it shouldn't. + err = cs.MaybeTakeHeavierTipSet(context.Background(), last) + require.NoError(t, err) + head = cs.GetHeaviestTipSet() + require.True(t, head.Equals(checkpoint)) + + // Remove the checkpoint. + err = cs.RemoveCheckpoint() + require.NoError(t, err) + + // Now switch to the other fork. + err = cs.MaybeTakeHeavierTipSet(context.Background(), last) + require.NoError(t, err) + head = cs.GetHeaviestTipSet() + require.True(t, head.Equals(last)) + + // Setting a checkpoint on the other fork should fail. + err = cs.SetCheckpoint(checkpoint) + require.Error(t, err) + + // Setting a checkpoint on this fork should succeed. + err = cs.SetCheckpoint(checkpointParents) + require.NoError(t, err) +} diff --git a/chain/store/store.go b/chain/store/store.go index 7ebe31ec4..1e78ce73d 100644 --- a/chain/store/store.go +++ b/chain/store/store.go @@ -54,8 +54,11 @@ import ( var log = logging.Logger("chainstore") -var chainHeadKey = dstore.NewKey("head") -var blockValidationCacheKeyPrefix = dstore.NewKey("blockValidation") +var ( + chainHeadKey = dstore.NewKey("head") + checkpointKey = dstore.NewKey("/chain/checks") + blockValidationCacheKeyPrefix = dstore.NewKey("blockValidation") +) var DefaultTipSetCacheSize = 8192 var DefaultMsgMetaCacheSize = 2048 @@ -115,6 +118,7 @@ type ChainStore struct { heaviestLk sync.RWMutex heaviest *types.TipSet + checkpoint *types.TipSet bestTips *pubsub.PubSub pubLk sync.Mutex @@ -215,6 +219,15 @@ func (cs *ChainStore) Close() error { } func (cs *ChainStore) Load() error { + if err := cs.loadHead(); err != nil { + return err + } + if err := cs.loadCheckpoint(); err != nil { + return err + } + return nil +} +func (cs *ChainStore) loadHead() error { head, err := cs.metadataDs.Get(chainHeadKey) if err == dstore.ErrNotFound { log.Warn("no previous chain state found") @@ -239,6 +252,31 @@ func (cs *ChainStore) Load() error { return nil } +func (cs *ChainStore) loadCheckpoint() error { + tskBytes, err := cs.metadataDs.Get(checkpointKey) + if err == dstore.ErrNotFound { + return nil + } + if err != nil { + return xerrors.Errorf("failed to load checkpoint from datastore: %w", err) + } + + var tsk types.TipSetKey + err = json.Unmarshal(tskBytes, &tsk) + if err != nil { + return err + } + + ts, err := cs.LoadTipSet(tsk) + if err != nil { + return xerrors.Errorf("loading tipset: %w", err) + } + + cs.checkpoint = ts + + return nil +} + func (cs *ChainStore) writeHead(ts *types.TipSet) error { data, err := json.Marshal(ts.Cids()) if err != nil { @@ -439,6 +477,11 @@ func (cs *ChainStore) exceedsForkLength(synced, external *types.TipSet) (bool, e return false, nil } + // Now check to see if we've walked back to the checkpoint. + if synced.Equals(cs.checkpoint) { + return true, nil + } + // If we didn't, go back *one* tipset on the `synced` side (incrementing // the `forkLength`). if synced.Height() == 0 { @@ -467,6 +510,9 @@ func (cs *ChainStore) ForceHeadSilent(_ context.Context, ts *types.TipSet) error cs.heaviestLk.Lock() defer cs.heaviestLk.Unlock() + if err := cs.removeCheckpoint(); err != nil { + return err + } cs.heaviest = ts err := cs.writeHead(ts) @@ -642,13 +688,80 @@ func FlushValidationCache(ds datastore.Batching) error { } // SetHead sets the chainstores current 'best' head node. -// This should only be called if something is broken and needs fixing +// This should only be called if something is broken and needs fixing. +// +// This function will bypass and remove any checkpoints. func (cs *ChainStore) SetHead(ts *types.TipSet) error { cs.heaviestLk.Lock() defer cs.heaviestLk.Unlock() + if err := cs.removeCheckpoint(); err != nil { + return err + } return cs.takeHeaviestTipSet(context.TODO(), ts) } +// RemoveCheckpoint removes the current checkpoint. +func (cs *ChainStore) RemoveCheckpoint() error { + cs.heaviestLk.Lock() + defer cs.heaviestLk.Unlock() + return cs.removeCheckpoint() +} + +func (cs *ChainStore) removeCheckpoint() error { + if err := cs.metadataDs.Delete(checkpointKey); err != nil { + return err + } + cs.checkpoint = nil + return nil +} + +// SetCheckpoint will set a checkpoint past which the chainstore will not allow forks. +// +// NOTE: Checkpoints cannot be set beyond ForkLengthThreshold epochs in the past. +func (cs *ChainStore) SetCheckpoint(ts *types.TipSet) error { + tskBytes, err := json.Marshal(ts.Key()) + if err != nil { + return err + } + + cs.heaviestLk.Lock() + defer cs.heaviestLk.Unlock() + + if ts.Height() > cs.heaviest.Height() { + return xerrors.Errorf("cannot set a checkpoint in the future") + } + + // Otherwise, this operation could get _very_ expensive. + if cs.heaviest.Height()-ts.Height() > build.ForkLengthThreshold { + return xerrors.Errorf("cannot set a checkpoint before the fork threshold") + } + + if !ts.Equals(cs.heaviest) { + anc, err := cs.IsAncestorOf(ts, cs.heaviest) + if err != nil { + return xerrors.Errorf("cannot determine whether checkpoint tipset is in main-chain: %w", err) + } + + if !anc { + return xerrors.Errorf("cannot mark tipset as checkpoint, since it isn't in the main-chain: %w", err) + } + } + err = cs.metadataDs.Put(checkpointKey, tskBytes) + if err != nil { + return err + } + + cs.checkpoint = ts + return nil +} + +func (cs *ChainStore) GetCheckpoint() *types.TipSet { + cs.heaviestLk.RLock() + chkpt := cs.checkpoint + cs.heaviestLk.RUnlock() + return chkpt +} + // Contains returns whether our BlockStore has all blocks in the supplied TipSet. func (cs *ChainStore) Contains(ts *types.TipSet) (bool, error) { for _, c := range ts.Cids() { diff --git a/chain/sync.go b/chain/sync.go index 66c9c18bd..6f594024d 100644 --- a/chain/sync.go +++ b/chain/sync.go @@ -131,10 +131,6 @@ type Syncer struct { tickerCtxCancel context.CancelFunc - checkptLk sync.Mutex - - checkpt types.TipSetKey - ds dtypes.MetadataDS } @@ -152,14 +148,8 @@ func NewSyncer(ds dtypes.MetadataDS, sm *stmgr.StateManager, exchange exchange.C return nil, err } - cp, err := loadCheckpoint(ds) - if err != nil { - return nil, xerrors.Errorf("error loading mpool config: %w", err) - } - s := &Syncer{ ds: ds, - checkpt: cp, beacon: beacon, bad: NewBadBlockCache(), Genesis: gent, @@ -561,7 +551,7 @@ func (syncer *Syncer) Sync(ctx context.Context, maybeHead *types.TipSet) error { return nil } - if err := syncer.collectChain(ctx, maybeHead, hts); err != nil { + if err := syncer.collectChain(ctx, maybeHead, hts, false); err != nil { span.AddAttributes(trace.StringAttribute("col_error", err.Error())) span.SetStatus(trace.Status{ Code: 13, @@ -1247,7 +1237,7 @@ func extractSyncState(ctx context.Context) *SyncerState { // // All throughout the process, we keep checking if the received blocks are in // the deny list, and short-circuit the process if so. -func (syncer *Syncer) collectHeaders(ctx context.Context, incoming *types.TipSet, known *types.TipSet) ([]*types.TipSet, error) { +func (syncer *Syncer) collectHeaders(ctx context.Context, incoming *types.TipSet, known *types.TipSet, ignoreCheckpoint bool) ([]*types.TipSet, error) { ctx, span := trace.StartSpan(ctx, "collectHeaders") defer span.End() ss := extractSyncState(ctx) @@ -1416,7 +1406,7 @@ loop: // We have now ascertained that this is *not* a 'fast forward' log.Warnf("(fork detected) synced header chain (%s - %d) does not link to our best block (%s - %d)", incoming.Cids(), incoming.Height(), known.Cids(), known.Height()) - fork, err := syncer.syncFork(ctx, base, known) + fork, err := syncer.syncFork(ctx, base, known, ignoreCheckpoint) if err != nil { if xerrors.Is(err, ErrForkTooLong) || xerrors.Is(err, ErrForkCheckpoint) { // TODO: we're marking this block bad in the same way that we mark invalid blocks bad. Maybe distinguish? @@ -1442,11 +1432,14 @@ var ErrForkCheckpoint = fmt.Errorf("fork would require us to diverge from checkp // If the fork is too long (build.ForkLengthThreshold), or would cause us to diverge from the checkpoint (ErrForkCheckpoint), // we add the entire subchain to the denylist. Else, we find the common ancestor, and add the missing chain // fragment until the fork point to the returned []TipSet. -func (syncer *Syncer) syncFork(ctx context.Context, incoming *types.TipSet, known *types.TipSet) ([]*types.TipSet, error) { +func (syncer *Syncer) syncFork(ctx context.Context, incoming *types.TipSet, known *types.TipSet, ignoreCheckpoint bool) ([]*types.TipSet, error) { - chkpt := syncer.GetCheckpoint() - if known.Key() == chkpt { - return nil, ErrForkCheckpoint + var chkpt *types.TipSet + if !ignoreCheckpoint { + chkpt = syncer.store.GetCheckpoint() + if known.Equals(chkpt) { + return nil, ErrForkCheckpoint + } } // TODO: Does this mean we always ask for ForkLengthThreshold blocks from the network, even if we just need, like, 2? Yes. @@ -1488,7 +1481,7 @@ func (syncer *Syncer) syncFork(ctx context.Context, incoming *types.TipSet, know } // We will be forking away from nts, check that it isn't checkpointed - if nts.Key() == chkpt { + if nts.Equals(chkpt) { return nil, ErrForkCheckpoint } @@ -1699,14 +1692,14 @@ func persistMessages(ctx context.Context, bs bstore.Blockstore, bst *exchange.Co // // 3. StageMessages: having acquired the headers and found a common tipset, // we then move forward, requesting the full blocks, including the messages. -func (syncer *Syncer) collectChain(ctx context.Context, ts *types.TipSet, hts *types.TipSet) error { +func (syncer *Syncer) collectChain(ctx context.Context, ts *types.TipSet, hts *types.TipSet, ignoreCheckpoint bool) error { ctx, span := trace.StartSpan(ctx, "collectChain") defer span.End() ss := extractSyncState(ctx) ss.Init(hts, ts) - headers, err := syncer.collectHeaders(ctx, ts, hts) + headers, err := syncer.collectHeaders(ctx, ts, hts, ignoreCheckpoint) if err != nil { ss.Error(err) return err diff --git a/chain/sync_test.go b/chain/sync_test.go index fb2528c59..3176d9ec3 100644 --- a/chain/sync_test.go +++ b/chain/sync_test.go @@ -793,6 +793,11 @@ func TestSyncCheckpointHead(t *testing.T) { p1Head := tu.getHead(p1) require.True(tu.t, p1Head.Equals(a.TipSet())) tu.assertBad(p1, b.TipSet()) + + // Should be able to switch forks. + tu.checkpointTs(p1, b.TipSet().Key()) + p1Head = tu.getHead(p1) + require.True(tu.t, p1Head.Equals(b.TipSet())) } func TestSyncCheckpointEarlierThanHead(t *testing.T) { @@ -835,4 +840,9 @@ func TestSyncCheckpointEarlierThanHead(t *testing.T) { p1Head := tu.getHead(p1) require.True(tu.t, p1Head.Equals(a.TipSet())) tu.assertBad(p1, b.TipSet()) + + // Should be able to switch forks. + tu.checkpointTs(p1, b.TipSet().Key()) + p1Head = tu.getHead(p1) + require.True(tu.t, p1Head.Equals(b.TipSet())) }