tests: fix goroutine leak related to state snapshot generation (#28974)

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Martin HS 2024-02-14 17:02:56 +01:00 committed by GitHub
parent 55a46c3b10
commit 8321fe2fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 144 additions and 124 deletions

View File

@ -25,7 +25,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers/logger" "github.com/ethereum/go-ethereum/eth/tracers/logger"
"github.com/ethereum/go-ethereum/tests" "github.com/ethereum/go-ethereum/tests"
@ -90,26 +89,27 @@ func runStateTest(fname string, cfg vm.Config, jsonOut, dump bool) error {
if err != nil { if err != nil {
return err return err
} }
var tests map[string]tests.StateTest var testsByName map[string]tests.StateTest
if err := json.Unmarshal(src, &tests); err != nil { if err := json.Unmarshal(src, &testsByName); err != nil {
return err return err
} }
// Iterate over all the tests, run them and aggregate the results // Iterate over all the tests, run them and aggregate the results
results := make([]StatetestResult, 0, len(tests)) results := make([]StatetestResult, 0, len(testsByName))
for key, test := range tests { for key, test := range testsByName {
for _, st := range test.Subtests() { for _, st := range test.Subtests() {
// Run the test and aggregate the result // Run the test and aggregate the result
result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true} result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true}
test.Run(st, cfg, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, statedb *state.StateDB) { test.Run(st, cfg, false, rawdb.HashScheme, func(err error, tstate *tests.StateTestState) {
var root common.Hash var root common.Hash
if statedb != nil { if tstate.StateDB != nil {
root = statedb.IntermediateRoot(false) root = tstate.StateDB.IntermediateRoot(false)
result.Root = &root result.Root = &root
if jsonOut { if jsonOut {
fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%#x\"}\n", root) fmt.Fprintf(os.Stderr, "{\"stateRoot\": \"%#x\"}\n", root)
} }
if dump { // Dump any state to aid debugging if dump { // Dump any state to aid debugging
cpy, _ := state.New(root, statedb.Database(), nil) cpy, _ := state.New(root, tstate.StateDB.Database(), nil)
dump := cpy.RawDump(nil) dump := cpy.RawDump(nil)
result.State = &dump result.State = &dump
} }

View File

@ -258,6 +258,14 @@ func (t *Tree) Disable() {
for _, layer := range t.layers { for _, layer := range t.layers {
switch layer := layer.(type) { switch layer := layer.(type) {
case *diskLayer: case *diskLayer:
layer.lock.RLock()
generating := layer.genMarker != nil
layer.lock.RUnlock()
if !generating {
// Generator is already aborted or finished
break
}
// If the base layer is generating, abort it // If the base layer is generating, abort it
if layer.genAbort != nil { if layer.genAbort != nil {
abort := make(chan *generatorStats) abort := make(chan *generatorStats)

View File

@ -133,9 +133,9 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) {
GasLimit: uint64(test.Context.GasLimit), GasLimit: uint64(test.Context.GasLimit),
BaseFee: test.Genesis.BaseFee, BaseFee: test.Genesis.BaseFee,
} }
triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
) )
triedb.Close() state.Close()
tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
if err != nil { if err != nil {
@ -145,7 +145,7 @@ func testCallTracer(tracerName string, dirPath string, t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err) t.Fatalf("failed to prepare transaction for tracing: %v", err)
} }
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
vmRet, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(tx.Gas())) vmRet, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if err != nil { if err != nil {
t.Fatalf("failed to execute transaction: %v", err) t.Fatalf("failed to execute transaction: %v", err)
@ -235,8 +235,8 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("failed to prepare transaction for tracing: %v", err) b.Fatalf("failed to prepare transaction for tracing: %v", err)
} }
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer triedb.Close() defer state.Close()
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
@ -245,8 +245,8 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("failed to create call tracer: %v", err) b.Fatalf("failed to create call tracer: %v", err)
} }
evm := vm.NewEVM(context, txContext, statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) evm := vm.NewEVM(context, txContext, state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
snap := statedb.Snapshot() snap := state.StateDB.Snapshot()
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if _, err = st.TransitionDb(); err != nil { if _, err = st.TransitionDb(); err != nil {
b.Fatalf("failed to execute transaction: %v", err) b.Fatalf("failed to execute transaction: %v", err)
@ -254,7 +254,7 @@ func benchTracer(tracerName string, test *callTracerTest, b *testing.B) {
if _, err = tracer.GetResult(); err != nil { if _, err = tracer.GetResult(); err != nil {
b.Fatal(err) b.Fatal(err)
} }
statedb.RevertToSnapshot(snap) state.StateDB.RevertToSnapshot(snap)
} }
} }
@ -362,7 +362,7 @@ func TestInternals(t *testing.T) {
}, },
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), state := tests.MakePreState(rawdb.NewMemoryDatabase(),
core.GenesisAlloc{ core.GenesisAlloc{
to: core.GenesisAccount{ to: core.GenesisAccount{
Code: tc.code, Code: tc.code,
@ -371,9 +371,9 @@ func TestInternals(t *testing.T) {
Balance: big.NewInt(500000000000000), Balance: big.NewInt(500000000000000),
}, },
}, false, rawdb.HashScheme) }, false, rawdb.HashScheme)
defer triedb.Close() defer state.Close()
evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer}) evm := vm.NewEVM(context, txContext, state.StateDB, params.MainnetChainConfig, vm.Config{Tracer: tc.tracer})
msg := &core.Message{ msg := &core.Message{
To: &to, To: &to,
From: origin, From: origin,

View File

@ -95,8 +95,8 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string
Difficulty: (*big.Int)(test.Context.Difficulty), Difficulty: (*big.Int)(test.Context.Difficulty),
GasLimit: uint64(test.Context.GasLimit), GasLimit: uint64(test.Context.GasLimit),
} }
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) state := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
defer triedb.Close() defer state.Close()
// Create the tracer, the EVM environment and run it // Create the tracer, the EVM environment and run it
tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
@ -107,7 +107,7 @@ func flatCallTracerTestRunner(tracerName string, filename string, dirPath string
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare transaction for tracing: %v", err) return fmt.Errorf("failed to prepare transaction for tracing: %v", err)
} }
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if _, err = st.TransitionDb(); err != nil { if _, err = st.TransitionDb(); err != nil {

View File

@ -103,9 +103,9 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) {
GasLimit: uint64(test.Context.GasLimit), GasLimit: uint64(test.Context.GasLimit),
BaseFee: test.Genesis.BaseFee, BaseFee: test.Genesis.BaseFee,
} }
triedb, _, statedb = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme) state = tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false, rawdb.HashScheme)
) )
defer triedb.Close() defer state.Close()
tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig) tracer, err := tracers.DefaultDirectory.New(tracerName, new(tracers.Context), test.TracerConfig)
if err != nil { if err != nil {
@ -115,7 +115,7 @@ func testPrestateDiffTracer(tracerName string, dirPath string, t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to prepare transaction for tracing: %v", err) t.Fatalf("failed to prepare transaction for tracing: %v", err)
} }
evm := vm.NewEVM(context, core.NewEVMTxContext(msg), statedb, test.Genesis.Config, vm.Config{Tracer: tracer}) evm := vm.NewEVM(context, core.NewEVMTxContext(msg), state.StateDB, test.Genesis.Config, vm.Config{Tracer: tracer})
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
if _, err = st.TransitionDb(); err != nil { if _, err = st.TransitionDb(); err != nil {
t.Fatalf("failed to execute transaction: %v", err) t.Fatalf("failed to execute transaction: %v", err)

View File

@ -79,8 +79,8 @@ func BenchmarkTransactionTrace(b *testing.B) {
Code: []byte{}, Code: []byte{},
Balance: big.NewInt(500000000000000), Balance: big.NewInt(500000000000000),
} }
triedb, _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme) state := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false, rawdb.HashScheme)
defer triedb.Close() defer state.Close()
// Create the tracer, the EVM environment and run it // Create the tracer, the EVM environment and run it
tracer := logger.NewStructLogger(&logger.Config{ tracer := logger.NewStructLogger(&logger.Config{
@ -89,7 +89,7 @@ func BenchmarkTransactionTrace(b *testing.B) {
//EnableMemory: false, //EnableMemory: false,
//EnableReturnData: false, //EnableReturnData: false,
}) })
evm := vm.NewEVM(context, txContext, statedb, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer}) evm := vm.NewEVM(context, txContext, state.StateDB, params.AllEthashProtocolChanges, vm.Config{Tracer: tracer})
msg, err := core.TransactionToMessage(tx, signer, context.BaseFee) msg, err := core.TransactionToMessage(tx, signer, context.BaseFee)
if err != nil { if err != nil {
b.Fatalf("failed to prepare transaction for tracing: %v", err) b.Fatalf("failed to prepare transaction for tracing: %v", err)
@ -98,13 +98,13 @@ func BenchmarkTransactionTrace(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
snap := statedb.Snapshot() snap := state.StateDB.Snapshot()
st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas()))
_, err = st.TransitionDb() _, err = st.TransitionDb()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
statedb.RevertToSnapshot(snap) state.StateDB.RevertToSnapshot(snap)
if have, want := len(tracer.StructLogs()), 244752; have != want { if have, want := len(tracer.StructLogs()), 244752; have != want {
b.Fatalf("trace wrong, want %d steps, have %d", want, have) b.Fatalf("trace wrong, want %d steps, have %d", want, have)
} }

View File

@ -32,8 +32,6 @@ import (
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers/logger" "github.com/ethereum/go-ethereum/eth/tracers/logger"
@ -82,7 +80,7 @@ func TestState(t *testing.T) {
t.Run(key+"/hash/trie", func(t *testing.T) { t.Run(key+"/hash/trie", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error var result error
test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { test.Run(subtest, vmconfig, false, rawdb.HashScheme, func(err error, state *StateTestState) {
result = st.checkFailure(t, err) result = st.checkFailure(t, err)
}) })
return result return result
@ -91,9 +89,9 @@ func TestState(t *testing.T) {
t.Run(key+"/hash/snap", func(t *testing.T) { t.Run(key+"/hash/snap", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error var result error
test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { test.Run(subtest, vmconfig, true, rawdb.HashScheme, func(err error, state *StateTestState) {
if snaps != nil && state != nil { if state.Snapshots != nil && state.StateDB != nil {
if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil { if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil {
result = err result = err
return return
} }
@ -106,7 +104,7 @@ func TestState(t *testing.T) {
t.Run(key+"/path/trie", func(t *testing.T) { t.Run(key+"/path/trie", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error var result error
test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { test.Run(subtest, vmconfig, false, rawdb.PathScheme, func(err error, state *StateTestState) {
result = st.checkFailure(t, err) result = st.checkFailure(t, err)
}) })
return result return result
@ -115,9 +113,9 @@ func TestState(t *testing.T) {
t.Run(key+"/path/snap", func(t *testing.T) { t.Run(key+"/path/snap", func(t *testing.T) {
withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error {
var result error var result error
test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, snaps *snapshot.Tree, state *state.StateDB) { test.Run(subtest, vmconfig, true, rawdb.PathScheme, func(err error, state *StateTestState) {
if snaps != nil && state != nil { if state.Snapshots != nil && state.StateDB != nil {
if _, err := snaps.Journal(state.IntermediateRoot(false)); err != nil { if _, err := state.Snapshots.Journal(state.StateDB.IntermediateRoot(false)); err != nil {
result = err result = err
return return
} }
@ -222,8 +220,8 @@ func runBenchmark(b *testing.B, t *StateTest) {
vmconfig.ExtraEips = eips vmconfig.ExtraEips = eips
block := t.genesis(config).ToBlock() block := t.genesis(config).ToBlock()
triedb, _, statedb := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme) state := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, false, rawdb.HashScheme)
defer triedb.Close() defer state.Close()
var baseFee *big.Int var baseFee *big.Int
if rules.IsLondon { if rules.IsLondon {
@ -261,7 +259,7 @@ func runBenchmark(b *testing.B, t *StateTest) {
context := core.NewEVMBlockContext(block.Header(), nil, &t.json.Env.Coinbase) context := core.NewEVMBlockContext(block.Header(), nil, &t.json.Env.Coinbase)
context.GetHash = vmTestBlockHash context.GetHash = vmTestBlockHash
context.BaseFee = baseFee context.BaseFee = baseFee
evm := vm.NewEVM(context, txContext, statedb, config, vmconfig) evm := vm.NewEVM(context, txContext, state.StateDB, config, vmconfig)
// Create "contract" for sender to cache code analysis. // Create "contract" for sender to cache code analysis.
sender := vm.NewContract(vm.AccountRef(msg.From), vm.AccountRef(msg.From), sender := vm.NewContract(vm.AccountRef(msg.From), vm.AccountRef(msg.From),
@ -274,8 +272,8 @@ func runBenchmark(b *testing.B, t *StateTest) {
) )
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
snapshot := statedb.Snapshot() snapshot := state.StateDB.Snapshot()
statedb.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList) state.StateDB.Prepare(rules, msg.From, context.Coinbase, msg.To, vm.ActivePrecompiles(rules), msg.AccessList)
b.StartTimer() b.StartTimer()
start := time.Now() start := time.Now()
@ -288,10 +286,10 @@ func runBenchmark(b *testing.B, t *StateTest) {
b.StopTimer() b.StopTimer()
elapsed += uint64(time.Since(start)) elapsed += uint64(time.Since(start))
refund += statedb.GetRefund() refund += state.StateDB.GetRefund()
gasUsed += msg.GasLimit - leftOverGas gasUsed += msg.GasLimit - leftOverGas
statedb.RevertToSnapshot(snapshot) state.StateDB.RevertToSnapshot(snapshot)
} }
if elapsed < 1 { if elapsed < 1 {
elapsed = 1 elapsed = 1

View File

@ -194,20 +194,14 @@ func (t *StateTest) checkError(subtest StateSubtest, err error) error {
} }
// Run executes a specific subtest and verifies the post-state and logs // Run executes a specific subtest and verifies the post-state and logs
func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string, postCheck func(err error, snaps *snapshot.Tree, state *state.StateDB)) (result error) { func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string, postCheck func(err error, st *StateTestState)) (result error) {
triedb, snaps, statedb, root, err := t.RunNoVerify(subtest, vmconfig, snapshotter, scheme) st, root, err := t.RunNoVerify(subtest, vmconfig, snapshotter, scheme)
// Invoke the callback at the end of function for further analysis. // Invoke the callback at the end of function for further analysis.
defer func() { defer func() {
postCheck(result, snaps, statedb) postCheck(result, &st)
st.Close()
if triedb != nil {
triedb.Close()
}
if snaps != nil {
snaps.Release()
}
}() }()
checkedErr := t.checkError(subtest, err) checkedErr := t.checkError(subtest, err)
if checkedErr != nil { if checkedErr != nil {
return checkedErr return checkedErr
@ -224,23 +218,24 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bo
if root != common.Hash(post.Root) { if root != common.Hash(post.Root) {
return fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) return fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root)
} }
if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { if logs := rlpHash(st.StateDB.Logs()); logs != common.Hash(post.Logs) {
return fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) return fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs)
} }
statedb, _ = state.New(root, statedb.Database(), snaps) st.StateDB, _ = state.New(root, st.StateDB.Database(), st.Snapshots)
return nil return nil
} }
// RunNoVerify runs a specific subtest and returns the statedb and post-state root // RunNoVerify runs a specific subtest and returns the statedb and post-state root.
func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string) (*triedb.Database, *snapshot.Tree, *state.StateDB, common.Hash, error) { // Remember to call state.Close after verifying the test result!
func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapshotter bool, scheme string) (state StateTestState, root common.Hash, err error) {
config, eips, err := GetChainConfig(subtest.Fork) config, eips, err := GetChainConfig(subtest.Fork)
if err != nil { if err != nil {
return nil, nil, nil, common.Hash{}, UnsupportedForkError{subtest.Fork} return state, common.Hash{}, UnsupportedForkError{subtest.Fork}
} }
vmconfig.ExtraEips = eips vmconfig.ExtraEips = eips
block := t.genesis(config).ToBlock() block := t.genesis(config).ToBlock()
triedb, snaps, statedb := MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, snapshotter, scheme) state = MakePreState(rawdb.NewMemoryDatabase(), t.json.Pre, snapshotter, scheme)
var baseFee *big.Int var baseFee *big.Int
if config.IsLondon(new(big.Int)) { if config.IsLondon(new(big.Int)) {
@ -254,8 +249,18 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh
post := t.json.Post[subtest.Fork][subtest.Index] post := t.json.Post[subtest.Fork][subtest.Index]
msg, err := t.json.Tx.toMessage(post, baseFee) msg, err := t.json.Tx.toMessage(post, baseFee)
if err != nil { if err != nil {
triedb.Close() return state, common.Hash{}, err
return nil, nil, nil, common.Hash{}, err }
{ // Blob transactions may be present after the Cancun fork.
// In production,
// - the header is verified against the max in eip4844.go:VerifyEIP4844Header
// - the block body is verified against the header in block_validator.go:ValidateBody
// Here, we just do this shortcut smaller fix, since state tests do not
// utilize those codepaths
if len(msg.BlobHashes)*params.BlobTxBlobGasPerBlob > params.MaxBlobGasPerBlock {
return state, common.Hash{}, errors.New("blob gas exceeds maximum")
}
} }
// Try to recover tx with current signer // Try to recover tx with current signer
@ -263,13 +268,10 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh
var ttx types.Transaction var ttx types.Transaction
err := ttx.UnmarshalBinary(post.TxBytes) err := ttx.UnmarshalBinary(post.TxBytes)
if err != nil { if err != nil {
triedb.Close() return state, common.Hash{}, err
return nil, nil, nil, common.Hash{}, err
} }
if _, err := types.Sender(types.LatestSigner(config), &ttx); err != nil { if _, err := types.Sender(types.LatestSigner(config), &ttx); err != nil {
triedb.Close() return state, common.Hash{}, err
return nil, nil, nil, common.Hash{}, err
} }
} }
@ -290,78 +292,32 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh
if config.IsCancun(new(big.Int), block.Time()) && t.json.Env.ExcessBlobGas != nil { if config.IsCancun(new(big.Int), block.Time()) && t.json.Env.ExcessBlobGas != nil {
context.BlobBaseFee = eip4844.CalcBlobFee(*t.json.Env.ExcessBlobGas) context.BlobBaseFee = eip4844.CalcBlobFee(*t.json.Env.ExcessBlobGas)
} }
evm := vm.NewEVM(context, txContext, statedb, config, vmconfig) evm := vm.NewEVM(context, txContext, state.StateDB, config, vmconfig)
{ // Blob transactions may be present after the Cancun fork.
// In production,
// - the header is verified against the max in eip4844.go:VerifyEIP4844Header
// - the block body is verified against the header in block_validator.go:ValidateBody
// Here, we just do this shortcut smaller fix, since state tests do not
// utilize those codepaths
if len(msg.BlobHashes)*params.BlobTxBlobGasPerBlob > params.MaxBlobGasPerBlock {
return nil, nil, nil, common.Hash{}, errors.New("blob gas exceeds maximum")
}
}
// Execute the message. // Execute the message.
snapshot := statedb.Snapshot() snapshot := state.StateDB.Snapshot()
gaspool := new(core.GasPool) gaspool := new(core.GasPool)
gaspool.AddGas(block.GasLimit()) gaspool.AddGas(block.GasLimit())
_, err = core.ApplyMessage(evm, msg, gaspool) _, err = core.ApplyMessage(evm, msg, gaspool)
if err != nil { if err != nil {
statedb.RevertToSnapshot(snapshot) state.StateDB.RevertToSnapshot(snapshot)
} }
// Add 0-value mining reward. This only makes a difference in the cases // Add 0-value mining reward. This only makes a difference in the cases
// where // where
// - the coinbase self-destructed, or // - the coinbase self-destructed, or
// - there are only 'bad' transactions, which aren't executed. In those cases, // - there are only 'bad' transactions, which aren't executed. In those cases,
// the coinbase gets no txfee, so isn't created, and thus needs to be touched // the coinbase gets no txfee, so isn't created, and thus needs to be touched
statedb.AddBalance(block.Coinbase(), new(uint256.Int)) state.StateDB.AddBalance(block.Coinbase(), new(uint256.Int))
// Commit state mutations into database. // Commit state mutations into database.
root, _ := statedb.Commit(block.NumberU64(), config.IsEIP158(block.Number())) root, _ = state.StateDB.Commit(block.NumberU64(), config.IsEIP158(block.Number()))
return triedb, snaps, statedb, root, err return state, root, err
} }
func (t *StateTest) gasLimit(subtest StateSubtest) uint64 { func (t *StateTest) gasLimit(subtest StateSubtest) uint64 {
return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas] return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas]
} }
func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool, scheme string) (*triedb.Database, *snapshot.Tree, *state.StateDB) {
tconf := &triedb.Config{Preimages: true}
if scheme == rawdb.HashScheme {
tconf.HashDB = hashdb.Defaults
} else {
tconf.PathDB = pathdb.Defaults
}
triedb := triedb.NewDatabase(db, tconf)
sdb := state.NewDatabaseWithNodeDB(db, triedb)
statedb, _ := state.New(types.EmptyRootHash, sdb, nil)
for addr, a := range accounts {
statedb.SetCode(addr, a.Code)
statedb.SetNonce(addr, a.Nonce)
statedb.SetBalance(addr, uint256.MustFromBig(a.Balance))
for k, v := range a.Storage {
statedb.SetState(addr, k, v)
}
}
// Commit and re-open to start with a clean state.
root, _ := statedb.Commit(0, false)
var snaps *snapshot.Tree
if snapshotter {
snapconfig := snapshot.Config{
CacheSize: 1,
Recovery: false,
NoBuild: false,
AsyncBuild: false,
}
snaps, _ = snapshot.New(snapconfig, db, triedb, root)
}
statedb, _ = state.New(root, sdb, snaps)
return triedb, snaps, statedb
}
func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis { func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis {
genesis := &core.Genesis{ genesis := &core.Genesis{
Config: config, Config: config,
@ -478,3 +434,61 @@ func rlpHash(x interface{}) (h common.Hash) {
func vmTestBlockHash(n uint64) common.Hash { func vmTestBlockHash(n uint64) common.Hash {
return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
} }
// StateTestState groups all the state database objects together for use in tests.
type StateTestState struct {
StateDB *state.StateDB
TrieDB *triedb.Database
Snapshots *snapshot.Tree
}
// MakePreState creates a state containing the given allocation.
func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool, scheme string) StateTestState {
tconf := &triedb.Config{Preimages: true}
if scheme == rawdb.HashScheme {
tconf.HashDB = hashdb.Defaults
} else {
tconf.PathDB = pathdb.Defaults
}
triedb := triedb.NewDatabase(db, tconf)
sdb := state.NewDatabaseWithNodeDB(db, triedb)
statedb, _ := state.New(types.EmptyRootHash, sdb, nil)
for addr, a := range accounts {
statedb.SetCode(addr, a.Code)
statedb.SetNonce(addr, a.Nonce)
statedb.SetBalance(addr, uint256.MustFromBig(a.Balance))
for k, v := range a.Storage {
statedb.SetState(addr, k, v)
}
}
// Commit and re-open to start with a clean state.
root, _ := statedb.Commit(0, false)
// If snapshot is requested, initialize the snapshotter and use it in state.
var snaps *snapshot.Tree
if snapshotter {
snapconfig := snapshot.Config{
CacheSize: 1,
Recovery: false,
NoBuild: false,
AsyncBuild: false,
}
snaps, _ = snapshot.New(snapconfig, db, triedb, root)
}
statedb, _ = state.New(root, sdb, snaps)
return StateTestState{statedb, triedb, snaps}
}
// Close should be called when the state is no longer needed, ie. after running the test.
func (st *StateTestState) Close() {
if st.TrieDB != nil {
st.TrieDB.Close()
st.TrieDB = nil
}
if st.Snapshots != nil {
// Need to call Disable here to quit the snapshot generator goroutine.
st.Snapshots.Disable()
st.Snapshots.Release()
st.Snapshots = nil
}
}