diff --git a/chain/stmgr/utils.go b/chain/stmgr/utils.go index e6ba47665..ce9326e0a 100644 --- a/chain/stmgr/utils.go +++ b/chain/stmgr/utils.go @@ -426,7 +426,7 @@ func GetLookbackTipSetForRound(ctx context.Context, sm *StateManager, ts *types. return ts, nil } - lbts, err := sm.ChainStore().GetTipsetByHeight(ctx, lbr, ts) + lbts, err := sm.ChainStore().GetTipsetByHeight(ctx, lbr, ts, true) if err != nil { return nil, xerrors.Errorf("failed to get lookback tipset: %w", err) } diff --git a/chain/store/store.go b/chain/store/store.go index cde8add51..3985372c6 100644 --- a/chain/store/store.go +++ b/chain/store/store.go @@ -930,7 +930,11 @@ func (cs *ChainStore) GetRandomness(ctx context.Context, blks []cid.Cid, pers cr } } -func (cs *ChainStore) GetTipsetByHeight(ctx context.Context, h abi.ChainEpoch, ts *types.TipSet) (*types.TipSet, error) { +// GetTipsetByHeight returns the tipset on the chain behind 'ts' at the given +// height. In the case that the given height is a null round, the 'prev' flag +// selects the tipset before the null round if true, and the tipset following +// the null round if false. +func (cs *ChainStore) GetTipsetByHeight(ctx context.Context, h abi.ChainEpoch, ts *types.TipSet, prev bool) (*types.TipSet, error) { if ts == nil { ts = cs.GetHeaviestTipSet() } @@ -954,6 +958,9 @@ func (cs *ChainStore) GetTipsetByHeight(ctx context.Context, h abi.ChainEpoch, t } if h > pts.Height() { + if prev { + return pts, nil + } return ts, nil } if h == pts.Height() { diff --git a/cmd/lotus-bench/import.go b/cmd/lotus-bench/import.go index 8911856c3..c7062fba5 100644 --- a/cmd/lotus-bench/import.go +++ b/cmd/lotus-bench/import.go @@ -85,7 +85,7 @@ var importBenchCmd = &cli.Command{ } if h := cctx.Int64("height"); h != 0 { - tsh, err := cs.GetTipsetByHeight(context.TODO(), abi.ChainEpoch(h), head) + tsh, err := cs.GetTipsetByHeight(context.TODO(), abi.ChainEpoch(h), head, true) if err != nil { return err } diff --git a/node/impl/full/chain.go b/node/impl/full/chain.go index 076968987..8b00cfb34 100644 --- a/node/impl/full/chain.go +++ b/node/impl/full/chain.go @@ -174,7 +174,7 @@ func (a *ChainAPI) ChainGetTipSetByHeight(ctx context.Context, h abi.ChainEpoch, if err != nil { return nil, xerrors.Errorf("loading tipset %s: %w", tsk, err) } - return a.Chain.GetTipsetByHeight(ctx, h, ts) + return a.Chain.GetTipsetByHeight(ctx, h, ts, true) } func (a *ChainAPI) ChainReadObj(ctx context.Context, obj cid.Cid) ([]byte, error) {