diff --git a/core/state/statedb.go b/core/state/statedb.go index 3f96e8707..6e9796e70 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -804,6 +804,9 @@ func (s *StateDB) clearJournalAndRefund() { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { + if s.dbErr != nil { + return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) + } // Finalize any pending changes and merge everything into the tries s.IntermediateRoot(deleteEmptyObjects) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 1cb73f015..be2463535 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -680,3 +680,50 @@ func TestDeleteCreateRevert(t *testing.T) { t.Fatalf("self-destructed contract came alive") } } + +// TestMissingTrieNodes tests that if the statedb fails to load parts of the trie, +// the Commit operation fails with an error +// If we are missing trie nodes, we should not continue writing to the trie +func TestMissingTrieNodes(t *testing.T) { + + // Create an initial state with a few accounts + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + var root common.Hash + state, _ := New(common.Hash{}, db, nil) + addr := toAddr([]byte("so")) + { + state.SetBalance(addr, big.NewInt(1)) + state.SetCode(addr, []byte{1, 2, 3}) + a2 := toAddr([]byte("another")) + state.SetBalance(a2, big.NewInt(100)) + state.SetCode(a2, []byte{1, 2, 4}) + root, _ = state.Commit(false) + t.Logf("root: %x", root) + // force-flush + state.Database().TrieDB().Cap(0) + } + // Create a new state on the old root + state, _ = New(root, db, nil) + // Now we clear out the memdb + it := memDb.NewIterator(nil, nil) + for it.Next() { + k := it.Key() + // Leave the root intact + if !bytes.Equal(k, root[:]) { + t.Logf("key: %x", k) + memDb.Delete(k) + } + } + balance := state.GetBalance(addr) + // The removed elem should lead to it returning zero balance + if exp, got := uint64(0), balance.Uint64(); got != exp { + t.Errorf("expected %d, got %d", exp, got) + } + // Modify the state + state.SetBalance(addr, big.NewInt(2)) + root, err := state.Commit(false) + if err == nil { + t.Fatalf("expected error, got root :%x", root) + } +}