diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index df2b1ed33..aa5f5900b 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -300,6 +300,12 @@ func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs m // Cap traverses downwards the snapshot tree from a head block hash until the // number of allowed layers are crossed. All layers beyond the permitted number // are flattened downwards. +// +// Note, the final diff layer count in general will be one more than the amount +// requested. This happens because the bottom-most diff layer is the accumulator +// which may or may not overflow and cascade to disk. Since this last layer's +// survival is only known *after* capping, we need to omit it from the count if +// we want to ensure that *at least* the requested number of diff layers remain. func (t *Tree) Cap(root common.Hash, layers int) error { // Retrieve the head snapshot to cap from snap := t.Snapshot(root) @@ -324,10 +330,7 @@ func (t *Tree) Cap(root common.Hash, layers int) error { // Flattening the bottom-most diff layer requires special casing since there's // no child to rewire to the grandparent. In that case we can fake a temporary // child for the capping and then remove it. - var persisted *diskLayer - - switch layers { - case 0: + if layers == 0 { // If full commit was requested, flatten the diffs and merge onto disk diff.lock.RLock() base := diffToDisk(diff.flatten().(*diffLayer)) @@ -336,33 +339,9 @@ func (t *Tree) Cap(root common.Hash, layers int) error { // Replace the entire snapshot tree with the flat base t.layers = map[common.Hash]snapshot{base.root: base} return nil - - case 1: - // If full flattening was requested, flatten the diffs but only merge if the - // memory limit was reached - var ( - bottom *diffLayer - base *diskLayer - ) - diff.lock.RLock() - bottom = diff.flatten().(*diffLayer) - if bottom.memory >= aggregatorMemoryLimit { - base = diffToDisk(bottom) - } - diff.lock.RUnlock() - - // If all diff layers were removed, replace the entire snapshot tree - if base != nil { - t.layers = map[common.Hash]snapshot{base.root: base} - return nil - } - // Merge the new aggregated layer into the snapshot tree, clean stales below - t.layers[bottom.root] = bottom - - default: - // Many layers requested to be retained, cap normally - persisted = t.cap(diff, layers) } + persisted := t.cap(diff, layers) + // Remove any layer that is stale or links into a stale layer children := make(map[common.Hash][]common.Hash) for root, snap := range t.layers { @@ -405,9 +384,15 @@ func (t *Tree) Cap(root common.Hash, layers int) error { // layer limit is reached, memory cap is also enforced (but not before). // // The method returns the new disk layer if diffs were persisted into it. +// +// Note, the final diff layer count in general will be one more than the amount +// requested. This happens because the bottom-most diff layer is the accumulator +// which may or may not overflow and cascade to disk. Since this last layer's +// survival is only known *after* capping, we need to omit it from the count if +// we want to ensure that *at least* the requested number of diff layers remain. func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer { // Dive until we run out of layers or reach the persistent database - for ; layers > 2; layers-- { + for i := 0; i < layers-1; i++ { // If we still have diff layers below, continue down if parent, ok := diff.parent.(*diffLayer); ok { diff = parent diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go index a315fd216..4b787cfe2 100644 --- a/core/state/snapshot/snapshot_test.go +++ b/core/state/snapshot/snapshot_test.go @@ -162,8 +162,8 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) { defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit) aggregatorMemoryLimit = 0 - if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil { - t.Fatalf("failed to merge diff layer onto disk: %v", err) + if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil { + t.Fatalf("failed to merge accumulator onto disk: %v", err) } // Since the base layer was modified, ensure that data retrievald on the external reference fail if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { @@ -178,53 +178,6 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) { } } -// Tests that if a diff layer becomes stale, no active external references will -// be returned with junk data. This version of the test flattens every diff layer -// to check internal corner case around the bottom-most memory accumulator. -func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) { - // Create an empty base layer and a snapshot tree out of it - base := &diskLayer{ - diskdb: rawdb.NewMemoryDatabase(), - root: common.HexToHash("0x01"), - cache: fastcache.New(1024 * 500), - } - snaps := &Tree{ - layers: map[common.Hash]snapshot{ - base.root: base, - }, - } - // Commit two diffs on top and retrieve a reference to the bottommost - accounts := map[common.Hash][]byte{ - common.HexToHash("0xa1"): randomAccount(), - } - if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { - t.Fatalf("failed to create a diff layer: %v", err) - } - if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { - t.Fatalf("failed to create a diff layer: %v", err) - } - if n := len(snaps.layers); n != 3 { - t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) - } - ref := snaps.Snapshot(common.HexToHash("0x02")) - - // Flatten the diff layer into the bottom accumulator - if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil { - t.Fatalf("failed to flatten diff layer into accumulator: %v", err) - } - // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail - if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { - t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) - } - if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { - t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) - } - if n := len(snaps.layers); n != 2 { - t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) - fmt.Println(snaps.layers) - } -} - // Tests that if a diff layer becomes stale, no active external references will // be returned with junk data. This version of the test retains the bottom diff // layer to check the usual mode of operation where the accumulator is retained. @@ -267,7 +220,7 @@ func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) { t.Errorf("layers modified, got %d exp %d", got, exp) } // Flatten the diff layer into the bottom accumulator - if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil { + if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil { t.Fatalf("failed to flatten diff layer into accumulator: %v", err) } // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail @@ -389,7 +342,7 @@ func TestSnaphots(t *testing.T) { // Create a starting base layer and a snapshot tree out of it base := &diskLayer{ diskdb: rawdb.NewMemoryDatabase(), - root: common.HexToHash("0x01"), + root: makeRoot(1), cache: fastcache.New(1024 * 500), } snaps := &Tree{ @@ -397,17 +350,16 @@ func TestSnaphots(t *testing.T) { base.root: base, }, } - // Construct the snapshots with 128 layers + // Construct the snapshots with 129 layers, flattening whatever's above that var ( last = common.HexToHash("0x01") head common.Hash ) - // Flush another 128 layers, one diff will be flatten into the parent. - for i := 0; i < 128; i++ { + for i := 0; i < 129; i++ { head = makeRoot(uint64(i + 2)) snaps.Update(head, last, nil, setAccount(fmt.Sprintf("%d", i+2)), nil) last = head - snaps.Cap(head, 128) // 129 layers(128 diffs + 1 disk) are allowed, 129th is the persistent layer + snaps.Cap(head, 128) // 130 layers (128 diffs + 1 accumulator + 1 disk) } var cases = []struct { headRoot common.Hash @@ -417,22 +369,57 @@ func TestSnaphots(t *testing.T) { expectBottom common.Hash }{ {head, 0, false, 0, common.Hash{}}, - {head, 64, false, 64, makeRoot(127 + 2 - 63)}, - {head, 128, false, 128, makeRoot(2)}, // All diff layers - {head, 129, true, 128, makeRoot(2)}, // All diff layers - {head, 129, false, 129, common.HexToHash("0x01")}, // All diff layers + disk layer + {head, 64, false, 64, makeRoot(129 + 2 - 64)}, + {head, 128, false, 128, makeRoot(3)}, // Normal diff layers, no accumulator + {head, 129, true, 129, makeRoot(2)}, // All diff layers, including accumulator + {head, 130, false, 130, makeRoot(1)}, // All diff layers + disk layer } - for _, c := range cases { + for i, c := range cases { layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk) if len(layers) != c.expected { - t.Fatalf("Returned snapshot layers are mismatched, want %v, got %v", c.expected, len(layers)) + t.Errorf("non-overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers)) } if len(layers) == 0 { continue } bottommost := layers[len(layers)-1] if bottommost.Root() != c.expectBottom { - t.Fatalf("Snapshot mismatch, want %v, get %v", c.expectBottom, bottommost.Root()) + t.Errorf("non-overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root()) + } + } + // Above we've tested the normal capping, which leaves the accumulator live. + // Test that if the bottommost accumulator diff layer overflows the allowed + // memory limit, the snapshot tree gets capped to one less layer. + // Commit the diff layer onto the disk and ensure it's persisted + defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit) + aggregatorMemoryLimit = 0 + + snaps.Cap(head, 128) // 129 (128 diffs + 1 overflown accumulator + 1 disk) + + cases = []struct { + headRoot common.Hash + limit int + nodisk bool + expected int + expectBottom common.Hash + }{ + {head, 0, false, 0, common.Hash{}}, + {head, 64, false, 64, makeRoot(129 + 2 - 64)}, + {head, 128, false, 128, makeRoot(3)}, // All diff layers, accumulator was flattened + {head, 129, true, 128, makeRoot(3)}, // All diff layers, accumulator was flattened + {head, 130, false, 129, makeRoot(2)}, // All diff layers + disk layer + } + for i, c := range cases { + layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk) + if len(layers) != c.expected { + t.Errorf("overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers)) + } + if len(layers) == 0 { + continue + } + bottommost := layers[len(layers)-1] + if bottommost.Root() != c.expectBottom { + t.Errorf("overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root()) } } }