From 8321fe2fda0b44d6df3750bcee28b8627525173b Mon Sep 17 00:00:00 2001 From: Martin HS Date: Wed, 14 Feb 2024 17:02:56 +0100 Subject: [PATCH] tests: fix goroutine leak related to state snapshot generation (#28974) --------- Co-authored-by: Felix Lange --- cmd/evm/staterunner.go | 18 +- core/state/snapshot/snapshot.go | 8 + .../internal/tracetest/calltrace_test.go | 22 +-- .../internal/tracetest/flat_calltrace_test.go | 6 +- .../internal/tracetest/prestate_test.go | 6 +- eth/tracers/tracers_test.go | 10 +- tests/state_test.go | 32 ++-- tests/state_test_util.go | 166 ++++++++++-------- 8 files changed, 144 insertions(+), 124 deletions(-) diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go index 6e751b630..458d809ad 100644 --- a/cmd/evm/staterunner.go +++ b/cmd/evm/staterunner.go @@ -25,7 +25,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/tracers/logger" "github.com/ethereum/go-ethereum/tests" @@ -90,26 +89,27 @@ func runStateTest(fname string, cfg vm.Config, jsonOut, dump bool) error { if err != nil { return err } - var tests map[string]tests.StateTest - if err := json.Unmarshal(src, &tests); err != nil { + var testsByName map[string]tests.StateTest + if err := json.Unmarshal(src, &testsByName); err != nil { return err } + // Iterate over all the tests, run them and aggregate the results - results := make([]StatetestResult, 0, len(tests)) - for key, test := range tests { + results := make([]StatetestResult, 0, len(testsByName)) + for key, test := range testsByName { for _, st := range test.Subtests() { // Run the test and aggregate the result result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true} - test.Run(st, cfg, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, statedb *state.StateDB) { + test.Run(st, cfg, false, rawdb.HashScheme, func(err error, tstate *tests.StateTestState) { var root common.Hash - if statedb != nil { - root = statedb.IntermediateRoot(false) + if tstate.StateDB != nil { + root = tstate.StateDB.IntermediateRoot(false) result.Root = &root if jsonOut { fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%#x\"}\n", root) } if dump { // Dump any state to aid debugging - cpy, _ := state.New(root, statedb.Database(), nil) + cpy, _ := state.New(root, tstate.StateDB.Database(), nil) dump := cpy.RawDump(nil) result.State = &dump } diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 58aa375db..5c38cb725 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -258,6 +258,14 @@ func (t *Tree) Disable() { for _, layer := range t.layers { switch layer := layer.(type) { case *diskLayer: + + layer.lock.RLock() + generating := layer.genMarker != nil + layer.lock.RUnlock() + if !generating { + // Generator is already aborted or finished + break + } // If the base layer is generating, abort it if layer.genAbort != nil { abort := make(chan *generatorStats) diff --git a/eth/tracers/internal/tracetest/calltrace_test.go b/eth/tracers/internal/tracetest/calltrace_test.go index 0b43a021e..5eb0240e8 100644 --- a/eth/tracers/internal/tracetest/calltrace_test.go +++ b/eth/tracers/internal/tracetest/calltrace_test.go @@ -133,9 +133,9 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) { GasLimit: uint64(test.Context.GasLimit), BaseFee: test.Genesis.BaseFee, } - triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) + state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) ) - triedb.Close() + state.Close() tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) if err != nil { @@ -145,7 +145,7 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) { if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } - evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) + evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer}) vmRet, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(tx.Gas())) if err != nil { t.Fatalf("failed to execute transaction: %v", err) @@ -235,8 +235,8 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { if err != nil { b.Fatalf("failed to prepare transaction for tracing: %v", err) } - triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) - defer triedb.Close() + state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) + defer state.Close() b.ReportAllocs() b.ResetTimer() @@ -245,8 +245,8 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { if err != nil { b.Fatalf("failed to create call tracer: %v", err) } - evm := vm.NewEVM(context, txContext, statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) - snap := statedb.Snapshot() + evm := vm.NewEVM(context, txContext, state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer}) + snap := state.StateDB.Snapshot() st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) if _, err = st.TransitionDb(); err != nil { b.Fatalf("failed to execute transaction: %v", err) @@ -254,7 +254,7 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) { if _, err = tracer.GetResult(); err != nil { b.Fatal(err) } - statedb.RevertToSnapshot(snap) + state.StateDB.RevertToSnapshot(snap) } } @@ -362,7 +362,7 @@ func TestInternals(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), + state := tests.MakePreState(rawdb.NewMemoryDatabase(), core.GenesisAlloc{ to: core.GenesisAccount{ Code: tc.code, @@ -371,9 +371,9 @@ func TestInternals(t *testing.T) { Balance: big.NewInt(500000000000000), }, }, false, rawdb.HashScheme) - defer triedb.Close() + defer state.Close() - evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer}) + evm := vm.NewEVM(context, txContext, state.StateDB, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer}) msg := &core.Message{ To: &to, From: origin, diff --git a/eth/tracers/internal/tracetest/flat_calltrace_test.go b/eth/tracers/internal/tracetest/flat_calltrace_test.go index b318548bc..abee48891 100644 --- a/eth/tracers/internal/tracetest/flat_calltrace_test.go +++ b/eth/tracers/internal/tracetest/flat_calltrace_test.go @@ -95,8 +95,8 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string Difficulty: (*big.Int)(test.Context.Difficulty), GasLimit: uint64(test.Context.GasLimit), } - triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) - defer triedb.Close() + state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) + defer state.Close() // Create the tracer, the EVM environment and run it tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) @@ -107,7 +107,7 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string if err != nil { return fmt.Errorf("failed to prepare transaction for tracing: %v", err) } - evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) + evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer}) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) if _, err = st.TransitionDb(); err != nil { diff --git a/eth/tracers/internal/tracetest/prestate_test.go b/eth/tracers/internal/tracetest/prestate_test.go index 666a5fda7..8a60123dc 100644 --- a/eth/tracers/internal/tracetest/prestate_test.go +++ b/eth/tracers/internal/tracetest/prestate_test.go @@ -103,9 +103,9 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) { GasLimit: uint64(test.Context.GasLimit), BaseFee: test.Genesis.BaseFee, } - triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) + state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) ) - defer triedb.Close() + defer state.Close() tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) if err != nil { @@ -115,7 +115,7 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) { if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } - evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) + evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer}) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) if _, err = st.TransitionDb(); err != nil { t.Fatalf("failed to execute transaction: %v", err) diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index b10f3503e..234013760 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -79,8 +79,8 @@ func BenchmarkTransactionTrace(b *testing.B) { Code: []byte{}, Balance: big.NewInt(500000000000000), } - triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme) - defer triedb.Close() + state := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme) + defer state.Close() // Create the tracer, the EVM environment and run it tracer := logger.NewStructLogger(&logger.Config{ @@ -89,7 +89,7 @@ func BenchmarkTransactionTrace(b *testing.B) { //EnableMemory: false, //EnableReturnData: false, }) - evm := vm.NewEVM(context, txContext, statedb, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer}) + evm := vm.NewEVM(context, txContext, state.StateDB, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer}) msg, err := core.TransactionToMessage(tx, signer, context.BaseFee) if err != nil { b.Fatalf("failed to prepare transaction for tracing: %v", err) @@ -98,13 +98,13 @@ func BenchmarkTransactionTrace(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - snap := statedb.Snapshot() + snap := state.StateDB.Snapshot() st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) _, err = st.TransitionDb() if err != nil { b.Fatal(err) } - statedb.RevertToSnapshot(snap) + state.StateDB.RevertToSnapshot(snap) if have, want := len(tracer.StructLogs()), 244752; have != want { b.Fatalf("trace wrong, want %d steps, have %d", want, have) } diff --git a/tests/state_test.go b/tests/state_test.go index 3a7e83ae3..4eddf5ec3 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -32,8 +32,6 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/tracers/logger" @@ -82,7 +80,7 @@ func TestState(t *testing.T) { t.Run(key+"/hash/trie", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { var result error - test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { + test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, state *StateTestState) { result = st.checkFailure(t, err) }) return result @@ -91,9 +89,9 @@ func TestState(t *testing.T) { t.Run(key+"/hash/snap", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { var result error - test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { - if snaps != nil && state != nil { - if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil { + test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, state *StateTestState) { + if state.Snapshots != nil && state.StateDB != nil { + if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil { result = err return } @@ -106,7 +104,7 @@ func TestState(t *testing.T) { t.Run(key+"/path/trie", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { var result error - test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { + test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, state *StateTestState) { result = st.checkFailure(t, err) }) return result @@ -115,9 +113,9 @@ func TestState(t *testing.T) { t.Run(key+"/path/snap", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { var result error - test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { - if snaps != nil && state != nil { - if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil { + test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, state *StateTestState) { + if state.Snapshots != nil && state.StateDB != nil { + if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil { result = err return } @@ -222,8 +220,8 @@ func runBenchmark(b *testing.B, t *StateTest) { vmconfig.ExtraEips = eips block := t.genesis(config).ToBlock() - triedb, _, statedb := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme) - defer triedb.Close() + state := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme) + defer state.Close() var baseFee *big.Int if rules.IsLondon { @@ -261,7 +259,7 @@ func runBenchmark(b *testing.B, t *StateTest) { context := core.NewEVMBlockContext(block.Header(), nil, &t.json.Env.Coinbase) context.GetHash = vmTestBlockHash context.BaseFee = baseFee - evm := vm.NewEVM(context, txContext, statedb, config, vmconfig) + evm := vm.NewEVM(context, txContext, state.StateDB, config, vmconfig) // Create "contract" for sender to cache code analysis. sender := vm.NewContract(vm.AccountRef(msg.From), vm.AccountRef(msg.From), @@ -274,8 +272,8 @@ func runBenchmark(b *testing.B, t *StateTest) { ) b.ResetTimer() for n := 0; n < b.N; n++ { - snapshot := statedb.Snapshot() - statedb.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList) + snapshot := state.StateDB.Snapshot() + state.StateDB.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList) b.StartTimer() start := time.Now() @@ -288,10 +286,10 @@ func runBenchmark(b *testing.B, t *StateTest) { b.StopTimer() elapsed += uint64(time.Since(start)) - refund += statedb.GetRefund() + refund += state.StateDB.GetRefund() gasUsed += msg.GasLimit - leftOverGas - statedb.RevertToSnapshot(snapshot) + state.StateDB.RevertToSnapshot(snapshot) } if elapsed < 1 { elapsed = 1 diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 92014ed82..56ddf61b6 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -194,20 +194,14 @@ func (t *StateTest) checkError(subtest StateSubtest, err error) error { } // Run executes a specific subtest and verifies the post-state and logs -func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string, postCheck func(err error, snaps *snapshot.Tree, state *state.StateDB)) (result error) { - triedb, snaps, statedb, root, err := t.RunNoVerify(subtest, vmconfig, snapshotter, scheme) - +func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string, postCheck func(err error, st *StateTestState)) (result error) { + st, root, err := t.RunNoVerify(subtest, vmconfig, snapshotter, scheme) // Invoke the callback at the end of function for further analysis. defer func() { - postCheck(result, snaps, statedb) - - if triedb != nil { - triedb.Close() - } - if snaps != nil { - snaps.Release() - } + postCheck(result, &st) + st.Close() }() + checkedErr := t.checkError(subtest, err) if checkedErr != nil { return checkedErr @@ -224,23 +218,24 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bo if root != common.Hash(post.Root) { return fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) } - if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { + if logs := rlpHash(st.StateDB.Logs()); logs != common.Hash(post.Logs) { return fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) } - statedb, _ = state.New(root, statedb.Database(), snaps) + st.StateDB, _ = state.New(root, st.StateDB.Database(), st.Snapshots) return nil } -// RunNoVerify runs a specific subtest and returns the statedb and post-state root -func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string) (*triedb.Database, *snapshot.Tree, *state.StateDB, common.Hash, error) { +// RunNoVerify runs a specific subtest and returns the statedb and post-state root. +// Remember to call state.Close after verifying the test result! +func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string) (state StateTestState, root common.Hash, err error) { config, eips, err := GetChainConfig(subtest.Fork) if err != nil { - return nil, nil, nil, common.Hash{}, UnsupportedForkError{subtest.Fork} + return state, common.Hash{}, UnsupportedForkError{subtest.Fork} } vmconfig.ExtraEips = eips block := t.genesis(config).ToBlock() - triedb, snaps, statedb := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, snapshotter, scheme) + state = MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, snapshotter, scheme) var baseFee *big.Int if config.IsLondon(new(big.Int)) { @@ -254,8 +249,18 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh post := t.json.Post[subtest.Fork][subtest.Index] msg, err := t.json.Tx.toMessage(post, baseFee) if err != nil { - triedb.Close() - return nil, nil, nil, common.Hash{}, err + return state, common.Hash{}, err + } + + { // Blob transactions may be present after the Cancun fork. + // In production, + // - the header is verified against the max in eip4844.go:VerifyEIP4844Header + // - the block body is verified against the header in block_validator.go:ValidateBody + // Here, we just do this shortcut smaller fix, since state tests do not + // utilize those codepaths + if len(msg.BlobHashes)*params.BlobTxBlobGasPerBlob > params.MaxBlobGasPerBlock { + return state, common.Hash{}, errors.New("blob gas exceeds maximum") + } } // Try to recover tx with current signer @@ -263,13 +268,10 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh var ttx types.Transaction err := ttx.UnmarshalBinary(post.TxBytes) if err != nil { - triedb.Close() - return nil, nil, nil, common.Hash{}, err + return state, common.Hash{}, err } - if _, err := types.Sender(types.LatestSigner(config), &ttx); err != nil { - triedb.Close() - return nil, nil, nil, common.Hash{}, err + return state, common.Hash{}, err } } @@ -290,78 +292,32 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh if config.IsCancun(new(big.Int), block.Time()) && t.json.Env.ExcessBlobGas != nil { context.BlobBaseFee = eip4844.CalcBlobFee(*t.json.Env.ExcessBlobGas) } - evm := vm.NewEVM(context, txContext, statedb, config, vmconfig) - - { // Blob transactions may be present after the Cancun fork. - // In production, - // - the header is verified against the max in eip4844.go:VerifyEIP4844Header - // - the block body is verified against the header in block_validator.go:ValidateBody - // Here, we just do this shortcut smaller fix, since state tests do not - // utilize those codepaths - if len(msg.BlobHashes)*params.BlobTxBlobGasPerBlob > params.MaxBlobGasPerBlock { - return nil, nil, nil, common.Hash{}, errors.New("blob gas exceeds maximum") - } - } + evm := vm.NewEVM(context, txContext, state.StateDB, config, vmconfig) // Execute the message. - snapshot := statedb.Snapshot() + snapshot := state.StateDB.Snapshot() gaspool := new(core.GasPool) gaspool.AddGas(block.GasLimit()) _, err = core.ApplyMessage(evm, msg, gaspool) if err != nil { - statedb.RevertToSnapshot(snapshot) + state.StateDB.RevertToSnapshot(snapshot) } // Add 0-value mining reward. This only makes a difference in the cases // where // - the coinbase self-destructed, or // - there are only 'bad' transactions, which aren't executed. In those cases, // the coinbase gets no txfee, so isn't created, and thus needs to be touched - statedb.AddBalance(block.Coinbase(), new(uint256.Int)) + state.StateDB.AddBalance(block.Coinbase(), new(uint256.Int)) // Commit state mutations into database. - root, _ := statedb.Commit(block.NumberU64(), config.IsEIP158(block.Number())) - return triedb, snaps, statedb, root, err + root, _ = state.StateDB.Commit(block.NumberU64(), config.IsEIP158(block.Number())) + return state, root, err } func (t *StateTest) gasLimit(subtest StateSubtest) uint64 { return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas] } -func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool, scheme string) (*triedb.Database, *snapshot.Tree, *state.StateDB) { - tconf := &triedb.Config{Preimages: true} - if scheme == rawdb.HashScheme { - tconf.HashDB = hashdb.Defaults - } else { - tconf.PathDB = pathdb.Defaults - } - triedb := triedb.NewDatabase(db, tconf) - sdb := state.NewDatabaseWithNodeDB(db, triedb) - statedb, _ := state.New(types.EmptyRootHash, sdb, nil) - for addr, a := range accounts { - statedb.SetCode(addr, a.Code) - statedb.SetNonce(addr, a.Nonce) - statedb.SetBalance(addr, uint256.MustFromBig(a.Balance)) - for k, v := range a.Storage { - statedb.SetState(addr, k, v) - } - } - // Commit and re-open to start with a clean state. - root, _ := statedb.Commit(0, false) - - var snaps *snapshot.Tree - if snapshotter { - snapconfig := snapshot.Config{ - CacheSize: 1, - Recovery: false, - NoBuild: false, - AsyncBuild: false, - } - snaps, _ = snapshot.New(snapconfig, db, triedb, root) - } - statedb, _ = state.New(root, sdb, snaps) - return triedb, snaps, statedb -} - func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis { genesis := &core.Genesis{ Config: config, @@ -478,3 +434,61 @@ func rlpHash(x interface{}) (h common.Hash) { func vmTestBlockHash(n uint64) common.Hash { return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) } + +// StateTestState groups all the state database objects together for use in tests. +type StateTestState struct { + StateDB *state.StateDB + TrieDB *triedb.Database + Snapshots *snapshot.Tree +} + +// MakePreState creates a state containing the given allocation. +func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool, scheme string) StateTestState { + tconf := &triedb.Config{Preimages: true} + if scheme == rawdb.HashScheme { + tconf.HashDB = hashdb.Defaults + } else { + tconf.PathDB = pathdb.Defaults + } + triedb := triedb.NewDatabase(db, tconf) + sdb := state.NewDatabaseWithNodeDB(db, triedb) + statedb, _ := state.New(types.EmptyRootHash, sdb, nil) + for addr, a := range accounts { + statedb.SetCode(addr, a.Code) + statedb.SetNonce(addr, a.Nonce) + statedb.SetBalance(addr, uint256.MustFromBig(a.Balance)) + for k, v := range a.Storage { + statedb.SetState(addr, k, v) + } + } + // Commit and re-open to start with a clean state. + root, _ := statedb.Commit(0, false) + + // If snapshot is requested, initialize the snapshotter and use it in state. + var snaps *snapshot.Tree + if snapshotter { + snapconfig := snapshot.Config{ + CacheSize: 1, + Recovery: false, + NoBuild: false, + AsyncBuild: false, + } + snaps, _ = snapshot.New(snapconfig, db, triedb, root) + } + statedb, _ = state.New(root, sdb, snaps) + return StateTestState{statedb, triedb, snaps} +} + +// Close should be called when the state is no longer needed, ie. after running the test. +func (st *StateTestState) Close() { + if st.TrieDB != nil { + st.TrieDB.Close() + st.TrieDB = nil + } + if st.Snapshots != nil { + // Need to call Disable here to quit the snapshot generator goroutine. + st.Snapshots.Disable() + st.Snapshots.Release() + st.Snapshots = nil + } +}